How to Implement Role-Based Access Control in NestJS with MongoDB - Part 5
Learn how to implement refresh tokens and invalidate tokens in NestJS with MongoDB in this part of the series
Introduction
This post is the fifth in a series titled "Implement Role-Based Access Control in NestJS using MongoDB." In this series, we'll create an RBAC (Role-Based Access Control) app from scratch using NestJS and MongoDB. Whether you're a beginner or looking to implement role-based access control in your app, follow along and build the app with me!
Previous Post
In the last post, we worked on the following things.
Adding Public Routes
Creating a custom Active User Decorator
Check out the blog post here
What's Next?
In this post, we will cover the following:
Implementing Refresh Tokens
Invalidating Tokens
Implementing Refresh Tokens
A refresh token is a special token used to obtain new access tokens. This allows us to have short-lived access tokens without forcing users to manually log in every time the original access token is expired. Once the access token expires, client applications can use a refresh token to "refresh" i.e. regenerate or retrieve the newly generated access token.
Refresh tokens should be returned alongside the access token as a part of the user's initial authentication and authorization flow.
Applications must securely store the refresh tokens.
First, we will add a new environment variable JWT_REFRESH_TOKEN_TTL
to our environment .env
file.
# .env file
JWT_REFRESH_TOKEN_TTL= 86400
Also, register the above variable in the JWT configuration as follows
// config/jwt.config.ts
export default registerAs('jwt', () => {
return {
secret: process.env.JWT_SECRET,
audience: process.env.JWT_TOKEN_AUDIENCE,
issuer: process.env.JWT_TOKEN_ISSUER,
accessTokenTtl: parseInt(process.env.JWT_ACCESS_TOKEN_TTL ?? '3600', 10),
refreshTokenTtl: parseInt(process.env.JWT_REFRESH_TOKEN_TTL ?? '86400', 10),
};
});
Next, we move to our authentication service and improve our code. We need to create a generic function to return the JWT access_token
and refresh_token
since both expiration times differ, let's refactor the token signing logic and extract it to a separate method.
// auth.service.ts
private async signToken<T>(userId: number, expiresIn: number, payload?: T) {
return await this.jwtService.signAsync(
{
sub: userId,
...payload,
},
{
audience: this.jwtConfiguration.audience,
issuer: this.jwtConfiguration.issuer,
secret: this.jwtConfiguration.secret,
expiresIn,
},
);
}
In the above code, we pass three different parameters to the SignToken()
method.
First,
userId
it represents the unique user ID,expiresIn
which can be different for both the access token and the refresh token,Lastly, the
payload
contains the additional set of properties to be signed and embedded.
Also, make sure to update the signIn()
method to replace the previous JWT signature method.
// auth.service.ts
async signIn(signInDto: signInDTo) {
...
...
const accessToken = await this.signToken<Partial<ActiveUserData>>(
user.id,
this.jwtConfiguration.accessTokenTtl,
{ email: user.email },
);
return {accessToken};
}
Above we are using the Partial<T>
utility type to make all the properties of the given type optional. With the generic signToken()
method, we can generate tokens for both the access token and the refresh token. We will use Promise.all()
it to generate both tokens in parallel.
For better code readability we scope this function to a different public method as it will be in use in future cases too. The below function is a robust, modular approach using Promise.all([])
to return both accessToken
and refreshToken
via one method.
// auth.service.ts
public async generateTokens(user: User) {
const [accessToken, refreshToken] = await Promise.all([
this.signToken<Partial<ActiveUserData>>(
user.id,
this.jwtConfiguration.accessTokenTtl,
{ email: user.email },
),
this.signToken(user.id, this.jwtConfiguration.refreshTokenTtl),
]);
return { accessToken, refreshToken };
}
Now if you run the POST sign-in
API endpoint you will get a response something like the below code snippet, where both accessToken
and refreshToken
is returned.
{
"accessToken": "xxxx.yyyyyy.zzz",
"refreshToken": "aaaa.bbbbbb.ccc"
}
Now to use the generated refresh token we will follow a simple approach to create a new endpoint that validates the incoming refresh token and generates access and refresh token pairs for us.
Let's create a new DTO for the refresh token by implementing the following code using your Nest CLI.
nest g class iam/auth/dto/refresh-token.dto --no-spec --flat
Within the newly created DTO class add a property name refreshToken
that is not empty.
// refresh-token.dto
import { IsNotEmpty } from 'class-validator';
export class RefreshTokenDto {
@IsNotEmpty()
refreshToken: string;
}
Next, we add the corresponding refresh-tokens
endpoint in the auth controller.
// auth.controller.ts
...
...
...
@Post('refresh-tokens')
refreshToken(@Body() refreshTokenDto: RefreshTokenDto) {
return this.authService.refreshTokens(refreshTokenDto);
}
Now moving to our auth service we first verify the incoming refresh token. We use the sub
which represents the user ID to retrieve the user entity from the database and based on the user we return the generated tokens.
// auth.service.ts
async refreshTokens(refreshTokenDto: RefreshTokenDto) {
try {
const { sub } = await this.jwtService.verifyAsync<
Pick<ActiveUserData, 'sub'>
>(refreshTokenDto.refreshToken, {
secret: this.jwtConfiguration.secret,
audience: this.jwtConfiguration.audience,
issuer: this.jwtConfiguration.issuer,
});
const user = await this.userModel.findById({ id: sub }).exec();
return await this.generateTokens(user);
} catch (error) {
throw new UnauthorizedException('Invalid refresh token');
}
}
To validate our functionality we need to test our refresh-tokens
endpoint, before testing that make sure you have copied the refreshToken
earlier. Use it in the POST refresh-tokens
endpoint to successfully regenerate a new set of tokens using the previous refresh token.
Users will now have the ability to re-authenticate themselves via refresh-token which makes our app quite secure while keeping our access token TTL short-lived.
Invalidating Tokens
The benefit of using JWTs is that they are stateless which means we don't have to query a central database to validate the token on each incoming request. As long as the token signature is valid and the token has not expired itself we can authenticate the request.
Now what if our token is leaked or compromised? How to invalidate those tokens to prevent the attacker from logging into our application?
The above issues can be used by revoking the tokens. One approach is to store the tokens in the database and check which tokens are valid and which ones are to be revoked. This makes our tokens stateful and the guard we implemented would have to query the database on every incoming request which will hurt our application performance.
We will implement a mechanism called a "refresh token rotation". In this approach whenever a refresh token is used to issue a new pair of tokens, we will invalidate the original token so that it cannot be used again.
We are going to add the Redis database to this stack as it is lightweight and provides great performance.
We will use our docker-compose.yml
file to add the Redis service and also add ioredis
a dependency package.
# docker.compose.yml
services:
redis:
image: redis
ports:
- '6379:6379'
restart: always
npm i ioredis
The next step is to make sure our docker container is running by implementing the following command:
docker compose up -d
Once everything is up and running, let's generate a RefreshTokenIdsStorage
class using the Nest CLI.
nest g class iam/storage/refresh-token-ids.storage --no-spec
Once the file is created, let's navigate to our IamModule()
and register the RefreshTokenIdsStorage
as a provider.
// iam.module.ts
@Module({
imports: [
MongooseModule.forFeature([{ name: User.name, schema: UserSchema }]),
JwtModule.registerAsync(jwtConfig.asProvider()),
ConfigModule.forFeature(jwtConfig),
],
providers: [
{
provide: HashingService,
useClass: BcryptService,
},
{ provide: APP_GUARD, useClass: AuthGuard },
AccessTokenGuard,
RefreshTokenIdsStorage,
AuthService,
],
controllers: [AuthController],
})
export class IamModule {}
Now going back to our storage file, let's implement two interfaces that hook us into NestJs lifecycle events. These are OnApplicationBootstrap
and OnApplicationShutdown
.
// refresh-token-ids.storage.ts
onApplicationBootstrap() {
this.redisClient = new Redis({
host: 'localhost',
port: 6379,
});
}
onApplicationShutdown(signal?: string) {
this.redisClient.quit();
}
For the application bootstrap, we are initiating the Redis connection and for the application shutdown, we are terminating the Redis connection because we want to make sure we clean up our resources.
Let's set four methods in this class:
Insert: Use to insert new entries into the Redis database.
Validate: Validate the token ID passed in and depending on the validation we implement will return true or false.
Invalidate: Help us to invalidate the token by removing the ID entry from our database.
And, lastly, a private method to getKey()
which constructs the entry ID based on the user identifier.
// refresh-token-ids.storage.ts
async insert(userId: number, tokenId: string): Promise<void> {
await this.redisClient.set(this.getKey(userId), tokenId);
}
async validate(userId: number, tokenId: string): Promise<boolean> {
const storedTokenId = await this.redisClient.get(this.getKey(userId));
if (storedTokenId !== tokenId) {
throw new Error('Invalid refresh token');
}
return storedTokenId === tokenId;
}
async invalidate(userId: number): Promise<void> {
await this.redisClient.del(this.getKey(userId));
}
private getKey(userId: number): string {
return `user-${userId}`;
}
The next step is to inject the new storage service into our authentication service constructor.
//auth.service.ts
constructor(
...
...
private readonly refreshTokenIdsStorage: RefreshTokenIdsStorage,
)
Now, before we sign our token, let's generate a unique, refresh token using the randomUUID()
method imported from the native crypto
package.
Also, pass this refreshTokenId
as a payload signToken()
to retrieve the refreshToken
. Next, we will add the newly generated refreshTokenId
to our Redis database.
//auth.service.ts
async generateTokens(user:User) {
const refreshTokenId = randomUUID();
const [accessToken, refreshToken] = await Promise.all([
this.signToken<Partial<ActiveUserData>>(
user.id,
this.jwtConfiguration.accessTokenTtl,
{ email: user.email },
),
this.signToken(user.id, this.jwtConfiguration.refreshTokenTtl, { refreshTokenId }),
]);
// insert the newly generated `refreshTokenId` in our Redis Database.
this.refreshTokenIdsStorage.insert(user.id, refreshTokenId);
return { accessToken, refreshToken };
}
Now, our generateTokens()
method is completed and let's refractor the refreshTokens()
method to include the above functionality.
//auth.service.ts
async refreshTokens(refreshTokenDto: RefreshTokenDto) {
try {
const { sub, refreshTokenId } = await this.jwtService.verifyAsync<
Pick<ActiveUserData, 'sub'> & { refreshTokenId: string }
>(refreshTokenDto.refreshToken, {
secret: this.jwtConfiguration.secret,
audience: this.jwtConfiguration.audience,
issuer: this.jwtConfiguration.issuer,
});
const user = await this.userModel.findById({ _id: sub }).exec();
const isValid = await this.refreshTokenIdsStorage.validate(
user.id,
refreshTokenId,
);
if (isValid) {
await this.refreshTokenIdsStorage.invalidate(user.id);
} else {
throw new Error('Invalid refresh token');
}
return await this.generateTokens(user);
} catch (error) {
throw new UnauthorizedException('Invalid refresh token');
}
}
In the above code, we first updated the type argument and then we validated the generated refresh token ID using the validate()
method added to our RefeshTokenIdsStorage
class.
If the token is valid then we are going to invalidate the token so that it is not used in any subsequent requests and in case it is invalid then we are going to show the corresponding error.
The above-mentioned technique is called "refresh token rotation" which we have successfully implemented it. Let's test it.
Start by SignIn
API endpoint to get a fresh set of accessToken
and refreshToken
. In case there are no records in DB you can use the SignUp
API endpoint to create a new user.
Response of SignIn API -
Now use the refreshToken
and hit the refresh-tokens
API endpoint which in our case is localhost:3000/auth/refresh-tokens. The first-time response of this API endpoint will be something like this.
As you can see using the previous refreshToken
we are getting a new set of accessToken
and refreshToken
.
Now let's hit this API endpoint again and use the same refreshToken
value as the JSON body. If you check the response, it won't be the same as shown in the previous image but something like shown below.
The above image shows that our refreshToken
is successfully invalidated and hence our whole approach to generate a refresh token and invalidate it after its first use is working perfectly fine.
Always remember that your accessToken
should be a short-lived token which in other words means higher value in terms of security and less time for an attacker to breach the content. In our case, it is 1 hour as added in the environment .env
file
Conclusion
We covered a lot, so let's summarize:
Implementing Refresh Tokens
Storing the newly generated Refresh Token in a key-value store database such as Redis.
Invalidating Token after its subsequent use.
That's it for today. Next, we will work on the Authorization module.
Next Post
Stay tuned for the next post,
The new blog post will be published by 14 July 2024.
References
https://blog.amanpreet.dev/how-to-implement-role-based-access-control-in-nestjs-with-mongodb-part-1
https://blog.amanpreet.dev/how-to-implement-role-based-access-control-in-nestjs-with-mongodb-part-2
https://blog.amanpreet.dev/how-to-implement-role-based-access-control-in-nestjs-with-mongodb-part-3
https://blog.amanpreet.dev/how-to-implement-role-based-access-control-in-nestjs-with-mongodb-part-4