Implement state checking and csrf protection

## Purpose

* Fixes for http://roadmap.entgra.net/issues/9846

Co-authored-by: rajitha <rajitha@entgra.io>
Reviewed-on: community/device-mgt-core#232
Co-authored-by: Rajitha Kumara <rajitha@entgra.io>
Co-committed-by: Rajitha Kumara <rajitha@entgra.io>
master
Rajitha Kumara 1 year ago committed by Lasantha Dharmakeerthi
parent 327f507aa8
commit 69efff10bd

@ -27,6 +27,7 @@ import io.entgra.device.mgt.core.ui.request.interceptor.util.HandlerUtil;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.apache.http.HttpHeaders; import org.apache.http.HttpHeaders;
import org.apache.http.HttpStatus;
import org.apache.http.client.methods.HttpPost; import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ContentType; import org.apache.http.entity.ContentType;
import org.apache.http.entity.StringEntity; import org.apache.http.entity.StringEntity;
@ -39,6 +40,7 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession; import javax.servlet.http.HttpSession;
import java.io.IOException; import java.io.IOException;
import java.util.Objects;
@MultipartConfig @MultipartConfig
@WebServlet("/ssoLoginCallback") @WebServlet("/ssoLoginCallback")
@ -47,6 +49,7 @@ public class SsoLoginCallbackHandler extends HttpServlet {
@Override @Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException { protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException {
String state = req.getParameter("state");
String code = req.getParameter("code"); String code = req.getParameter("code");
HttpSession session = req.getSession(false); HttpSession session = req.getSession(false);
@ -66,6 +69,11 @@ public class SsoLoginCallbackHandler extends HttpServlet {
return; return;
} }
if (state == null || !Objects.equals(state, session.getAttribute("state").toString())) {
resp.sendError(HttpStatus.SC_BAD_REQUEST, "MismatchingStateError: CSRF Warning! State not equal in request and response");
return;
}
String scope = session.getAttribute("scope").toString(); String scope = session.getAttribute("scope").toString();
HttpPost tokenEndpoint = new HttpPost(keyManagerUrl + HandlerConstants.OAUTH2_TOKEN_ENDPOINT); HttpPost tokenEndpoint = new HttpPost(keyManagerUrl + HandlerConstants.OAUTH2_TOKEN_ENDPOINT);
@ -75,7 +83,7 @@ public class SsoLoginCallbackHandler extends HttpServlet {
String loginCallbackUrl = iotsCoreUrl + req.getContextPath() + HandlerConstants.SSO_LOGIN_CALLBACK; String loginCallbackUrl = iotsCoreUrl + req.getContextPath() + HandlerConstants.SSO_LOGIN_CALLBACK;
StringEntity tokenEPPayload = new StringEntity( StringEntity tokenEPPayload = new StringEntity(
"grant_type=" + HandlerConstants.CODE_GRANT_TYPE + "&code=" + code + "&state=&scope=" + scope + "grant_type=" + HandlerConstants.CODE_GRANT_TYPE + "&code=" + code + "&scope=" + scope +
"&redirect_uri=" + loginCallbackUrl, "&redirect_uri=" + loginCallbackUrl,
ContentType.APPLICATION_FORM_URLENCODED); ContentType.APPLICATION_FORM_URLENCODED);
tokenEndpoint.setEntity(tokenEPPayload); tokenEndpoint.setEntity(tokenEPPayload);

@ -86,6 +86,7 @@ public class SsoLoginHandler extends HttpServlet {
private LoginCache loginCache; private LoginCache loginCache;
private OAuthApp oAuthApp; private OAuthApp oAuthApp;
private OAuthAppCacheKey oAuthAppCacheKey; private OAuthAppCacheKey oAuthAppCacheKey;
private String state;
@Override @Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) { protected void doGet(HttpServletRequest req, HttpServletResponse resp) {
@ -97,6 +98,7 @@ public class SsoLoginHandler extends HttpServlet {
httpSession = req.getSession(true); httpSession = req.getSession(true);
state = HandlerUtil.generateStateToken();
initializeAdminCredentials(); initializeAdminCredentials();
baseContextPath = req.getContextPath(); baseContextPath = req.getContextPath();
applicationName = baseContextPath.substring(1, baseContextPath.indexOf("-ui-request-handler")); applicationName = baseContextPath.substring(1, baseContextPath.indexOf("-ui-request-handler"));
@ -127,12 +129,11 @@ public class SsoLoginHandler extends HttpServlet {
String scopesSsoString = HandlerUtil.getScopeString(scopesSsoJson); String scopesSsoString = HandlerUtil.getScopeString(scopesSsoJson);
String loginCallbackUrl = iotsCoreUrl + baseContextPath + HandlerConstants.SSO_LOGIN_CALLBACK; String loginCallbackUrl = iotsCoreUrl + baseContextPath + HandlerConstants.SSO_LOGIN_CALLBACK;
persistAuthSessionData(req, oAuthApp.getClientId(), oAuthApp.getClientSecret(), persistAuthSessionData(req, oAuthApp.getClientId(), oAuthApp.getClientSecret(),
oAuthApp.getEncodedClientApp(), scopesSsoString); oAuthApp.getEncodedClientApp(), scopesSsoString, state);
resp.sendRedirect(keyManagerUrl + HandlerConstants.AUTHORIZATION_ENDPOINT + resp.sendRedirect(keyManagerUrl + HandlerConstants.AUTHORIZATION_ENDPOINT +
"?response_type=code" + "?response_type=code" +
"&state=" + state +
"&client_id=" + clientId + "&client_id=" + clientId +
"&state=" +
"&scope=openid " + scopesSsoString + "&scope=openid " + scopesSsoString +
"&redirect_uri=" + loginCallbackUrl); "&redirect_uri=" + loginCallbackUrl);
} catch (IOException e) { } catch (IOException e) {
@ -186,7 +187,7 @@ public class SsoLoginHandler extends HttpServlet {
clientSecret = jClientAppResultAsJsonObject.get("client_secret").getAsString(); clientSecret = jClientAppResultAsJsonObject.get("client_secret").getAsString();
encodedClientApp = Base64.getEncoder().encodeToString((clientId + ":" + clientSecret).getBytes()); encodedClientApp = Base64.getEncoder().encodeToString((clientId + ":" + clientSecret).getBytes());
String scopesString = HandlerUtil.getScopeString(scopes); String scopesString = HandlerUtil.getScopeString(scopes);
persistAuthSessionData(req, clientId, clientSecret, encodedClientApp, scopesString); persistAuthSessionData(req, clientId, clientSecret, encodedClientApp, scopesString, state);
} }
// cache the oauth app credentials // cache the oauth app credentials
@ -287,13 +288,14 @@ public class SsoLoginHandler extends HttpServlet {
* @param scopes - User scopes * @param scopes - User scopes
*/ */
private void persistAuthSessionData(HttpServletRequest req, String clientId, String clientSecret, private void persistAuthSessionData(HttpServletRequest req, String clientId, String clientSecret,
String encodedClientApp, String scopes) { String encodedClientApp, String scopes, String state) {
httpSession = req.getSession(false); httpSession = req.getSession(false);
httpSession.setAttribute("clientId", clientId); httpSession.setAttribute("clientId", clientId);
httpSession.setAttribute("clientSecret", clientSecret); httpSession.setAttribute("clientSecret", clientSecret);
httpSession.setAttribute("encodedClientApp", encodedClientApp); httpSession.setAttribute("encodedClientApp", encodedClientApp);
httpSession.setAttribute("scope", scopes); httpSession.setAttribute("scope", scopes);
httpSession.setAttribute("redirectUrl", req.getParameter("redirect")); httpSession.setAttribute("redirectUrl", req.getParameter("redirect"));
httpSession.setAttribute("state", state);
httpSession.setMaxInactiveInterval(sessionTimeOut); httpSession.setMaxInactiveInterval(sessionTimeOut);
} }

@ -71,6 +71,8 @@ import java.io.IOException;
import java.io.InputStreamReader; import java.io.InputStreamReader;
import java.io.PrintWriter; import java.io.PrintWriter;
import java.io.StringWriter; import java.io.StringWriter;
import java.math.BigInteger;
import java.security.SecureRandom;
import java.util.Enumeration; import java.util.Enumeration;
import java.util.List; import java.util.List;
@ -763,4 +765,8 @@ public class HandlerUtil {
} }
return otpManagementService; return otpManagementService;
} }
public static String generateStateToken() {
return new BigInteger(130, new SecureRandom()).toString(32);
}
} }

Loading…
Cancel
Save