Improve error handling for token endpoint

appcategoryfix
Vigneshan Seshamany 2 years ago
parent b14c9c7ad6
commit dd553a2326

@ -18,6 +18,7 @@
package org.wso2.carbon.apimgt.keymgt.extension.api; package org.wso2.carbon.apimgt.keymgt.extension.api;
import com.google.gson.Gson;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.wso2.carbon.apimgt.keymgt.extension.DCRResponse; import org.wso2.carbon.apimgt.keymgt.extension.DCRResponse;
@ -41,6 +42,8 @@ import java.util.Base64;
public class KeyManagerServiceImpl implements KeyManagerService { public class KeyManagerServiceImpl implements KeyManagerService {
Gson gson = new Gson();
@Override @Override
@POST @POST
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@ -51,7 +54,7 @@ public class KeyManagerServiceImpl implements KeyManagerService {
KeyMgtService keyMgtService = new KeyMgtServiceImpl(); KeyMgtService keyMgtService = new KeyMgtServiceImpl();
DCRResponse resp = keyMgtService.dynamicClientRegistration(dcrRequest.getApplicationName(), dcrRequest.getUsername(), DCRResponse resp = keyMgtService.dynamicClientRegistration(dcrRequest.getApplicationName(), dcrRequest.getUsername(),
dcrRequest.getGrantTypes(), dcrRequest.getCallBackUrl(), dcrRequest.getTags(), dcrRequest.getIsSaasApp()); dcrRequest.getGrantTypes(), dcrRequest.getCallBackUrl(), dcrRequest.getTags(), dcrRequest.getIsSaasApp());
return Response.status(Response.Status.CREATED).entity(resp).build(); return Response.status(Response.Status.CREATED).entity(gson.toJson(resp)).build();
} catch (KeyMgtException e) { } catch (KeyMgtException e) {
return Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity(e.getMessage()).build(); return Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity(e.getMessage()).build();
} }
@ -80,7 +83,7 @@ public class KeyManagerServiceImpl implements KeyManagerService {
new TokenRequest(encodedClientCredentials.split(":")[0], new TokenRequest(encodedClientCredentials.split(":")[0],
encodedClientCredentials.split(":")[1], refreshToken, scope, encodedClientCredentials.split(":")[1], refreshToken, scope,
grantType, assertion,admin_access_token)); grantType, assertion,admin_access_token));
return Response.status(Response.Status.CREATED).entity(resp).build(); return Response.status(Response.Status.CREATED).entity(gson.toJson(resp)).build();
} catch (KeyMgtException e) { } catch (KeyMgtException e) {
return Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity(e.getMessage()).build(); return Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity(e.getMessage()).build();
} catch (BadRequestException e) { } catch (BadRequestException e) {

@ -28,11 +28,9 @@ import okhttp3.RequestBody;
import okhttp3.Response; import okhttp3.Response;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.apache.http.HttpStatus;
import org.json.JSONObject; import org.json.JSONObject;
import org.wso2.carbon.apimgt.api.APIConsumer; import org.wso2.carbon.apimgt.api.APIConsumer;
import org.wso2.carbon.apimgt.api.APIManagementException; import org.wso2.carbon.apimgt.api.APIManagementException;
import org.wso2.carbon.apimgt.api.model.APIKey;
import org.wso2.carbon.apimgt.api.model.Application; import org.wso2.carbon.apimgt.api.model.Application;
import org.wso2.carbon.apimgt.impl.APIManagerFactory; import org.wso2.carbon.apimgt.impl.APIManagerFactory;
import org.wso2.carbon.apimgt.impl.utils.APIUtil; import org.wso2.carbon.apimgt.impl.utils.APIUtil;
@ -42,6 +40,7 @@ import org.wso2.carbon.apimgt.keymgt.extension.KeyMgtConstants;
import org.wso2.carbon.apimgt.keymgt.extension.OAuthApplication; import org.wso2.carbon.apimgt.keymgt.extension.OAuthApplication;
import org.wso2.carbon.apimgt.keymgt.extension.TokenRequest; import org.wso2.carbon.apimgt.keymgt.extension.TokenRequest;
import org.wso2.carbon.apimgt.keymgt.extension.TokenResponse; import org.wso2.carbon.apimgt.keymgt.extension.TokenResponse;
import org.wso2.carbon.apimgt.keymgt.extension.exception.BadRequestException;
import org.wso2.carbon.apimgt.keymgt.extension.exception.KeyMgtException; import org.wso2.carbon.apimgt.keymgt.extension.exception.KeyMgtException;
import org.wso2.carbon.context.PrivilegedCarbonContext; import org.wso2.carbon.context.PrivilegedCarbonContext;
import org.wso2.carbon.device.mgt.core.config.DeviceConfigurationManager; import org.wso2.carbon.device.mgt.core.config.DeviceConfigurationManager;
@ -64,7 +63,7 @@ import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Objects;
public class KeyMgtServiceImpl implements KeyMgtService { public class KeyMgtServiceImpl implements KeyMgtService {
@ -148,9 +147,17 @@ public class KeyMgtServiceImpl implements KeyMgtService {
} }
} }
public TokenResponse generateAccessToken(TokenRequest tokenRequest) throws KeyMgtException { public TokenResponse generateAccessToken(TokenRequest tokenRequest) throws KeyMgtException, BadRequestException {
try { try {
Application application = APIUtil.getApplicationByClientId(tokenRequest.getClientId()); Application application = APIUtil.getApplicationByClientId(tokenRequest.getClientId());
if (application == null) {
JSONObject errorResponse = new JSONObject();
errorResponse.put("error", "invalid_client");
errorResponse.put("error_description", "A valid OAuth client could not be found for client_id: "
+ tokenRequest.getClientId());
throw new BadRequestException(errorResponse.toString());
}
String tenantDomain = MultitenantUtils.getTenantDomain(application.getOwner()); String tenantDomain = MultitenantUtils.getTenantDomain(application.getOwner());
String username, password; String username, password;
@ -173,7 +180,6 @@ public class KeyMgtServiceImpl implements KeyMgtService {
} }
} }
JSONObject jsonObject = new JSONObject();
RequestBody appTokenPayload; RequestBody appTokenPayload;
switch (tokenRequest.getGrantType()) { switch (tokenRequest.getGrantType()) {
case "client_credentials": case "client_credentials":
@ -209,7 +215,6 @@ public class KeyMgtServiceImpl implements KeyMgtService {
.add("scope", tokenRequest.getScope()).build(); .add("scope", tokenRequest.getScope()).build();
break; break;
} }
jsonObject.put("scope", tokenRequest.getScope());
kmConfig = getKeyManagerConfig(); kmConfig = getKeyManagerConfig();
String appTokenEndpoint = kmConfig.getServerUrl() + KeyMgtConstants.OAUTH2_TOKEN_ENDPOINT; String appTokenEndpoint = kmConfig.getServerUrl() + KeyMgtConstants.OAUTH2_TOKEN_ENDPOINT;
@ -220,20 +225,25 @@ public class KeyMgtServiceImpl implements KeyMgtService {
.build(); .build();
Response response = client.newCall(request).execute(); Response response = client.newCall(request).execute();
jsonObject = new JSONObject(response.body().string()); JSONObject responseObj = new JSONObject(Objects.requireNonNull(response.body()).string());
if (!response.isSuccessful()) {
throw new BadRequestException(responseObj.toString());
}
String accessToken; String accessToken;
if (KeyMgtConstants.SUPER_TENANT.equals(tenantDomain)) { if (KeyMgtConstants.SUPER_TENANT.equals(tenantDomain)) {
accessToken = jsonObject.getString("access_token"); accessToken = responseObj.getString("access_token");
} else { } else {
int tenantId = getRealmService() int tenantId = getRealmService()
.getTenantManager().getTenantId(tenantDomain); .getTenantManager().getTenantId(tenantDomain);
accessToken = tenantId + "_" + jsonObject.getString("access_token"); accessToken = tenantId + "_" + responseObj.getString("access_token");
} }
return new TokenResponse(accessToken, return new TokenResponse(accessToken,
jsonObject.getString("refresh_token"), responseObj.getString("refresh_token"),
jsonObject.getString("scope"), responseObj.getString("scope"),
jsonObject.getString("token_type"), responseObj.getString("token_type"),
jsonObject.getInt("expires_in")); responseObj.getInt("expires_in"));
} catch (APIManagementException e) { } catch (APIManagementException e) {
msg = "Error occurred while retrieving application"; msg = "Error occurred while retrieving application";

Loading…
Cancel
Save