Improve the caching logic

master
Rajitha Kumara 1 year ago
parent 393b919342
commit dfb6fe0de3

@ -1,7 +1,7 @@
package io.entgra.token.validator.cache;
import io.entgra.token.validator.dto.TokenDTO;
import io.entgra.token.validator.entities.Token;
import io.entgra.token.validator.exception.TokenCacheException;
import java.util.HashMap;
@ -9,7 +9,7 @@ public class TokenCache {
private static volatile TokenCache instance;
private final HashMap<Token, Token> cache = new HashMap<>();
private final HashMap<String, Token> cache = new HashMap<>();
TokenCache() {}
@ -24,21 +24,23 @@ public class TokenCache {
return instance;
}
public void add(Token token) {
cache.put(token, token);
}
public void add(Token token) throws TokenCacheException {
if (token == null) throw new TokenCacheException("Can not cache null token");
if (token.getAccessToken() != null)
cache.put(token.getAccessToken(), token);
public Token get(TokenDTO tokenDTO) {
return cache.get(buildCacheKey(tokenDTO));
if (token.getRefreshToken() != null)
cache.put(token.getRefreshToken(), token);
}
private Token buildCacheKey(TokenDTO tokenDTO) {
Token key = new Token();
if (tokenDTO != null) {
key.setAccessToken(tokenDTO.getAccessToken());
key.setRefreshToken(tokenDTO.getRefreshToken());
}
return key;
public Token get(Token token) throws TokenCacheException {
if (token == null) throw new TokenCacheException("Error get cached value for a null token");
if (token.getAccessToken() != null)
return cache.get(token.getAccessToken());
return cache.get(token.getRefreshToken());
}
}

@ -4,10 +4,12 @@ import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
import javax.persistence.*;
import javax.persistence.Column;
import javax.persistence.Entity;
import javax.persistence.Id;
import javax.persistence.Table;
import java.io.Serializable;
import java.util.Date;
import java.util.Objects;
@NoArgsConstructor
@Getter
@ -15,9 +17,6 @@ import java.util.Objects;
@Entity
@Table(name = "IDN_OAUTH2_ACCESS_TOKEN")
public class Token implements Serializable {
@Transient
private volatile int hashCode;
@Id
@Column(name = "TOKEN_ID")
private String tokenId;
@ -87,24 +86,4 @@ public class Token implements Serializable {
@Column(name = "REFRESH_TOKEN_HASH")
private String refreshTokenHash;
@Override
public boolean equals(Object that) {
if (that == null) return false;
if (that instanceof Token) {
Token thatToken = (Token) that;
return Objects.equals(thatToken.accessToken, this.accessToken)
&& Objects.equals(thatToken.refreshToken, this.refreshToken);
}
return false;
}
@Override
public int hashCode() {
if (hashCode == 0) {
hashCode = Objects.hash(accessToken, refreshToken);
}
return hashCode;
}
}

@ -0,0 +1,11 @@
package io.entgra.token.validator.exception;
public class TokenCacheException extends Exception {
public TokenCacheException(String msg, Throwable t) {
super(msg, t);
}
public TokenCacheException(String msg) {
super(msg);
}
}

@ -6,6 +6,7 @@ import io.entgra.token.validator.dto.AccessTokenDTO;
import io.entgra.token.validator.dto.TokenDTO;
import io.entgra.token.validator.dto.ValidationInfoDTO;
import io.entgra.token.validator.entities.Token;
import io.entgra.token.validator.exception.TokenCacheException;
import io.entgra.token.validator.exception.ValidationException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -21,13 +22,16 @@ public class TokenService {
public TokenDAO tokenDAO;
public ValidationInfoDTO validateToken(AccessTokenDTO accessTokenDTO) throws ValidationException {
Token retrievedToken = tokenDAO.findByAccessToken(accessTokenDTO.getAccessToken());
Token cachedToken = getFromCache(mapToToken(accessTokenDTO));
Token retrievedToken = cachedToken != null ? cachedToken:
tokenDAO.findByAccessToken(accessTokenDTO.getAccessToken());
if (retrievedToken == null) {
String msg = "Can't validate the access token";
log.error(msg);
throw new ValidationException(msg);
}
TokenCache.getInstance().add(retrievedToken);
if (cachedToken == null) addToCache(retrievedToken);
String tenantDomain = extractTenantDomain(retrievedToken.getSubjectIdentifier(),
retrievedToken.getTenantId());
log.info("Validate the access token owns by " + retrievedToken.getAuthzUser() + "@" + tenantDomain);
@ -37,7 +41,7 @@ public class TokenService {
}
public TokenDTO renew(TokenDTO tokenDTO) throws ValidationException {
Token cachedToken = TokenCache.getInstance().get(tokenDTO);
Token cachedToken = getFromCache(mapToToken(tokenDTO));
Token retrievedToken = cachedToken != null ? cachedToken :
tokenDAO.findByRefreshToken(tokenDTO.getRefreshToken());
if (retrievedToken == null) {
@ -45,6 +49,8 @@ public class TokenService {
log.error(msg);
throw new ValidationException(msg);
}
if (cachedToken == null) addToCache(retrievedToken);
int expiresIn = retrievedToken.getValidityPeriod() > 0 ?
(int) retrievedToken.getValidityPeriod() / 1000 : 0;
log.info("Successfully renewed the token");
@ -60,4 +66,44 @@ public class TokenService {
return subjectIdentifier.split("@")[1];
}
private Token mapToToken(AccessTokenDTO accessTokenDTO) {
if (accessTokenDTO == null) return null;
Token token = new Token();
if (accessTokenDTO.getAccessToken() != null)
token.setAccessToken(accessTokenDTO.getAccessToken());
return token;
}
private Token mapToToken(TokenDTO tokenDTO) {
if (tokenDTO == null) return null;
Token token = new Token();
if (tokenDTO.getAccessToken() != null)
token.setAccessToken(tokenDTO.getAccessToken());
if (tokenDTO.getRefreshToken() != null)
token.setRefreshToken(tokenDTO.getRefreshToken());
return token;
}
private void addToCache(Token token) throws ValidationException{
try {
TokenCache.getInstance().add(token);
} catch (TokenCacheException e) {
String msg = "Error occurred while caching";
throw new ValidationException(msg, e);
}
}
private Token getFromCache(Token token) throws ValidationException {
try {
return TokenCache.getInstance().get(token);
} catch (TokenCacheException e) {
String msg = "Error occurred while getting the cached token";
throw new ValidationException(msg, e);
}
}
}

Loading…
Cancel
Save