Skip to Content
👋 Hey there! Welcome to NestJS tRPC.

Middleware

Learn how to implement and use middleware for authentication, validation, and more in NestJS tRPC.

Overview

Middleware in NestJS tRPC provides a way to intercept and modify requests before they reach your procedure handlers. Common use cases include:

  • Authentication and authorization
  • Request logging and monitoring
  • Input validation and transformation
  • Rate limiting
  • Caching
  • Error handling

Creating Middleware

Basic Middleware

Create middleware by implementing the MiddlewareFn interface:

import { MiddlewareFn } from '@nexica/nestjs-trpc'
import { TRPCError } from '@trpc/server'
import { RequestContext } from '@/context/app.context'
 
export const LoggingMiddleware: MiddlewareFn<RequestContext> = async (opts) => {
    const { ctx, next, path, type } = opts
 
    console.log(`📝 ${type.toUpperCase()} ${path} - Request ID: ${ctx.requestId}`)
 
    const start = Date.now()
    const result = await next({ ctx })
    const duration = Date.now() - start
 
    console.log(`✅ ${type.toUpperCase()} ${path} completed in ${duration}ms`)
 
    return result
}

Authentication Middleware

import { Injectable } from '@nestjs/common'
import { MiddlewareFn } from '@nexica/nestjs-trpc'
import { TRPCError } from '@trpc/server'
import { JwtService } from '@nestjs/jwt'
 
@Injectable()
export class AuthMiddleware {
    constructor(private readonly jwtService: JwtService) {}
 
    public readonly middleware: MiddlewareFn<RequestContext> = async (opts) => {
        const { ctx, next } = opts
 
        // Extract token from Authorization header
        const authorization = ctx.req.headers.authorization
        if (!authorization || !authorization.startsWith('Bearer ')) {
            throw new TRPCError({
                code: 'UNAUTHORIZED',
                message: 'No valid authorization header found',
            })
        }
 
        const token = authorization.slice(7) // Remove 'Bearer ' prefix
 
        try {
            // Verify JWT token
            const payload = await this.jwtService.verifyAsync(token)
 
            // Add user information to context
            const updatedCtx = {
                ...ctx,
                userId: payload.sub,
                userEmail: payload.email,
                userRoles: payload.roles || [],
            }
 
            return next({ ctx: updatedCtx })
        } catch (error) {
            throw new TRPCError({
                code: 'UNAUTHORIZED',
                message: 'Invalid or expired token',
            })
        }
    }
}
 
// Export a factory function for easier use
export const createAuthMiddleware = (jwtService: JwtService) => new AuthMiddleware(jwtService).middleware

Role-Based Authorization

export const createRoleMiddleware = (requiredRoles: string[]) => {
    const middleware: MiddlewareFn<RequestContext> = async (opts) => {
        const { ctx, next } = opts
 
        // Ensure user is authenticated (userId should be set by auth middleware)
        if (!ctx.userId) {
            throw new TRPCError({
                code: 'UNAUTHORIZED',
                message: 'Authentication required',
            })
        }
 
        // Check if user has required roles
        const userRoles = ctx.userRoles || []
        const hasRequiredRole = requiredRoles.some((role) => userRoles.includes(role))
 
        if (!hasRequiredRole) {
            throw new TRPCError({
                code: 'FORBIDDEN',
                message: `Access denied. Required roles: ${requiredRoles.join(', ')}`,
            })
        }
 
        return next({ ctx })
    }
 
    return middleware
}
 
// Usage examples
export const AdminMiddleware = createRoleMiddleware(['admin'])
export const ModeratorMiddleware = createRoleMiddleware(['admin', 'moderator'])
export const UserMiddleware = createRoleMiddleware(['admin', 'moderator', 'user'])

Rate Limiting Middleware

interface RateLimitOptions {
    windowMs: number // Time window in milliseconds
    maxRequests: number // Max requests per window
    keyGenerator?: (ctx: RequestContext) => string
}
 
export const createRateLimitMiddleware = (options: RateLimitOptions) => {
    const requests = new Map<string, { count: number; resetTime: number }>()
 
    const middleware: MiddlewareFn<RequestContext> = async (opts) => {
        const { ctx, next } = opts
 
        // Generate rate limit key (default: IP address)
        const key = options.keyGenerator ? options.keyGenerator(ctx) : ctx.req.ip || ctx.req.connection.remoteAddress || 'unknown'
 
        const now = Date.now()
        const windowStart = now - options.windowMs
 
        // Clean up old entries
        for (const [k, v] of requests.entries()) {
            if (v.resetTime < windowStart) {
                requests.delete(k)
            }
        }
 
        // Get current request count for this key
        const current = requests.get(key) || { count: 0, resetTime: now + options.windowMs }
 
        if (current.count >= options.maxRequests) {
            throw new TRPCError({
                code: 'TOO_MANY_REQUESTS',
                message: `Rate limit exceeded. Try again after ${Math.ceil((current.resetTime - now) / 1000)} seconds.`,
            })
        }
 
        // Increment request count
        requests.set(key, {
            count: current.count + 1,
            resetTime: current.resetTime,
        })
 
        return next({ ctx })
    }
 
    return middleware
}
 
// Usage
export const ApiRateLimit = createRateLimitMiddleware({
    windowMs: 15 * 60 * 1000, // 15 minutes
    maxRequests: 100, // 100 requests per window
})

Applying Middleware

Router-Level Middleware

Apply middleware to all procedures in a router:

import { Router, Middleware } from '@nexica/nestjs-trpc'
import { AuthMiddleware, LoggingMiddleware } from '@/middleware'
 
@Router()
@Middleware(LoggingMiddleware, AuthMiddleware) // Applied to all procedures
export class UserRouter {
    @Query()
    async getProfile(@Context() ctx: RequestContext) {
        // Auth middleware ensures ctx.userId is available
        return await this.userService.findById(ctx.userId)
    }
 
    @Mutation()
    async updateProfile(@Input() input: UpdateProfileInput, @Context() ctx: RequestContext) {
        // Both logging and auth middleware are applied
        return await this.userService.update(ctx.userId, input)
    }
}

Procedure-Level Middleware

Apply middleware to specific procedures:

@Router()
export class AdminRouter {
    @Query()
    async getPublicStats() {
        // No middleware - public endpoint
        return await this.statsService.getPublicStats()
    }
 
    @Query()
    @Middleware(AuthMiddleware)
    async getUserStats(@Context() ctx: RequestContext) {
        // Only auth middleware applied
        return await this.statsService.getUserStats(ctx.userId)
    }
 
    @Mutation()
    @Middleware(AuthMiddleware, AdminMiddleware)
    async deleteUser(@Input() input: { id: string }) {
        // Both auth and admin middleware applied
        return await this.userService.delete(input.id)
    }
}

Combining Multiple Middleware

Stack multiple middleware in order:

@Router()
@Middleware(LoggingMiddleware) // Applied first (outermost)
export class ApiRouter {
    @Query()
    @Middleware(AuthMiddleware, ApiRateLimit) // Applied after logging
    async sensitiveOperation(@Context() ctx: RequestContext) {
        // Execution order:
        // 1. LoggingMiddleware (from router)
        // 2. AuthMiddleware (from procedure)
        // 3. ApiRateLimit (from procedure)
        // 4. Procedure handler
        return await this.performSensitiveOperation(ctx.userId)
    }
}

Advanced Middleware Patterns

Conditional Middleware

export const createConditionalMiddleware = (condition: (ctx: RequestContext) => boolean, middleware: MiddlewareFn<RequestContext>) => {
    const conditionalMiddleware: MiddlewareFn<RequestContext> = async (opts) => {
        const { ctx, next } = opts
 
        if (condition(ctx)) {
            return middleware(opts)
        }
 
        return next({ ctx })
    }
 
    return conditionalMiddleware
}
 
// Usage: Only apply auth to non-public endpoints
export const OptionalAuth = createConditionalMiddleware((ctx) => !ctx.req.path.startsWith('/public'), AuthMiddleware)

Caching Middleware

interface CacheOptions {
    ttl: number // Time to live in seconds
    keyGenerator: (ctx: RequestContext, path: string, input: any) => string
}
 
export const createCacheMiddleware = (options: CacheOptions) => {
    const cache = new Map<string, { data: any; expires: number }>()
 
    const middleware: MiddlewareFn<RequestContext> = async (opts) => {
        const { ctx, next, path, input } = opts
 
        // Only cache GET operations (queries)
        if (opts.type !== 'query') {
            return next({ ctx })
        }
 
        const cacheKey = options.keyGenerator(ctx, path, input)
        const now = Date.now()
 
        // Check cache
        const cached = cache.get(cacheKey)
        if (cached && cached.expires > now) {
            console.log(`🎯 Cache hit for ${path}`)
            return cached.data
        }
 
        // Execute procedure
        const result = await next({ ctx })
 
        // Store in cache
        cache.set(cacheKey, {
            data: result,
            expires: now + options.ttl * 1000,
        })
 
        console.log(`💾 Cached result for ${path}`)
        return result
    }
 
    return middleware
}
 
// Usage
export const QueryCache = createCacheMiddleware({
    ttl: 300, // 5 minutes
    keyGenerator: (ctx, path, input) => `${path}:${JSON.stringify(input)}:${ctx.userId}`,
})

Error Handling Middleware

export const ErrorHandlingMiddleware: MiddlewareFn<RequestContext> = async (opts) => {
    const { ctx, next, path, type } = opts
 
    try {
        return await next({ ctx })
    } catch (error) {
        // Log error details
        console.error(`❌ Error in ${type.toUpperCase()} ${path}:`, {
            error: error.message,
            stack: error.stack,
            userId: ctx.userId,
            requestId: ctx.requestId,
        })
 
        // Transform known errors
        if (error.code === 'P2002') {
            // Prisma unique constraint
            throw new TRPCError({
                code: 'CONFLICT',
                message: 'Resource already exists',
            })
        }
 
        if (error.code === 'P2025') {
            // Prisma record not found
            throw new TRPCError({
                code: 'NOT_FOUND',
                message: 'Resource not found',
            })
        }
 
        // Re-throw tRPC errors as-is
        if (error instanceof TRPCError) {
            throw error
        }
 
        // Transform unexpected errors
        throw new TRPCError({
            code: 'INTERNAL_SERVER_ERROR',
            message: 'An unexpected error occurred',
            cause: error,
        })
    }
}

Testing Middleware

Unit Testing

import { describe, it, expect, vi } from 'vitest'
import { TRPCError } from '@trpc/server'
import { AuthMiddleware } from '@/middleware/auth.middleware'
 
describe('AuthMiddleware', () => {
    it('should pass through valid token', async () => {
        const mockNext = vi.fn().mockResolvedValue({ data: 'success' })
        const mockJwtService = {
            verifyAsync: vi.fn().mockResolvedValue({
                sub: 'user123',
                email: 'test@example.com',
            }),
        }
 
        const middleware = new AuthMiddleware(mockJwtService).middleware
 
        const ctx = {
            req: {
                headers: { authorization: 'Bearer valid-token' },
            },
        } as any
 
        const result = await middleware({
            ctx,
            next: mockNext,
            path: 'test',
            type: 'query',
        })
 
        expect(mockNext).toHaveBeenCalledWith({
            ctx: expect.objectContaining({
                userId: 'user123',
                userEmail: 'test@example.com',
            }),
        })
        expect(result).toEqual({ data: 'success' })
    })
 
    it('should throw UNAUTHORIZED for missing token', async () => {
        const mockNext = vi.fn()
        const mockJwtService = { verifyAsync: vi.fn() }
        const middleware = new AuthMiddleware(mockJwtService).middleware
 
        const ctx = {
            req: { headers: {} },
        } as any
 
        await expect(
            middleware({
                ctx,
                next: mockNext,
                path: 'test',
                type: 'query',
            })
        ).rejects.toThrow(TRPCError)
 
        expect(mockNext).not.toHaveBeenCalled()
    })
})

Integration Testing

import { Test } from '@nestjs/testing'
import { TRPCModule } from '@nexica/nestjs-trpc'
import { AuthMiddleware } from '@/middleware/auth.middleware'
 
describe('Middleware Integration', () => {
    let app: any
 
    beforeEach(async () => {
        const module = await Test.createTestingModule({
            imports: [
                TRPCModule.forRoot({
                    context: RequestContextFactory,
                }),
            ],
            providers: [TestRouter, AuthMiddleware],
        }).compile()
 
        app = module.createNestApplication()
        await app.init()
    })
 
    it('should apply middleware correctly', async () => {
        // Test middleware behavior with actual HTTP requests
        const response = await request(app.getHttpServer())
            .post('/trpc/test.protectedProcedure')
            .set('Authorization', 'Bearer valid-token')
            .send({ input: {} })
 
        expect(response.status).toBe(200)
    })
})

Best Practices

1. Keep Middleware Focused

Each middleware should have a single responsibility:

// ✅ Good: Single responsibility
export const AuthMiddleware = /* ... authentication only ... */
export const LoggingMiddleware = /* ... logging only ... */
export const ValidationMiddleware = /* ... validation only ... */
 
// ❌ Avoid: Multiple responsibilities
export const MegaMiddleware = /* ... auth + logging + validation + caching ... */

2. Handle Errors Gracefully

Always handle potential errors in middleware:

export const SafeMiddleware: MiddlewareFn<RequestContext> = async (opts) => {
    try {
        // Middleware logic here
        return await next({ ctx })
    } catch (error) {
        // Log error and decide whether to throw or handle
        console.error('Middleware error:', error)
        throw error // or handle gracefully
    }
}

3. Use TypeScript for Better Safety

Leverage TypeScript for type-safe middleware:

interface AuthenticatedContext extends RequestContext {
    userId: string
    userRoles: string[]
}
 
export const AuthMiddleware: MiddlewareFn<RequestContext, AuthenticatedContext> = async (opts) => {
    // Implementation ensures the returned context has userId and userRoles
}

4. Performance Considerations

  • Avoid heavy computations in middleware
  • Use caching for expensive operations
  • Consider async/await performance implications

Next Steps

Continue exploring NestJS tRPC:

  • Testing - Learn how to test your middleware
  • Examples - See complete middleware implementations
  • Deployment - Deploy your application with middleware
Last updated on: