diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/GRPCAuthInterceptor.java b/sdk/src/main/java/io/opentdf/platform/sdk/GRPCAuthInterceptor.java index 5a796489..a7e4e8b2 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/GRPCAuthInterceptor.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/GRPCAuthInterceptor.java @@ -5,7 +5,6 @@ import com.nimbusds.jose.jwk.RSAKey; import com.nimbusds.jwt.SignedJWT; import com.nimbusds.oauth2.sdk.AuthorizationGrant; -import com.nimbusds.oauth2.sdk.ClientCredentialsGrant; import com.nimbusds.oauth2.sdk.ErrorObject; import com.nimbusds.oauth2.sdk.TokenRequest; import com.nimbusds.oauth2.sdk.TokenResponse; @@ -41,6 +40,7 @@ class GRPCAuthInterceptor implements ClientInterceptor { private final ClientAuthentication clientAuth; private final RSAKey rsaKey; private final URI tokenEndpointURI; + private final AuthorizationGrant authzGrant; private SSLFactory sslFactory; private static final Logger logger = LoggerFactory.getLogger(GRPCAuthInterceptor.class); @@ -52,11 +52,12 @@ class GRPCAuthInterceptor implements ClientInterceptor { * @param rsaKey the RSA key to be used by the interceptor * @param sslFactory Optional SSLFactory for Requests */ - public GRPCAuthInterceptor(ClientAuthentication clientAuth, RSAKey rsaKey, URI tokenEndpointURI, SSLFactory sslFactory) { + public GRPCAuthInterceptor(ClientAuthentication clientAuth, RSAKey rsaKey, URI tokenEndpointURI, AuthorizationGrant authzGrant, SSLFactory sslFactory) { this.clientAuth = clientAuth; this.rsaKey = rsaKey; this.tokenEndpointURI = tokenEndpointURI; this.sslFactory = sslFactory; + this.authzGrant = authzGrant; } /** @@ -110,12 +111,9 @@ private synchronized AccessToken getToken() { logger.trace("The current access token is expired or empty, getting a new one"); - // Construct the client credentials grant - AuthorizationGrant clientGrant = new ClientCredentialsGrant(); - // Make the token request TokenRequest tokenRequest = new TokenRequest(this.tokenEndpointURI, - clientAuth, clientGrant, null); + clientAuth, authzGrant, null); HTTPRequest httpRequest = tokenRequest.toHTTPRequest(); if(sslFactory!=null){ httpRequest.setSSLSocketFactory(sslFactory.getSslSocketFactory()); diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java b/sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java index be13bdd0..4068b5e0 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java @@ -4,12 +4,17 @@ import com.nimbusds.jose.jwk.KeyUse; import com.nimbusds.jose.jwk.RSAKey; import com.nimbusds.jose.jwk.gen.RSAKeyGenerator; +import com.nimbusds.oauth2.sdk.AuthorizationGrant; +import com.nimbusds.oauth2.sdk.ClientCredentialsGrant; import com.nimbusds.oauth2.sdk.GeneralException; import com.nimbusds.oauth2.sdk.auth.ClientAuthentication; import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic; import com.nimbusds.oauth2.sdk.auth.Secret; import com.nimbusds.oauth2.sdk.id.ClientID; import com.nimbusds.oauth2.sdk.id.Issuer; +import com.nimbusds.oauth2.sdk.token.BearerAccessToken; +import com.nimbusds.oauth2.sdk.token.TokenTypeURI; +import com.nimbusds.oauth2.sdk.tokenexchange.TokenExchangeGrant; import com.nimbusds.openid.connect.sdk.op.OIDCProviderMetadata; import io.grpc.*; import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationRequest; @@ -41,6 +46,7 @@ public class SDKBuilder { private ClientAuthentication clientAuth = null; private Boolean usePlainText; private SSLFactory sslFactory; + private AuthorizationGrant authzGrant; private static final Logger logger = LoggerFactory.getLogger(SDKBuilder.class); @@ -49,6 +55,7 @@ public static SDKBuilder newBuilder() { builder.usePlainText = false; builder.clientAuth = null; builder.platformEndpoint = null; + builder.authzGrant = null; return builder; } @@ -99,6 +106,24 @@ public SDKBuilder platformEndpoint(String platformEndpoint) { return this; } + public SDKBuilder authorizationGrant(AuthorizationGrant authzGrant) { + if (this.authzGrant != null) { + throw new RuntimeException("Authorization grant can't be specified twice"); + } + this.authzGrant = authzGrant; + return this; + } + + public SDKBuilder tokenExchange(String jwt) { + if (this.authzGrant != null) { + throw new RuntimeException("Authorization grant can't be specified twice"); + } + + BearerAccessToken token = new BearerAccessToken(jwt); + this.authzGrant = new TokenExchangeGrant(token, TokenTypeURI.ACCESS_TOKEN); + return this; + } + public SDKBuilder clientSecret(String clientID, String clientSecret) { ClientID cid = new ClientID(clientID); Secret cs = new Secret(clientSecret); @@ -168,7 +193,11 @@ private GRPCAuthInterceptor getGrpcAuthInterceptor(RSAKey rsaKey) { throw new SDKException("Error resolving the OIDC provider metadata", e); } - return new GRPCAuthInterceptor(clientAuth, rsaKey, providerMetadata.getTokenEndpointURI(), sslFactory); + if (this.authzGrant == null) { + this.authzGrant = new ClientCredentialsGrant(); + } + + return new GRPCAuthInterceptor(clientAuth, rsaKey, providerMetadata.getTokenEndpointURI(), this.authzGrant, sslFactory); } static class ServicesAndInternals {