diff --git a/package.json b/package.json index 9eccb48..d6c0aab 100644 --- a/package.json +++ b/package.json @@ -48,6 +48,7 @@ "keygrip": "^1.1.0", "mysql2": "^3.10.2", "nanoid": "^5.0.7", + "nestjs-cls": "^4.4.0", "nestjs-otel": "^6.1.1", "nestjs-postal-client": "^0.0.6", "oidc-provider": "^8.5.1", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index d9cf9e3..25e89bd 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -74,12 +74,15 @@ importers: nanoid: specifier: ^5.0.7 version: 5.0.7 + nestjs-cls: + specifier: ^4.4.0 + version: 4.4.0(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/core@10.3.10(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/platform-express@10.3.10)(reflect-metadata@0.2.2)(rxjs@7.8.1))(reflect-metadata@0.2.2)(rxjs@7.8.1) nestjs-otel: specifier: ^6.1.1 version: 6.1.1(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/core@10.3.10(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/platform-express@10.3.10)(reflect-metadata@0.2.2)(rxjs@7.8.1)) nestjs-postal-client: specifier: ^0.0.6 - version: 0.0.6(cawqvsjzhg64rgfdjtb5cpgrqu) + version: 0.0.6(3k6st3bbxynfpw3tlvogds7gku) oidc-provider: specifier: ^8.5.1 version: 8.5.1 @@ -2945,8 +2948,8 @@ packages: neo-async@2.6.2: resolution: {integrity: sha512-Yd3UES5mWCSqR+qNT93S3UoYUkqAZ9lLg8a7g9rimsWmYGK8cVToA4/sF3RrshdyV3sAGMXVUmpMYOw+dLpOuw==} - nestjs-cls@4.3.0: - resolution: {integrity: sha512-MVTun6tqCZih8AJXRj8uBuuFyJhQrIA9m9fStiQjbBXUkE3BrlMRvmLzyw8UcneB3xtFFTfwkAh5PYKRulyaOg==} + nestjs-cls@4.4.0: + resolution: {integrity: sha512-qxsptbCo8Cp7xnAxtWv9+pSqOtB2NCr9ekQDH3FhxPAmgOys8F4WEGhuLLQ9iyW4dwqCao0xXatqQyA4anedmQ==} engines: {node: '>=16'} peerDependencies: '@nestjs/common': '> 7.0.0 < 11' @@ -7459,7 +7462,7 @@ snapshots: neo-async@2.6.2: {} - nestjs-cls@4.3.0(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/core@10.3.10(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/platform-express@10.3.10)(reflect-metadata@0.2.2)(rxjs@7.8.1))(reflect-metadata@0.2.2)(rxjs@7.8.1): + nestjs-cls@4.4.0(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/core@10.3.10(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/platform-express@10.3.10)(reflect-metadata@0.2.2)(rxjs@7.8.1))(reflect-metadata@0.2.2)(rxjs@7.8.1): dependencies: '@nestjs/common': 10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1) '@nestjs/core': 10.3.10(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/platform-express@10.3.10)(reflect-metadata@0.2.2)(rxjs@7.8.1) @@ -7474,7 +7477,7 @@ snapshots: '@opentelemetry/host-metrics': 0.35.3(@opentelemetry/api@1.9.0) response-time: 2.3.2 - nestjs-postal-client@0.0.6(cawqvsjzhg64rgfdjtb5cpgrqu): + nestjs-postal-client@0.0.6(3k6st3bbxynfpw3tlvogds7gku): dependencies: '@nestjs/cache-manager': 2.2.2(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/core@10.3.10(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/platform-express@10.3.10)(reflect-metadata@0.2.2)(rxjs@7.8.1))(cache-manager@5.7.2)(rxjs@7.8.1) '@nestjs/common': 10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1) @@ -7483,7 +7486,7 @@ snapshots: axios: 1.7.2 fastify: 4.28.1 kakious-nestjs-http-promise: 0.0.1(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(axios@1.7.2) - nestjs-cls: 4.3.0(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/core@10.3.10(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/platform-express@10.3.10)(reflect-metadata@0.2.2)(rxjs@7.8.1))(reflect-metadata@0.2.2)(rxjs@7.8.1) + nestjs-cls: 4.4.0(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/core@10.3.10(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/platform-express@10.3.10)(reflect-metadata@0.2.2)(rxjs@7.8.1))(reflect-metadata@0.2.2)(rxjs@7.8.1) nestjs-otel: 6.1.1(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/core@10.3.10(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/platform-express@10.3.10)(reflect-metadata@0.2.2)(rxjs@7.8.1)) typia: 6.5.1(typescript@5.5.2) transitivePeerDependencies: diff --git a/src/account/account.module.ts b/src/account/account.module.ts new file mode 100644 index 0000000..e69de29 diff --git a/src/app.module.ts b/src/app.module.ts index 69057b2..4a3ae3e 100644 --- a/src/app.module.ts +++ b/src/app.module.ts @@ -1,4 +1,4 @@ -import { Module } from '@nestjs/common'; +import { Module, NestModule } from '@nestjs/common'; import { AppController } from './app.controller'; import { AppService } from './app.service'; import { ConfigModule, ConfigService } from '@nestjs/config'; @@ -13,6 +13,8 @@ import { AuthModule } from './auth/auth.module'; import { UserModule } from './user/user.module'; import { ServeStaticModule } from '@nestjs/serve-static'; import { join } from 'path'; +import { AuthMiddleware } from './auth/middleware/auth.middleware'; +import { ClsModule } from 'nestjs-cls'; @Module({ imports: [ @@ -48,6 +50,10 @@ import { join } from 'path'; serveRoot: '/assets', rootPath: join(__dirname, '..', 'assets'), }), + ClsModule.forRoot({ + global: true, + middleware: { mount: false }, + }), MailModule, RedisModule, UserModule, @@ -56,4 +62,8 @@ import { join } from 'path'; controllers: [AppController], providers: [AppService], }) -export class AppModule {} +export class AppModule implements NestModule { + configure(consumer: import('@nestjs/common').MiddlewareConsumer) { + consumer.apply(AuthMiddleware).forRoutes('*'); + } +} diff --git a/src/app.service.ts b/src/app.service.ts index d12de69..927d7cc 100644 --- a/src/app.service.ts +++ b/src/app.service.ts @@ -1,8 +1,8 @@ -import { Injectable } from '@nestjs/common'; - -@Injectable() -export class AppService { - getHello(): string { - return 'Hello World!'; - } -} +import { Injectable } from '@nestjs/common'; + +@Injectable() +export class AppService { + getHello(): string { + return 'Hello World!'; + } +} diff --git a/src/auth/auth.module.ts b/src/auth/auth.module.ts index 82cf98d..9225706 100644 --- a/src/auth/auth.module.ts +++ b/src/auth/auth.module.ts @@ -21,6 +21,6 @@ import { OidcClientPermission } from '../database/models/oidc_client_permissions ], controllers: [OidcController, AuthController], providers: [ConfigService, OidcService, AuthService], - exports: [], + exports: [OidcService], }) export class AuthModule {} diff --git a/src/auth/controllers/auth.controller.ts b/src/auth/controllers/auth.controller.ts index 85fad8f..813447b 100644 --- a/src/auth/controllers/auth.controller.ts +++ b/src/auth/controllers/auth.controller.ts @@ -6,6 +6,7 @@ import { ForgotPasswordDto } from '../dto/forgotPassword.dto'; import { CreateUserDto } from '../dto/register.dto'; import { LoginUserDto } from '../dto/loginUser.dto'; import { Response } from 'express'; +import { User } from '../decorators/user.decorator'; @Controller('auth') export class AuthController { @@ -68,4 +69,10 @@ export class AuthController { login: 'login', }; } + + @Get('auth-test') + @ApiExcludeEndpoint() + public async getAuthTest(@User() user: any): Promise { + return user; + } } diff --git a/src/auth/decorators/isPublic.decorator.ts b/src/auth/decorators/isPublic.decorator.ts new file mode 100644 index 0000000..e69de29 diff --git a/src/auth/decorators/user.decorator.ts b/src/auth/decorators/user.decorator.ts new file mode 100644 index 0000000..40740aa --- /dev/null +++ b/src/auth/decorators/user.decorator.ts @@ -0,0 +1,16 @@ +import { createParamDecorator } from '@nestjs/common'; +import { ClsServiceManager } from 'nestjs-cls'; + +export const User = createParamDecorator(() => { + const cls = ClsServiceManager.getClsService(); + const authType = cls.get('authType'); + + if (authType !== 'session') { + return null; + } + + const user = cls.get('user'); + // remove the password from the user object + delete user.password; + return user; +}); diff --git a/src/auth/middleware/auth.middleware.ts b/src/auth/middleware/auth.middleware.ts new file mode 100644 index 0000000..197c3d3 --- /dev/null +++ b/src/auth/middleware/auth.middleware.ts @@ -0,0 +1,36 @@ +import { Injectable, NestMiddleware } from '@nestjs/common'; +import { Request, Response, NextFunction } from 'express'; +import { OidcService } from '../oidc/core.service'; +import { ClsService } from 'nestjs-cls'; + +@Injectable() +export class AuthMiddleware implements NestMiddleware { + constructor( + private readonly oidcService: OidcService, + private readonly clsService: ClsService, + ) {} + async use(req: Request, res: Response, next: NextFunction) { + // check if the user has session cookies + if (req.cookies['_session']) { + const { session, user } = await this.validateSession(req, res); + + // set the session and user data in the CLS + this.clsService.set('authType', 'session'); + this.clsService.set('user', user); + this.clsService.set('session', session); + } + next(); + } + + /** + * Validates a session cookie and returns the session and user data + * @param Req The request object + * @param Res The response object + * @returns Promise<{ sessionId: string, user: User }> The session and user data + */ + public async validateSession(req: Request, res: Response): Promise { + // validate the session cookie + const request = await this.oidcService.verifyByRequest(req, res); + return request; + } +} diff --git a/src/auth/oidc/adapter.ts b/src/auth/oidc/adapter.ts index dc6af2e..17a9fed 100644 --- a/src/auth/oidc/adapter.ts +++ b/src/auth/oidc/adapter.ts @@ -44,11 +44,11 @@ const grantable = new Set([ const globalCacheTTL = 1800; const ClientNotSupportedError = new Error('Clients are not supported'); -export const createOidcAdapter: ( - db: DataSource, - redis: RedisService, - baseUrl: string, -) => any = (db, redis, baseUrl) => { +export const createOidcAdapter: (db: DataSource, redis: RedisService, baseUrl: string) => any = ( + db, + redis, + baseUrl, +) => { const oidcClientRepo = db.getRepository(OidcClient); const oidcGrantRepo = db.getRepository(OidcGrant); const oidcSessionRepo = db.getRepository(OidcSession); @@ -70,11 +70,7 @@ export const createOidcAdapter: ( * @param {object} payload Object with all properties intended for storage. * @param {integer} expiresIn Number of seconds intended for this model to be stored. */ - async upsert( - id: string, - payload: AdapterPayload, - expiresIn: number, - ): Promise { + async upsert(id: string, payload: AdapterPayload, expiresIn: number): Promise { switch (this.type) { case TCLIENT: throw ClientNotSupportedError; @@ -123,12 +119,8 @@ export const createOidcAdapter: ( * when encountered. * @param {string} userCode the user_code value associated with a DeviceCode instance */ - async findByUserCode( - userCode: string, - ): Promise { - const id = (await redis.get(SQLAdapter.userCodeKeyFor(userCode))) as - | string - | undefined; + async findByUserCode(userCode: string): Promise { + const id = (await redis.get(SQLAdapter.userCodeKeyFor(userCode))) as string | undefined; if (!id) { return undefined; } @@ -304,13 +296,8 @@ export const createOidcAdapter: ( * @param id Session UUID * @returns Promise */ - private async fetchSession( - id: string, - ): Promise { - const cachedSession = (await redis.get( - this.key(id), - )) as OidcSession | null; - + private async fetchSession(id: string): Promise { + const cachedSession = (await redis.get(this.key(id))) as OidcSession | null; if (cachedSession) { return cachedSession; } @@ -323,10 +310,14 @@ export const createOidcAdapter: ( return undefined; } + // if it exists, convert accountID to string + if (session.accountId) { + session.accountId = session.accountId.toString(); + } + const generatedSession = session.generateResponse(); await redis.set(this.key(id), generatedSession, globalCacheTTL); - return generatedSession; } @@ -368,12 +359,8 @@ export const createOidcAdapter: ( * @param uid Session UUID * @returns Promise */ - private async fetchSessionByUid( - uid: string, - ): Promise { - const cachedSession = (await redis.get(SQLAdapter.uidKey(uid))) as - | string - | null; + private async fetchSessionByUid(uid: string): Promise { + const cachedSession = (await redis.get(SQLAdapter.uidKey(uid))) as string | null; if (cachedSession) { return this.fetchSession(cachedSession); @@ -387,6 +374,11 @@ export const createOidcAdapter: ( return undefined; } + // if it exists, convert accountID to string + if (session.accountId) { + session.accountId = session.accountId.toString(); + } + // cache the uuid reference for the session await redis.set(SQLAdapter.uidKey(uid), session.id, globalCacheTTL); @@ -432,11 +424,7 @@ export const createOidcAdapter: ( client.redirect_uris = []; } - if ( - client.redirect_uris.some((uri) => - uri.includes('system.internalRedirect'), - ) - ) { + if (client.redirect_uris.some((uri) => uri.includes('system.internalRedirect'))) { client.redirect_uris = client.redirect_uris.map((uri) => { if (uri === 'system.internalRedirect.management') { return `${baseUrl}/account/login`; @@ -464,10 +452,7 @@ export const createOidcAdapter: ( * @param expiresIn Refresh Token expiration in seconds * @returns Promise */ - private async upsertRefreshToken( - id: string, - payload: AdapterPayload, - ): Promise { + private async upsertRefreshToken(id: string, payload: AdapterPayload): Promise { await oidcRefreshRepo.upsert( { id, @@ -486,9 +471,7 @@ export const createOidcAdapter: ( * @param id Refresh Token UUID * @returns Promise */ - private async fetchRefreshToken( - id: string, - ): Promise { + private async fetchRefreshToken(id: string): Promise { const refreshToken = await oidcRefreshRepo.findOne({ where: { id }, }); @@ -536,11 +519,7 @@ export const createOidcAdapter: ( * @param expiresIn Model expiration in seconds * @returns Promise */ - async genericUpsert( - id: string, - payload: AdapterPayload, - expiresIn: number, - ): Promise { + async genericUpsert(id: string, payload: AdapterPayload, expiresIn: number): Promise { const multi = redis.multi(); const key = this.key(id); multi.call('JSON.SET', key, '.', JSON.stringify(payload)); @@ -598,11 +577,7 @@ export const createOidcAdapter: ( * @param id Model UUID */ async genericConsume(id: string): Promise { - await redis.jsonSetPath( - this.key(id), - 'consumed', - Math.floor(Date.now() / 1000), - ); + await redis.jsonSetPath(this.key(id), 'consumed', Math.floor(Date.now() / 1000)); } /** diff --git a/src/auth/oidc/core.service.ts b/src/auth/oidc/core.service.ts index ce85d5b..3402a98 100644 --- a/src/auth/oidc/core.service.ts +++ b/src/auth/oidc/core.service.ts @@ -1,5 +1,11 @@ /* eslint-disable @typescript-eslint/no-unused-vars */ -import { Injectable, Logger, OnModuleInit } from '@nestjs/common'; +import { + Injectable, + InternalServerErrorException, + Logger, + NotFoundException, + OnModuleInit, +} from '@nestjs/common'; import { promises as fs } from 'fs'; import type Provider from 'oidc-provider'; import psl from 'psl'; @@ -27,6 +33,8 @@ import generateId from './helper/nanoid.helper'; import { context, trace } from '@opentelemetry/api'; import * as KeyGrip from 'keygrip'; import { getEpochTime } from '../../util/time.util'; +import { VerifiedSessionFromRequest } from './types/session.type'; +import { Request, Response } from 'express'; // This is an async import for the oidc-provider package as it's now only esm and we need to use it in a commonjs environment. async function getProvider(): Promise<{ @@ -68,6 +76,8 @@ export class OidcService implements OnModuleInit { private jwks: any; private cookies: any; private readonly logger = new Logger(OidcService.name); + private sessionNotFoundMessage = 'Session not found'; + private sessionNotFoundError = 'NoSession'; constructor( private readonly userService: UserService, @@ -490,9 +500,18 @@ export class OidcService implements OnModuleInit { httpOnly: true, }, }, + { + name: '_session.legacy', + value: sessionId, + options: { + expires: expire, + sameSite: 'None', + httpOnly: true, + }, + }, { name: '_session.sig', - value: keyGrip.sign(sessionId), + value: keyGrip.sign(pre), options: { expires: expire, sameSite: 'strict', @@ -504,4 +523,76 @@ export class OidcService implements OnModuleInit { this.logger.debug(`Created new session`, { sessionId, accountId }); return { sessionId, cookies, cookiesForms }; } + + /** + * Verify a session by request and response + * @param req The request + * @param res The response + * @returns any + */ + @Span() + async verifyByRequest(req: Request, res: Response): Promise { + try { + const ctx = this.provider.app.createContext(req, res); + const session = await this.provider.Session.get(ctx); + + if (!session.accountId) { + throw new NotFoundException(this.sessionNotFoundMessage, this.sessionNotFoundError); + } + + const user = await this.userService.getUserById(session.accountId); + + if (!user) { + this.logger.error( + 'Account not found while trying to verify Session. Account ID was: ' + session.accountId, + ); + throw new NotFoundException( + 'Account not found while trying to verify Session.', + this.sessionNotFoundError, + ); + } + + return { + session, + user, + }; + } catch (err) { + this.logger.error( + err, + 'There was an error while trying to verify session, purging session cookies', + ); + + //this.clearSessionCookies(res); + throw new InternalServerErrorException('Unknown Error while trying to process session', { + cause: err, + description: 'VerifySessionError', + }); + } + } + + /** + * Remove all cookies related to session and interaction from response + * @param res the response + */ + @Span() + clearCookies(res: Response): void { + res.clearCookie('_interaction'); + res.clearCookie('_interaction.sig'); + res.clearCookie('_session'); + res.clearCookie('_session.sig'); + res.clearCookie('_session.legacy'); + res.clearCookie('_session.legacy.sig'); + } + + /* + * Remove all cookies related to session and interaction from response + * @param res the response + */ + @Span() + clearSessionCookies(res: Response): void { + res.clearCookie('_session'); + res.clearCookie('_session.sig'); + res.clearCookie('_session.legacy'); + res.clearCookie('_session.legacy.sig'); + } } diff --git a/src/auth/oidc/types/session.type.ts b/src/auth/oidc/types/session.type.ts index 206305b..67d6d5d 100644 --- a/src/auth/oidc/types/session.type.ts +++ b/src/auth/oidc/types/session.type.ts @@ -1,6 +1,8 @@ +import { User } from '../../../database/models/user.model'; + export interface VerifiedSessionFromRequest { session: Session; - //USER: User; + user: User; } export interface Session { diff --git a/src/main.ts b/src/main.ts index b47bad0..ef3d71f 100644 --- a/src/main.ts +++ b/src/main.ts @@ -4,6 +4,7 @@ import { join } from 'path'; import { AppModule } from './app.module'; import { ValidationPipe } from '@nestjs/common'; import * as cookieParser from 'cookie-parser'; +import { ClsMiddleware } from 'nestjs-cls'; async function bootstrap() { const app = await NestFactory.create(AppModule); @@ -11,6 +12,10 @@ async function bootstrap() { app.useGlobalPipes(new ValidationPipe()); app.use(cookieParser()); + // Doing this to make sure it's always the first middleware. Since it does hold auth data. + app.use(new ClsMiddleware().use); + + // Rendering app.useStaticAssets(join(__dirname, '..', 'public')); app.setBaseViewsDir(join(__dirname, '..', 'views')); app.setViewEngine('hbs');