feat: more implementing oidc-core

This commit is contained in:
Kakious 2024-07-18 22:51:36 -04:00
parent b38f5841e3
commit c65ba92322
17 changed files with 333 additions and 49 deletions

View file

@ -1,4 +1,7 @@
{ {
"singleQuote": true, "singleQuote": true,
"trailingComma": "all" "trailingComma": "all",
"tabWidth": 2,
"semi": true,
"printWidth": 100
} }

View file

@ -32,6 +32,7 @@
"@nestjs/config": "^3.2.3", "@nestjs/config": "^3.2.3",
"@nestjs/core": "^10.3.10", "@nestjs/core": "^10.3.10",
"@nestjs/platform-express": "^10.3.10", "@nestjs/platform-express": "^10.3.10",
"@nestjs/swagger": "^7.4.0",
"@nestjs/terminus": "^10.2.3", "@nestjs/terminus": "^10.2.3",
"@nestjs/typeorm": "^10.0.2", "@nestjs/typeorm": "^10.0.2",
"@opentelemetry/api": "^1.9.0", "@opentelemetry/api": "^1.9.0",
@ -52,7 +53,8 @@
"rxjs": "^7.8.1", "rxjs": "^7.8.1",
"typeorm": "^0.3.20", "typeorm": "^0.3.20",
"typia": "^6.5.1", "typia": "^6.5.1",
"uuid": "^10.0.0" "uuid": "^10.0.0",
"wildcard": "^2.0.1"
}, },
"devDependencies": { "devDependencies": {
"@nestjs/cli": "^10.4.2", "@nestjs/cli": "^10.4.2",

View file

@ -26,6 +26,9 @@ importers:
'@nestjs/platform-express': '@nestjs/platform-express':
specifier: ^10.3.10 specifier: ^10.3.10
version: 10.3.10(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/core@10.3.10) version: 10.3.10(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/core@10.3.10)
'@nestjs/swagger':
specifier: ^7.4.0
version: 7.4.0(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/core@10.3.10(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/platform-express@10.3.10)(reflect-metadata@0.2.2)(rxjs@7.8.1))(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)
'@nestjs/terminus': '@nestjs/terminus':
specifier: ^10.2.3 specifier: ^10.2.3
version: 10.2.3(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/core@10.3.10(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/platform-express@10.3.10)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/typeorm@10.0.2(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/core@10.3.10(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/platform-express@10.3.10)(reflect-metadata@0.2.2)(rxjs@7.8.1))(reflect-metadata@0.2.2)(rxjs@7.8.1)(typeorm@0.3.20(ioredis@5.4.1)(mysql2@3.10.2)(ts-node@10.9.2(@types/node@20.14.10)(typescript@5.5.3))))(reflect-metadata@0.2.2)(rxjs@7.8.1)(typeorm@0.3.20(ioredis@5.4.1)(mysql2@3.10.2)(ts-node@10.9.2(@types/node@20.14.10)(typescript@5.5.3))) version: 10.2.3(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/core@10.3.10(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/platform-express@10.3.10)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/typeorm@10.0.2(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/core@10.3.10(@nestjs/common@10.3.10(class-transformer@0.5.1)(class-validator@0.14.1)(reflect-metadata@0.2.2)(rxjs@7.8.1))(@nestjs/platform-express@10.3.10)(reflect-metadata@0.2.2)(rxjs@7.8.1))(reflect-metadata@0.2.2)(rxjs@7.8.1)(typeorm@0.3.20(ioredis@5.4.1)(mysql2@3.10.2)(ts-node@10.9.2(@types/node@20.14.10)(typescript@5.5.3))))(reflect-metadata@0.2.2)(rxjs@7.8.1)(typeorm@0.3.20(ioredis@5.4.1)(mysql2@3.10.2)(ts-node@10.9.2(@types/node@20.14.10)(typescript@5.5.3)))
@ -89,6 +92,9 @@ importers:
uuid: uuid:
specifier: ^10.0.0 specifier: ^10.0.0
version: 10.0.0 version: 10.0.0
wildcard:
specifier: ^2.0.1
version: 2.0.1
devDependencies: devDependencies:
'@nestjs/cli': '@nestjs/cli':
specifier: ^10.4.2 specifier: ^10.4.2
@ -4001,6 +4007,9 @@ packages:
resolution: {integrity: sha512-NsmoXalsWVDMGupxZ5R08ka9flZjjiLvHVAWYOKtiKM8ujtZWr9cRffak+uSE48+Ob8ObalXpwyeUiyDD6QFgg==} resolution: {integrity: sha512-NsmoXalsWVDMGupxZ5R08ka9flZjjiLvHVAWYOKtiKM8ujtZWr9cRffak+uSE48+Ob8ObalXpwyeUiyDD6QFgg==}
engines: {node: '>=8'} engines: {node: '>=8'}
wildcard@2.0.1:
resolution: {integrity: sha512-CC1bOL87PIWSBhDcTrdeLo6eGT7mCFtrg0uIJtqJUFyK+eJnzl8A1niH56uu7KMa5XFrtiV+AQuHO3n7DsHnLQ==}
with@7.0.2: with@7.0.2:
resolution: {integrity: sha512-RNGKj82nUPg3g5ygxkQl0R937xLyho1J24ItRCBTr/m1YnZkzJy1hUiHUJrc/VlsDQzsCnInEGSg3bci0Lmd4w==} resolution: {integrity: sha512-RNGKj82nUPg3g5ygxkQl0R937xLyho1J24ItRCBTr/m1YnZkzJy1hUiHUJrc/VlsDQzsCnInEGSg3bci0Lmd4w==}
engines: {node: '>= 10.0.0'} engines: {node: '>= 10.0.0'}
@ -8407,6 +8416,8 @@ snapshots:
dependencies: dependencies:
string-width: 4.2.3 string-width: 4.2.3
wildcard@2.0.1: {}
with@7.0.2: with@7.0.2:
dependencies: dependencies:
'@babel/parser': 7.24.8 '@babel/parser': 7.24.8

View file

@ -1,3 +1,11 @@
// This is the internal client ID for the application itself. // This is the internal client ID for the application itself.
export const internalClientId = 'system.internal.management'; export const internalClientId = 'system.internal.management';
export const internalRedirectTag = 'system.internalRedirect.management'; export const internalRedirectTag = 'system.internalRedirect.management';
// This is the master org configuration. This is the default org configuration that is used when no org is specified.
export const masterOrgId = 1;
export const masterOrgName = 'WaterWolf';
export const masterOrgSlug = 'waterwolf';
export const masterOrgAdminRole = 'admin';
export const masterOrgUserRole = 'user';

View file

@ -0,0 +1,19 @@
import { All, Controller, Req, Res, UseInterceptors } from '@nestjs/common';
import { ApiExcludeController } from '@nestjs/swagger';
import { OidcService } from '../oidc/core.service';
import { ExpressResErrorInterceptor } from '../../interceptor/express_res_error.interceptor';
@UseInterceptors(new ExpressResErrorInterceptor())
@ApiExcludeController()
@Controller('oidc')
export class OidcController {
constructor(private readonly oidcService: OidcService) {}
@All('/*')
// eslint-disable-next-line @typescript-eslint/explicit-function-return-type
public mountedOidc(@Req() req: any, @Res() res: any) {
req.url = req.originalUrl.replace('/oidc', '');
const callback = this.oidcService.provider.callback();
return callback(req, res);
}
}

View file

@ -3,8 +3,9 @@ import { Injectable, Logger, OnModuleInit } from '@nestjs/common';
import { promises as fs } from 'fs'; import { promises as fs } from 'fs';
import type Provider from 'oidc-provider'; import type Provider from 'oidc-provider';
import psl from 'psl'; import psl from 'psl';
import type { Configuration, errors, KoaContextWithOIDC } from 'oidc-provider'; import type { Configuration, errors } from 'oidc-provider';
import { createOidcAdapter } from './adapter'; import { createOidcAdapter } from './adapter';
import wildcard from 'wildcard';
import { import {
ACCESS_TOKEN_LIFE, ACCESS_TOKEN_LIFE,
AUTHORIZATION_TOKEN_LIFE, AUTHORIZATION_TOKEN_LIFE,
@ -17,6 +18,12 @@ import {
REFRESH_TOKEN_LIFE, REFRESH_TOKEN_LIFE,
SESSION_LIFE, SESSION_LIFE,
} from './oidc.const'; } from './oidc.const';
import { ConfigService } from '@nestjs/config';
import { DataSource } from 'typeorm';
import { RedisService } from '../../redis/redis.service';
import { UserService } from 'src/user/user.service';
import { Span } from 'nestjs-otel';
import generateId from './helper/nanoid.helper';
// This is an async import for the oidc-provider package as it's now only esm and we need to use it in a commonjs environment. // This is an async import for the oidc-provider package as it's now only esm and we need to use it in a commonjs environment.
async function getProvider(): Promise<{ async function getProvider(): Promise<{
@ -54,15 +61,24 @@ async function getProvider(): Promise<{
* OidcService is the service that handles all interaction with the OIDC library and package * OidcService is the service that handles all interaction with the OIDC library and package
*/ */
export class OidcService implements OnModuleInit { export class OidcService implements OnModuleInit {
private provider!: Provider; private _provider!: Provider;
private jwks: any; private jwks: any;
private cookies: any; private cookies: any;
private readonly logger = new Logger(OidcService.name); private readonly logger = new Logger(OidcService.name);
constructor() {} constructor(
private readonly userService: UserService,
private readonly configService: ConfigService,
private readonly dataSource: DataSource,
private readonly redisService: RedisService,
) {}
async onModuleInit() { async onModuleInit() {
await this.loadKeys();
const { provider, providerErrors } = await getProvider(); const { provider, providerErrors } = await getProvider();
const baseUrl = this.configService.getOrThrow<string>('base.localUrl');
const isOrigin = (value: string): boolean => { const isOrigin = (value: string): boolean => {
if (typeof value !== 'string') { if (typeof value !== 'string') {
@ -182,17 +198,14 @@ export class OidcService implements OnModuleInit {
}, },
}, },
clientBasedCORS(ctx, origin, client) { clientBasedCORS(ctx, origin, client) {
// ctx.oidc.route can be used to exclude endpoints from this behaviour, in that case just return
// true to always allow CORS on them, false to deny
// you may also allow some known internal origins if you want to
return (client['client_cors'] as string[]).includes(origin); return (client['client_cors'] as string[]).includes(origin);
}, },
jwks: this.jwks, jwks: this.jwks,
cookies: { cookies: {
keys: this.cookies, keys: this.cookies,
}, },
// TODO: FUCKING IMPLEMENT THIS // TODO: FUCKING IMPLEMENT THIS AHHHHHHHHHH
findAccount: this.findAccount.bind(this), findAccount: this.userService.oidcFindAccount.bind(this),
ttl: { ttl: {
AccessToken: ACCESS_TOKEN_LIFE, AccessToken: ACCESS_TOKEN_LIFE,
RefreshToken: function RefreshTokenTTL(ctx, token, client) { RefreshToken: function RefreshTokenTTL(ctx, token, client) {
@ -223,8 +236,7 @@ export class OidcService implements OnModuleInit {
} }
return ( return (
code.scopes.has('offline_access') || code.scopes.has('offline_access') ||
(client.applicationType === 'web' && (client.applicationType === 'web' && client.tokenEndpointAuthMethod === 'none')
client.tokenEndpointAuthMethod === 'none')
); );
}, },
features: { features: {
@ -242,9 +254,9 @@ export class OidcService implements OnModuleInit {
} }
if (token.clientId !== ctx.oidc.client?.clientId) { if (token.clientId !== ctx.oidc.client?.clientId) {
return ( return (client['allowed_introspection_targets'] as string[]).includes(
client['allowed_introspection_targets'] as string[] token.clientId!,
).includes(token.clientId!); );
} }
return true; return true;
}, },
@ -296,7 +308,6 @@ export class OidcService implements OnModuleInit {
openid: ['sub'], openid: ['sub'],
profile: ['name', 'preferred_username', 'picture', 'displayName'], profile: ['name', 'preferred_username', 'picture', 'displayName'],
email: ['email', 'email_verified'], email: ['email', 'email_verified'],
roles: ['roles', 'permissions'],
}, },
pkce: { pkce: {
methods: ['S256'], methods: ['S256'],
@ -305,6 +316,68 @@ export class OidcService implements OnModuleInit {
}, },
}, },
} as Configuration; } as Configuration;
const oidc = new provider(`${baseUrl}/oidc`, config) as Provider;
oidc.proxy = true;
const { redirectUriAllowed, postLogoutRedirectUriAllowed } = oidc.Client.prototype;
const hasWildcardHost = (redirectUri: string) => {
const { hostname } = new URL(redirectUri);
return hostname.includes('*');
};
const wildcardMatches = (redirectUri: string, wildcardUri: string) =>
!!wildcard(wildcardUri, redirectUri);
const logger = this.logger;
oidc.Client.prototype.redirectUriAllowed = function wildcardRedirectUriAllowed(redirectUri) {
logger.log(
{
redirectUri,
redirectUris: this.redirectUris,
},
'Checking redirect uri - 1',
);
if (!this.redirectUris) {
return redirectUriAllowed.call(this, redirectUri);
}
if (!this.redirectUris.some(hasWildcardHost)) {
return redirectUriAllowed.call(this, redirectUri);
}
const wildcardUris = this.redirectUris.filter(hasWildcardHost);
if (wildcardUris.some(wildcardMatches.bind(undefined, redirectUri))) {
return true;
}
return redirectUriAllowed.call(this, redirectUri);
};
oidc.Client.prototype.postLogoutRedirectUriAllowed =
function wildcardPostLogoutRedirectUriAllowed(redirectUri) {
if (!this.postLogoutRedirectUris) {
return postLogoutRedirectUriAllowed.call(this, redirectUri);
}
if (!this.postLogoutRedirectUris.some(hasWildcardHost)) {
return postLogoutRedirectUriAllowed.call(this, redirectUri);
}
const wildcardUris = this.postLogoutRedirectUris.filter(hasWildcardHost);
if (wildcardUris.some(wildcardMatches.bind(undefined, redirectUri))) {
return true;
}
return postLogoutRedirectUriAllowed.call(this, redirectUri);
};
this._provider = oidc;
}
public get provider(): Provider {
return this._provider;
} }
/// === SUPPORT FUNCTIONS === /// /// === SUPPORT FUNCTIONS === ///
@ -317,18 +390,12 @@ export class OidcService implements OnModuleInit {
const keysFolder = this.getKeysFolder(); const keysFolder = this.getKeysFolder();
try { try {
this.jwks = JSON.parse( this.jwks = JSON.parse(await fs.readFile(keysFolder + 'keys/jwks.json', 'utf8'));
await fs.readFile(keysFolder + 'keys/jwks.json', 'utf8'), this.cookies = JSON.parse(await fs.readFile(keysFolder + 'keys/cookies.json', 'utf8'));
);
this.cookies = JSON.parse(
await fs.readFile(keysFolder + 'keys/cookies.json', 'utf8'),
);
this.logger.debug(`Loaded oidc keys successfully`); this.logger.debug(`Loaded oidc keys successfully`);
} catch (error) { } catch (error) {
throw new Error( throw new Error(`Failed to load keys, keys should be located ${keysFolder}keys`);
`Failed to load keys, keys should be located ${keysFolder}keys`,
);
} }
} }
@ -338,4 +405,13 @@ export class OidcService implements OnModuleInit {
} }
return __dirname + '/../../../../'; return __dirname + '/../../../../';
} }
/**
* Generate a JTI for an OIDC object
* @returns string - The JTI
*/
@Span()
async generateJti(): Promise<string> {
return await generateId();
}
} }

View file

@ -0,0 +1,43 @@
// Write an error class that extends the base error class
//
export enum OidcErrorType {
InvalidRequest = 'invalid_request',
InvalidClient = 'invalid_client',
InvalidGrant = 'invalid_grant',
UnauthorizedClient = 'unauthorized_client',
InvalidSession = 'invalid_session',
}
export default class OidcError extends Error {
constructor(
public errorType: OidcErrorType,
public errorDescription: string,
public errorUri?: string,
) {
super(errorDescription);
}
public toOidcError() {
return {
error: this.errorType,
error_description: this.errorDescription,
error_uri: this.errorUri,
};
}
public static fromOidcError(oidcError: any) {
return new OidcError(oidcError.error, oidcError.error_description, oidcError.error_uri);
}
public static invalidSession() {
return new OidcError(OidcErrorType.InvalidSession, 'Session token is invalid');
}
public static invalidRequest() {
return new OidcError(OidcErrorType.InvalidRequest, 'Request is invalid');
}
public static invalidClient() {
return new OidcError(OidcErrorType.InvalidClient, 'Client is invalid');
}
}

View file

@ -1,13 +1,6 @@
import type { ClientAuthMethod, ResponseType } from 'oidc-provider'; import type { ClientAuthMethod, ResponseType } from 'oidc-provider';
import { import { Column, Entity, JoinTable, ManyToMany, OneToMany, PrimaryGeneratedColumn } from 'typeorm';
Column,
Entity,
JoinTable,
ManyToMany,
OneToMany,
PrimaryGeneratedColumn,
} from 'typeorm';
import { OidcClientPermission } from './oidc_client_permissions.model'; import { OidcClientPermission } from './oidc_client_permissions.model';
import { MAX_STRING_LENGTH } from '../database.const'; import { MAX_STRING_LENGTH } from '../database.const';
@ -20,7 +13,7 @@ export class OidcClient {
// Owner Org ID // Owner Org ID
@Column({ type: 'varchar', length: MAX_STRING_LENGTH, nullable: false }) @Column({ type: 'varchar', length: MAX_STRING_LENGTH, nullable: false })
ownerName: string; owner_org_id: string;
@Column({ type: 'varchar', length: MAX_STRING_LENGTH, nullable: false }) @Column({ type: 'varchar', length: MAX_STRING_LENGTH, nullable: false })
client_secret: string; client_secret: string;
@ -28,21 +21,24 @@ export class OidcClient {
@Column({ type: 'varchar', length: MAX_STRING_LENGTH, nullable: true }) @Column({ type: 'varchar', length: MAX_STRING_LENGTH, nullable: true })
client_name: string; client_name: string;
@Column({ type: 'simple-array', nullable: false }) @Column({ type: 'simple-array', nullable: false, default: [] })
redirect_uris: string[]; redirect_uris: string[];
@Column({ type: 'simple-array', nullable: true }) @Column({ type: 'simple-array', nullable: true, default: [] })
client_cors: string[]; client_cors: string[];
@Column({ type: 'simple-array', nullable: true }) @Column({ type: 'simple-array', nullable: true, default: [] })
allowed_introspection_targets: string[]; allowed_introspection_targets: string[];
@Column({ type: 'simple-array', nullable: false }) @Column({ type: 'simple-array', nullable: false, default: [] })
include_permissions_from_client: string[]; include_permissions_from_client: string[];
@Column({ type: 'simple-array', nullable: false }) @Column({ type: 'simple-array', nullable: false, default: [] })
post_logout_redirect_uris: string[]; post_logout_redirect_uris: string[];
@Column({ type: 'simple-array', nullable: true })
default_scopes: string[];
@Column({ type: 'simple-array', nullable: false }) @Column({ type: 'simple-array', nullable: false })
response_types: ResponseType[]; response_types: ResponseType[];
@ -64,10 +60,7 @@ export class OidcClient {
@OneToMany(() => OidcClientPermission, (permission) => permission.client) @OneToMany(() => OidcClientPermission, (permission) => permission.client)
permissions: OidcClientPermission[]; permissions: OidcClientPermission[];
@ManyToMany( @ManyToMany(() => OidcClientPermission, (permission) => permission.assignedClients)
() => OidcClientPermission,
(permission) => permission.assignedClients,
)
@JoinTable({ @JoinTable({
name: 'oidc_client_permissions', name: 'oidc_client_permissions',
joinColumn: { joinColumn: {

View file

@ -9,6 +9,9 @@ export class Organization {
@Column({ length: MAX_STRING_LENGTH }) @Column({ length: MAX_STRING_LENGTH })
name: string; name: string;
@Column({ length: MAX_STRING_LENGTH, unique: true })
slug: string;
@Column({ length: MAX_STRING_LENGTH, nullable: true }) @Column({ length: MAX_STRING_LENGTH, nullable: true })
logo?: string; logo?: string;

View file

@ -0,0 +1 @@
// This contains the inital setup data for the IdP. This should be standalone.

View file

@ -0,0 +1,44 @@
import type { CallHandler, ExecutionContext, NestInterceptor } from '@nestjs/common';
import { HttpException, Injectable, Logger } from '@nestjs/common';
import type { Response } from 'express';
import { catchError, type Observable, throwError } from 'rxjs';
/**
* Finishes the request if an error was unexpectedly thrown.
* This ensures the requests don't hang.
* See https://github.com/nestjs/nest/issues/5448
*/
@Injectable()
export class ExpressResErrorInterceptor implements NestInterceptor {
private readonly logger = new Logger(ExpressResErrorInterceptor.name);
intercept(ctx: ExecutionContext, next: CallHandler): Observable<unknown> {
const res = ctx.switchToHttp().getResponse<Response>();
return next.handle().pipe(
catchError((err: unknown) => {
if (!res.writableEnded) {
// Send an error response if it hasn't been done so the request doesn't hang.
this.logger.error(err, 'Uncaught internal error');
if (err instanceof HttpException) {
const statusCode = err.getStatus();
res.status(statusCode).send({
message: err.getResponse(),
statusCode,
});
} else {
res.status(500).send({
message: 'Internal Server Error',
error: 'Internal Server Error',
statusCode: 500,
});
}
}
return throwError(() => err);
}),
);
}
}

View file

@ -83,7 +83,7 @@ export class RedisService implements OnApplicationShutdown {
* @param key Key for the value to get * @param key Key for the value to get
* @returns * @returns
*/ */
public async get(key: string): Promise<string | object | null> { public async get<T = string>(key: string): Promise<T | null> {
this.logger.debug(`Getting key ${key}`); this.logger.debug(`Getting key ${key}`);
const value = await this._ioredis.get(key); const value = await this._ioredis.get(key);
if (!value) { if (!value) {
@ -91,9 +91,9 @@ export class RedisService implements OnApplicationShutdown {
} }
try { try {
return JSON.parse(value); return JSON.parse(value) as T;
} catch (error) { } catch (error) {
return value; return value as T;
} }
} }

0
src/s3/s3.module.ts Normal file
View file

0
src/s3/s3.service.ts Normal file
View file

View file

@ -0,0 +1,6 @@
export const USER_NOT_FOUND_ERROR = 'User not found';
// Caching Constants for Redis
export const userCacheTTL = 60 * 60 * 24; // 24 hours
export const userCacheKey = 'ww-auth:user';

13
src/user/user.module.ts Normal file
View file

@ -0,0 +1,13 @@
import { Module } from '@nestjs/common';
import { TypeOrmModule } from '@nestjs/typeorm';
import { User } from '../database/models/user.model';
import { RedisModule } from '../redis/redis.module';
@Module({
imports: [TypeOrmModule.forFeature([User]), RedisModule],
controllers: [],
providers: [],
exports: [],
})
export class UserModule {}

62
src/user/user.service.ts Normal file
View file

@ -0,0 +1,62 @@
import { Injectable, NotFoundException } from '@nestjs/common';
import { InjectRepository } from '@nestjs/typeorm';
import { Repository } from 'typeorm';
import { User } from '../database/models/user.model';
import { RedisService } from '../redis/redis.service';
import { Span } from 'nestjs-otel';
import { USER_NOT_FOUND_ERROR, userCacheKey } from './user.constant';
@Injectable()
export class UserService {
constructor(
@InjectRepository(User)
private readonly userRepository: Repository<User>,
private readonly redisService: RedisService,
) {}
/**
* Get a user by their ID
* @param id The user's ID
* @oaram relations The DB relations to include
* @returns Promise<User> The user
* @throws NotFoundException
*/
@Span()
async getUserById(id: string, relations: string[] = []): Promise<User> {
const cacheKey = userCacheKey + id;
const cachedUser = await this.redisService.get<User>(cacheKey);
if (cachedUser && relations.length === 0) {
return cachedUser;
}
const relationsToInclude: Record<string, boolean> = relations.reduce(
(acc: Record<string, boolean>, relation: string) => {
acc[relation] = true;
return acc;
},
{},
);
const queryBuilder = this.userRepository.createQueryBuilder('user');
queryBuilder.where('user.id = :id', { id });
// TODO: Update once org roles are introduced
for (const relation in relationsToInclude) {
queryBuilder.leftJoinAndSelect(`user.${relation}`, relation);
}
const user = await queryBuilder.getOne();
if (user && relations.length === 0) {
await this.redisService.set(cacheKey, user);
}
if (!user) {
throw new NotFoundException(USER_NOT_FOUND_ERROR);
}
return user;
}
}