feat: more implementing oidc-core

This commit is contained in:
Kakious 2024-07-18 22:51:36 -04:00
parent b38f5841e3
commit c65ba92322
17 changed files with 333 additions and 49 deletions

View file

@ -1,4 +1,7 @@
{
"singleQuote": true,
"trailingComma": "all"
{
"singleQuote": true,
"trailingComma": "all",
"tabWidth": 2,
"semi": true,
"printWidth": 100
}

View file

@ -32,6 +32,7 @@
"@nestjs/config": "^3.2.3",
"@nestjs/core": "^10.3.10",
"@nestjs/platform-express": "^10.3.10",
"@nestjs/swagger": "^7.4.0",
"@nestjs/terminus": "^10.2.3",
"@nestjs/typeorm": "^10.0.2",
"@opentelemetry/api": "^1.9.0",
@ -52,7 +53,8 @@
"rxjs": "^7.8.1",
"typeorm": "^0.3.20",
"typia": "^6.5.1",
"uuid": "^10.0.0"
"uuid": "^10.0.0",
"wildcard": "^2.0.1"
},
"devDependencies": {
"@nestjs/cli": "^10.4.2",

11
pnpm-lock.yaml generated
View file

@ -26,6 +26,9 @@ importers:
'@nestjs/platform-express':
specifier: ^10.3.10
version: 10.3.10(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/core@10.3.10)
'@nestjs/swagger':
specifier: ^7.4.0
version: 7.4.0(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/core@10.3.10(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/platform-express@10.3.10)(reflect-metadata@0.2.2)(rxjs@7.8.1))(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)
'@nestjs/terminus':
specifier: ^10.2.3
version: 10.2.3(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/core@10.3.10(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/platform-express@10.3.10)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/typeorm@10.0.2(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/core@10.3.10(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/platform-express@10.3.10)(reflect-metadata@0.2.2)(rxjs@7.8.1))(reflect-metadata@0.2.2)(rxjs@7.8.1)(typeorm@0.3.20(ioredis@5.4.1)(mysql2@3.10.2)(ts-node@10.9.2(@types/node@20.14.10)(typescript@5.5.3))))(reflect-metadata@0.2.2)(rxjs@7.8.1)(typeorm@0.3.20(ioredis@5.4.1)(mysql2@3.10.2)(ts-node@10.9.2(@types/node@20.14.10)(typescript@5.5.3)))
@ -89,6 +92,9 @@ importers:
uuid:
specifier: ^10.0.0
version: 10.0.0
wildcard:
specifier: ^2.0.1
version: 2.0.1
devDependencies:
'@nestjs/cli':
specifier: ^10.4.2
@ -4001,6 +4007,9 @@ packages:
resolution: {integrity: sha512-NsmoXalsWVDMGupxZ5R08ka9flZjjiLvHVAWYOKtiKM8ujtZWr9cRffak+uSE48+Ob8ObalXpwyeUiyDD6QFgg==}
engines: {node: '>=8'}
wildcard@2.0.1:
resolution: {integrity: sha512-CC1bOL87PIWSBhDcTrdeLo6eGT7mCFtrg0uIJtqJUFyK+eJnzl8A1niH56uu7KMa5XFrtiV+AQuHO3n7DsHnLQ==}
with@7.0.2:
resolution: {integrity: sha512-RNGKj82nUPg3g5ygxkQl0R937xLyho1J24ItRCBTr/m1YnZkzJy1hUiHUJrc/VlsDQzsCnInEGSg3bci0Lmd4w==}
engines: {node: '>= 10.0.0'}
@ -8407,6 +8416,8 @@ snapshots:
dependencies:
string-width: 4.2.3
wildcard@2.0.1: {}
with@7.0.2:
dependencies:
'@babel/parser': 7.24.8

View file

@ -1,3 +1,11 @@
// This is the internal client ID for the application itself.
export const internalClientId = 'system.internal.management';
export const internalRedirectTag = 'system.internalRedirect.management';
// This is the master org configuration. This is the default org configuration that is used when no org is specified.
export const masterOrgId = 1;
export const masterOrgName = 'WaterWolf';
export const masterOrgSlug = 'waterwolf';
export const masterOrgAdminRole = 'admin';
export const masterOrgUserRole = 'user';

View file

@ -0,0 +1,19 @@
import { All, Controller, Req, Res, UseInterceptors } from '@nestjs/common';
import { ApiExcludeController } from '@nestjs/swagger';
import { OidcService } from '../oidc/core.service';
import { ExpressResErrorInterceptor } from '../../interceptor/express_res_error.interceptor';
@UseInterceptors(new ExpressResErrorInterceptor())
@ApiExcludeController()
@Controller('oidc')
export class OidcController {
constructor(private readonly oidcService: OidcService) {}
@All('/*')
// eslint-disable-next-line @typescript-eslint/explicit-function-return-type
public mountedOidc(@Req() req: any, @Res() res: any) {
req.url = req.originalUrl.replace('/oidc', '');
const callback = this.oidcService.provider.callback();
return callback(req, res);
}
}

View file

@ -3,8 +3,9 @@ import { Injectable, Logger, OnModuleInit } from '@nestjs/common';
import { promises as fs } from 'fs';
import type Provider from 'oidc-provider';
import psl from 'psl';
import type { Configuration, errors, KoaContextWithOIDC } from 'oidc-provider';
import type { Configuration, errors } from 'oidc-provider';
import { createOidcAdapter } from './adapter';
import wildcard from 'wildcard';
import {
ACCESS_TOKEN_LIFE,
AUTHORIZATION_TOKEN_LIFE,
@ -17,6 +18,12 @@ import {
REFRESH_TOKEN_LIFE,
SESSION_LIFE,
} from './oidc.const';
import { ConfigService } from '@nestjs/config';
import { DataSource } from 'typeorm';
import { RedisService } from '../../redis/redis.service';
import { UserService } from 'src/user/user.service';
import { Span } from 'nestjs-otel';
import generateId from './helper/nanoid.helper';
// This is an async import for the oidc-provider package as it's now only esm and we need to use it in a commonjs environment.
async function getProvider(): Promise<{
@ -54,15 +61,24 @@ async function getProvider(): Promise<{
* OidcService is the service that handles all interaction with the OIDC library and package
*/
export class OidcService implements OnModuleInit {
private provider!: Provider;
private _provider!: Provider;
private jwks: any;
private cookies: any;
private readonly logger = new Logger(OidcService.name);
constructor() {}
constructor(
private readonly userService: UserService,
private readonly configService: ConfigService,
private readonly dataSource: DataSource,
private readonly redisService: RedisService,
) {}
async onModuleInit() {
await this.loadKeys();
const { provider, providerErrors } = await getProvider();
const baseUrl = this.configService.getOrThrow<string>('base.localUrl');
const isOrigin = (value: string): boolean => {
if (typeof value !== 'string') {
@ -182,17 +198,14 @@ export class OidcService implements OnModuleInit {
},
},
clientBasedCORS(ctx, origin, client) {
// ctx.oidc.route can be used to exclude endpoints from this behaviour, in that case just return
// true to always allow CORS on them, false to deny
// you may also allow some known internal origins if you want to
return (client['client_cors'] as string[]).includes(origin);
},
jwks: this.jwks,
cookies: {
keys: this.cookies,
},
// TODO: FUCKING IMPLEMENT THIS
findAccount: this.findAccount.bind(this),
// TODO: FUCKING IMPLEMENT THIS AHHHHHHHHHH
findAccount: this.userService.oidcFindAccount.bind(this),
ttl: {
AccessToken: ACCESS_TOKEN_LIFE,
RefreshToken: function RefreshTokenTTL(ctx, token, client) {
@ -223,8 +236,7 @@ export class OidcService implements OnModuleInit {
}
return (
code.scopes.has('offline_access') ||
(client.applicationType === 'web' &&
client.tokenEndpointAuthMethod === 'none')
(client.applicationType === 'web' && client.tokenEndpointAuthMethod === 'none')
);
},
features: {
@ -242,9 +254,9 @@ export class OidcService implements OnModuleInit {
}
if (token.clientId !== ctx.oidc.client?.clientId) {
return (
client['allowed_introspection_targets'] as string[]
).includes(token.clientId!);
return (client['allowed_introspection_targets'] as string[]).includes(
token.clientId!,
);
}
return true;
},
@ -296,7 +308,6 @@ export class OidcService implements OnModuleInit {
openid: ['sub'],
profile: ['name', 'preferred_username', 'picture', 'displayName'],
email: ['email', 'email_verified'],
roles: ['roles', 'permissions'],
},
pkce: {
methods: ['S256'],
@ -305,6 +316,68 @@ export class OidcService implements OnModuleInit {
},
},
} as Configuration;
const oidc = new provider(`${baseUrl}/oidc`, config) as Provider;
oidc.proxy = true;
const { redirectUriAllowed, postLogoutRedirectUriAllowed } = oidc.Client.prototype;
const hasWildcardHost = (redirectUri: string) => {
const { hostname } = new URL(redirectUri);
return hostname.includes('*');
};
const wildcardMatches = (redirectUri: string, wildcardUri: string) =>
!!wildcard(wildcardUri, redirectUri);
const logger = this.logger;
oidc.Client.prototype.redirectUriAllowed = function wildcardRedirectUriAllowed(redirectUri) {
logger.log(
{
redirectUri,
redirectUris: this.redirectUris,
},
'Checking redirect uri - 1',
);
if (!this.redirectUris) {
return redirectUriAllowed.call(this, redirectUri);
}
if (!this.redirectUris.some(hasWildcardHost)) {
return redirectUriAllowed.call(this, redirectUri);
}
const wildcardUris = this.redirectUris.filter(hasWildcardHost);
if (wildcardUris.some(wildcardMatches.bind(undefined, redirectUri))) {
return true;
}
return redirectUriAllowed.call(this, redirectUri);
};
oidc.Client.prototype.postLogoutRedirectUriAllowed =
function wildcardPostLogoutRedirectUriAllowed(redirectUri) {
if (!this.postLogoutRedirectUris) {
return postLogoutRedirectUriAllowed.call(this, redirectUri);
}
if (!this.postLogoutRedirectUris.some(hasWildcardHost)) {
return postLogoutRedirectUriAllowed.call(this, redirectUri);
}
const wildcardUris = this.postLogoutRedirectUris.filter(hasWildcardHost);
if (wildcardUris.some(wildcardMatches.bind(undefined, redirectUri))) {
return true;
}
return postLogoutRedirectUriAllowed.call(this, redirectUri);
};
this._provider = oidc;
}
public get provider(): Provider {
return this._provider;
}
/// === SUPPORT FUNCTIONS === ///
@ -317,18 +390,12 @@ export class OidcService implements OnModuleInit {
const keysFolder = this.getKeysFolder();
try {
this.jwks = JSON.parse(
await fs.readFile(keysFolder + 'keys/jwks.json', 'utf8'),
);
this.cookies = JSON.parse(
await fs.readFile(keysFolder + 'keys/cookies.json', 'utf8'),
);
this.jwks = JSON.parse(await fs.readFile(keysFolder + 'keys/jwks.json', 'utf8'));
this.cookies = JSON.parse(await fs.readFile(keysFolder + 'keys/cookies.json', 'utf8'));
this.logger.debug(`Loaded oidc keys successfully`);
} catch (error) {
throw new Error(
`Failed to load keys, keys should be located ${keysFolder}keys`,
);
throw new Error(`Failed to load keys, keys should be located ${keysFolder}keys`);
}
}
@ -338,4 +405,13 @@ export class OidcService implements OnModuleInit {
}
return __dirname + '/../../../../';
}
/**
* Generate a JTI for an OIDC object
* @returns string - The JTI
*/
@Span()
async generateJti(): Promise<string> {
return await generateId();
}
}

View file

@ -0,0 +1,43 @@
// Write an error class that extends the base error class
//
export enum OidcErrorType {
InvalidRequest = 'invalid_request',
InvalidClient = 'invalid_client',
InvalidGrant = 'invalid_grant',
UnauthorizedClient = 'unauthorized_client',
InvalidSession = 'invalid_session',
}
export default class OidcError extends Error {
constructor(
public errorType: OidcErrorType,
public errorDescription: string,
public errorUri?: string,
) {
super(errorDescription);
}
public toOidcError() {
return {
error: this.errorType,
error_description: this.errorDescription,
error_uri: this.errorUri,
};
}
public static fromOidcError(oidcError: any) {
return new OidcError(oidcError.error, oidcError.error_description, oidcError.error_uri);
}
public static invalidSession() {
return new OidcError(OidcErrorType.InvalidSession, 'Session token is invalid');
}
public static invalidRequest() {
return new OidcError(OidcErrorType.InvalidRequest, 'Request is invalid');
}
public static invalidClient() {
return new OidcError(OidcErrorType.InvalidClient, 'Client is invalid');
}
}

View file

@ -1,13 +1,6 @@
import type { ClientAuthMethod, ResponseType } from 'oidc-provider';
import {
Column,
Entity,
JoinTable,
ManyToMany,
OneToMany,
PrimaryGeneratedColumn,
} from 'typeorm';
import { Column, Entity, JoinTable, ManyToMany, OneToMany, PrimaryGeneratedColumn } from 'typeorm';
import { OidcClientPermission } from './oidc_client_permissions.model';
import { MAX_STRING_LENGTH } from '../database.const';
@ -20,7 +13,7 @@ export class OidcClient {
// Owner Org ID
@Column({ type: 'varchar', length: MAX_STRING_LENGTH, nullable: false })
ownerName: string;
owner_org_id: string;
@Column({ type: 'varchar', length: MAX_STRING_LENGTH, nullable: false })
client_secret: string;
@ -28,21 +21,24 @@ export class OidcClient {
@Column({ type: 'varchar', length: MAX_STRING_LENGTH, nullable: true })
client_name: string;
@Column({ type: 'simple-array', nullable: false })
@Column({ type: 'simple-array', nullable: false, default: [] })
redirect_uris: string[];
@Column({ type: 'simple-array', nullable: true })
@Column({ type: 'simple-array', nullable: true, default: [] })
client_cors: string[];
@Column({ type: 'simple-array', nullable: true })
@Column({ type: 'simple-array', nullable: true, default: [] })
allowed_introspection_targets: string[];
@Column({ type: 'simple-array', nullable: false })
@Column({ type: 'simple-array', nullable: false, default: [] })
include_permissions_from_client: string[];
@Column({ type: 'simple-array', nullable: false })
@Column({ type: 'simple-array', nullable: false, default: [] })
post_logout_redirect_uris: string[];
@Column({ type: 'simple-array', nullable: true })
default_scopes: string[];
@Column({ type: 'simple-array', nullable: false })
response_types: ResponseType[];
@ -64,10 +60,7 @@ export class OidcClient {
@OneToMany(() => OidcClientPermission, (permission) => permission.client)
permissions: OidcClientPermission[];
@ManyToMany(
() => OidcClientPermission,
(permission) => permission.assignedClients,
)
@ManyToMany(() => OidcClientPermission, (permission) => permission.assignedClients)
@JoinTable({
name: 'oidc_client_permissions',
joinColumn: {

View file

@ -9,6 +9,9 @@ export class Organization {
@Column({ length: MAX_STRING_LENGTH })
name: string;
@Column({ length: MAX_STRING_LENGTH, unique: true })
slug: string;
@Column({ length: MAX_STRING_LENGTH, nullable: true })
logo?: string;

View file

@ -0,0 +1 @@
// This contains the inital setup data for the IdP. This should be standalone.

View file

@ -0,0 +1,44 @@
import type { CallHandler, ExecutionContext, NestInterceptor } from '@nestjs/common';
import { HttpException, Injectable, Logger } from '@nestjs/common';
import type { Response } from 'express';
import { catchError, type Observable, throwError } from 'rxjs';
/**
* Finishes the request if an error was unexpectedly thrown.
* This ensures the requests don't hang.
* See https://github.com/nestjs/nest/issues/5448
*/
@Injectable()
export class ExpressResErrorInterceptor implements NestInterceptor {
private readonly logger = new Logger(ExpressResErrorInterceptor.name);
intercept(ctx: ExecutionContext, next: CallHandler): Observable<unknown> {
const res = ctx.switchToHttp().getResponse<Response>();
return next.handle().pipe(
catchError((err: unknown) => {
if (!res.writableEnded) {
// Send an error response if it hasn't been done so the request doesn't hang.
this.logger.error(err, 'Uncaught internal error');
if (err instanceof HttpException) {
const statusCode = err.getStatus();
res.status(statusCode).send({
message: err.getResponse(),
statusCode,
});
} else {
res.status(500).send({
message: 'Internal Server Error',
error: 'Internal Server Error',
statusCode: 500,
});
}
}
return throwError(() => err);
}),
);
}
}

View file

@ -83,7 +83,7 @@ export class RedisService implements OnApplicationShutdown {
* @param key Key for the value to get
* @returns
*/
public async get(key: string): Promise<string | object | null> {
public async get<T = string>(key: string): Promise<T | null> {
this.logger.debug(`Getting key ${key}`);
const value = await this._ioredis.get(key);
if (!value) {
@ -91,9 +91,9 @@ export class RedisService implements OnApplicationShutdown {
}
try {
return JSON.parse(value);
return JSON.parse(value) as T;
} catch (error) {
return value;
return value as T;
}
}
@ -283,4 +283,4 @@ export class RedisService implements OnApplicationShutdown {
},
};
}
}
}

0
src/s3/s3.module.ts Normal file
View file

0
src/s3/s3.service.ts Normal file
View file

View file

@ -0,0 +1,6 @@
export const USER_NOT_FOUND_ERROR = 'User not found';
// Caching Constants for Redis
export const userCacheTTL = 60 * 60 * 24; // 24 hours
export const userCacheKey = 'ww-auth:user';

13
src/user/user.module.ts Normal file
View file

@ -0,0 +1,13 @@
import { Module } from '@nestjs/common';
import { TypeOrmModule } from '@nestjs/typeorm';
import { User } from '../database/models/user.model';
import { RedisModule } from '../redis/redis.module';
@Module({
imports: [TypeOrmModule.forFeature([User]), RedisModule],
controllers: [],
providers: [],
exports: [],
})
export class UserModule {}

62
src/user/user.service.ts Normal file
View file

@ -0,0 +1,62 @@
import { Injectable, NotFoundException } from '@nestjs/common';
import { InjectRepository } from '@nestjs/typeorm';
import { Repository } from 'typeorm';
import { User } from '../database/models/user.model';
import { RedisService } from '../redis/redis.service';
import { Span } from 'nestjs-otel';
import { USER_NOT_FOUND_ERROR, userCacheKey } from './user.constant';
@Injectable()
export class UserService {
constructor(
@InjectRepository(User)
private readonly userRepository: Repository<User>,
private readonly redisService: RedisService,
) {}
/**
* Get a user by their ID
* @param id The user's ID
* @oaram relations The DB relations to include
* @returns Promise<User> The user
* @throws NotFoundException
*/
@Span()
async getUserById(id: string, relations: string[] = []): Promise<User> {
const cacheKey = userCacheKey + id;
const cachedUser = await this.redisService.get<User>(cacheKey);
if (cachedUser && relations.length === 0) {
return cachedUser;
}
const relationsToInclude: Record<string, boolean> = relations.reduce(
(acc: Record<string, boolean>, relation: string) => {
acc[relation] = true;
return acc;
},
{},
);
const queryBuilder = this.userRepository.createQueryBuilder('user');
queryBuilder.where('user.id = :id', { id });
// TODO: Update once org roles are introduced
for (const relation in relationsToInclude) {
queryBuilder.leftJoinAndSelect(`user.${relation}`, relation);
}
const user = await queryBuilder.getOne();
if (user && relations.length === 0) {
await this.redisService.set(cacheKey, user);
}
if (!user) {
throw new NotFoundException(USER_NOT_FOUND_ERROR);
}
return user;
}
}