Skip to content

Commit

Permalink
implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
rmehta19 committed Nov 1, 2024
1 parent b456935 commit 3510643
Show file tree
Hide file tree
Showing 11 changed files with 285 additions and 31 deletions.
2 changes: 2 additions & 0 deletions gapic-generator-java/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -397,10 +397,12 @@
<dependency>
<groupId>com.google.api</groupId>
<artifactId>gax-grpc</artifactId>
<version>2.57.1-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>com.google.api</groupId>
<artifactId>gax-grpc</artifactId>
<version>2.57.1-SNAPSHOT</version>
<!-- import the test code, https://maven.apache.org/plugins/maven-jar-plugin/examples/create-test-jar.html -->
<type>test-jar</type>
<classifier>testlib</classifier>
Expand Down
1 change: 1 addition & 0 deletions gax-java/gax-grpc/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ _COMPILE_DEPS = [
"@io_grpc_grpc_netty_shaded//jar",
"@io_grpc_grpc_grpclb//jar",
"@io_grpc_grpc_java//alts:alts",
"@io_grpc_grpc_java//s2a:s2a",
"@io_netty_netty_tcnative_boringssl_static//jar",
"@javax_annotation_javax_annotation_api//jar",
"//gax:gax",
Expand Down
4 changes: 4 additions & 0 deletions gax-java/gax-grpc/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@
<groupId>io.grpc</groupId>
<artifactId>grpc-protobuf</artifactId>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-s2a</artifactId>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-stub</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import com.google.auth.ApiKeyCredentials;
import com.google.auth.Credentials;
import com.google.auth.oauth2.ComputeEngineCredentials;
import com.google.auth.oauth2.S2A;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
Expand All @@ -54,11 +55,13 @@
import io.grpc.CallCredentials;
import io.grpc.ChannelCredentials;
import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.TlsChannelCredentials;
import io.grpc.alts.GoogleDefaultChannelCredentials;
import io.grpc.auth.MoreCallCredentials;
import io.grpc.s2a.S2AChannelCredentials;
import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
Expand Down Expand Up @@ -99,6 +102,12 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP
@VisibleForTesting
static final String DIRECT_PATH_ENV_ENABLE_XDS = "GOOGLE_CLOUD_ENABLE_DIRECT_PATH_XDS";

private static final String S2A_ENV_ENABLE_USE_S2A = "EXPERIMENTAL_GOOGLE_API_USE_S2A";
private static final String MTLS_MDS_ROOT = "/run/google-mds-mtls/root.crt";
// The mTLS MDS credentials are formatted as the concatenation of a PEM-encoded certificate chain
// followed by a PEM-encoded private key.
private static final String MTLS_MDS_CERT_CHAIN_AND_KEY = "/run/google-mds-mtls/client.key";

static final long DIRECT_PATH_KEEP_ALIVE_TIME_SECONDS = 3600;
static final long DIRECT_PATH_KEEP_ALIVE_TIMEOUT_SECONDS = 20;
static final String GCE_PRODUCTION_NAME_PRIOR_2016 = "Google";
Expand All @@ -108,6 +117,7 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP
private final Executor executor;
private final HeaderProvider headerProvider;
private final String endpoint;
private final String mtlsEndpoint;
// TODO: remove. envProvider currently provides DirectPath environment variable, and is only used
// during initial rollout for DirectPath. This provider will be removed once the DirectPath
// environment is not used.
Expand Down Expand Up @@ -136,6 +146,7 @@ private InstantiatingGrpcChannelProvider(Builder builder) {
this.executor = builder.executor;
this.headerProvider = builder.headerProvider;
this.endpoint = builder.endpoint;
this.mtlsEndpoint = builder.mtlsEndpoint;
this.mtlsProvider = builder.mtlsProvider;
this.envProvider = builder.envProvider;
this.interceptorProvider = builder.interceptorProvider;
Expand Down Expand Up @@ -211,6 +222,10 @@ public boolean needsEndpoint() {
return endpoint == null;
}

public boolean needsMtlsEndpoint() {
return mtlsEndpoint == null;
}

/**
* Specify the endpoint the channel should connect to.
*
Expand All @@ -225,6 +240,20 @@ public TransportChannelProvider withEndpoint(String endpoint) {
return toBuilder().setEndpoint(endpoint).build();
}

/**
* Specify the MTLS endpoint.
*
* <p>The value of {@code mtlsEndpoint} must be of the form {@code host:port}.
*
* @param mtlsEndpoint
* @return A new {@link InstantiatingGrpcChannelProvider} with the specified MTLS endpoint
* configured
*/
public TransportChannelProvider withMtlsEndpoint(String mtlsEndpoint) {
validateEndpoint(mtlsEndpoint);
return toBuilder().setMtlsEndpoint(mtlsEndpoint).build();
}

/** @deprecated Please modify pool settings via {@link #toBuilder()} */
@Deprecated
@Override
Expand Down Expand Up @@ -410,6 +439,83 @@ ChannelCredentials createMtlsChannelCredentials() throws IOException, GeneralSec
return null;
}

@VisibleForTesting
boolean isGoogleS2AEnabled() {
String S2AEnv = envProvider.getenv(S2A_ENV_ENABLE_USE_S2A);
boolean isS2AEnv = Boolean.parseBoolean(S2AEnv);
if (isS2AEnv) {
return true;
}
return false;
}

@VisibleForTesting
boolean shouldUseS2A() {
// If EXPERIMENTAL_GOOGLE_API_USE_S2A is not set to true, skip S2A.
if (!isGoogleS2AEnabled()) {
return false;
}

// If {@link mtlsEndpoint} is not set, skip S2A. S2A is also skipped when there is endpoint
// override. Endpoint override is respected when the {@link endpoint} is resolved via AIP#4114,
// see EndpointContext.java
if (endpoint != mtlsEndpoint) {
return false;
}

// mTLS via S2A is not supported in any universe other than googleapis.com.
if (!endpoint.contains(Credentials.GOOGLE_DEFAULT_UNIVERSE)) {
return false;
}

return true;
}

@VisibleForTesting
ChannelCredentials createMtlsToS2AChannelCredentials() throws IOException {
if (!isOnComputeEngine()) {
// Currently, MTLS to MDS is only available on GCE. See:
// https://cloud.google.com/compute/docs/metadata/overview#https-mds
return null;
}
File privateKeyFile = new File(MTLS_MDS_CERT_CHAIN_AND_KEY);
File certChainFile = new File(MTLS_MDS_CERT_CHAIN_AND_KEY);
File trustBundleFile = new File(MTLS_MDS_ROOT);
if (!privateKeyFile.isFile() || !certChainFile.isFile() || !trustBundleFile.isFile()) {
return null;
}
return TlsChannelCredentials.newBuilder()
.keyManager(privateKeyFile, certChainFile)
.trustManager(trustBundleFile)
.build();
}

@VisibleForTesting
ChannelCredentials createS2ASecuredChannelCredentials() {
S2A s2aUtils = S2A.newBuilder().build();
String plaintextAddress = s2aUtils.getPlaintextS2AAddress();
String mtlsAddress = s2aUtils.getMtlsS2AAddress();
if (!mtlsAddress.isEmpty()) {
try {
// Try to connect to S2A using mTLS.
ChannelCredentials mtlsToS2AChannelCredentials = createMtlsToS2AChannelCredentials();
if (mtlsToS2AChannelCredentials != null) {
return S2AChannelCredentials.newBuilder(mtlsAddress, mtlsToS2AChannelCredentials).build();
}
} catch (IOException ignore) {
// Fallback to plaintext connection to S2A.
}
}

if (!plaintextAddress.isEmpty()) {
// Fallback to plaintext connection to S2A.
return S2AChannelCredentials.newBuilder(plaintextAddress, InsecureChannelCredentials.create())
.build();
}

return null;
}

private ManagedChannel createSingleChannel() throws IOException {
GrpcHeaderInterceptor headerInterceptor =
new GrpcHeaderInterceptor(headersWithDuplicatesRemoved);
Expand Down Expand Up @@ -447,16 +553,30 @@ private ManagedChannel createSingleChannel() throws IOException {
builder.keepAliveTime(DIRECT_PATH_KEEP_ALIVE_TIME_SECONDS, TimeUnit.SECONDS);
builder.keepAliveTimeout(DIRECT_PATH_KEEP_ALIVE_TIMEOUT_SECONDS, TimeUnit.SECONDS);
} else {
// Try and create credentials via DCA. See https://google.aip.dev/auth/4114.
ChannelCredentials channelCredentials;
try {
channelCredentials = createMtlsChannelCredentials();
} catch (GeneralSecurityException e) {
throw new IOException(e);
}
if (channelCredentials != null) {
// Create the channel using channel credentials created via DCA.
builder = Grpc.newChannelBuilder(endpoint, channelCredentials);
} else {
builder = ManagedChannelBuilder.forAddress(serviceAddress, port);
// Could not create channel credentials via DCA. In accordance with
// https://google.aip.dev/auth/4115, if credentials not available through
// DCA, try mTLS with credentials held by the S2A (Secure Session Agent).
if (shouldUseS2A()) {
channelCredentials = createS2ASecuredChannelCredentials();
}
if (channelCredentials != null) {
// Create the channel using S2A-secured channel credentials.
builder = Grpc.newChannelBuilder(mtlsEndpoint, channelCredentials);
} else {
// Use default if we cannot initialize channel credentials via DCA or S2A.
builder = ManagedChannelBuilder.forAddress(serviceAddress, port);
}
}
}
// google-c2p resolver requires service config lookup
Expand Down Expand Up @@ -547,6 +667,11 @@ public String getEndpoint() {
return endpoint;
}

/** The mTLS endpoint. */
public String getMtlsEndpoint() {
return mtlsEndpoint;
}

/** This method is obsolete. Use {@link #getKeepAliveTimeDuration()} instead. */
@ObsoleteApi("Use getKeepAliveTimeDuration() instead")
public org.threeten.bp.Duration getKeepAliveTime() {
Expand Down Expand Up @@ -604,6 +729,7 @@ public static final class Builder {
private Executor executor;
private HeaderProvider headerProvider;
private String endpoint;
private String mtlsEndpoint;
private EnvironmentProvider envProvider;
private MtlsProvider mtlsProvider = new MtlsProvider();
@Nullable private GrpcInterceptorProvider interceptorProvider;
Expand Down Expand Up @@ -632,6 +758,7 @@ private Builder(InstantiatingGrpcChannelProvider provider) {
this.executor = provider.executor;
this.headerProvider = provider.headerProvider;
this.endpoint = provider.endpoint;
this.mtlsEndpoint = provider.mtlsEndpoint;
this.envProvider = provider.envProvider;
this.interceptorProvider = provider.interceptorProvider;
this.maxInboundMessageSize = provider.maxInboundMessageSize;
Expand Down Expand Up @@ -700,6 +827,12 @@ public Builder setEndpoint(String endpoint) {
return this;
}

public Builder setMtlsEndpoint(String mtlsEndpoint) {
validateEndpoint(mtlsEndpoint);
this.mtlsEndpoint = mtlsEndpoint;
return this;
}

@VisibleForTesting
Builder setMtlsProvider(MtlsProvider mtlsProvider) {
this.mtlsProvider = mtlsProvider;
Expand All @@ -722,6 +855,10 @@ public String getEndpoint() {
return endpoint;
}

public String getMtlsEndpoint() {
return mtlsEndpoint;
}

/** The maximum message size allowed to be received on the channel. */
public Builder setMaxInboundMessageSize(Integer max) {
this.maxInboundMessageSize = max;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,21 @@ public boolean needsEndpoint() {
return false;
}

@Override
public boolean needsMtlsEndpoint() {
return false;
}

@Override
public TransportChannelProvider withEndpoint(String endpoint) {
throw new UnsupportedOperationException("LocalChannelProvider doesn't need an endpoint");
}

@Override
public TransportChannelProvider withMtlsEndpoint(String mtlsEndpoint) {
throw new UnsupportedOperationException("LocalChannelProvider doesn't need an mtlsEndpoint");
}

@Override
@BetaApi("The surface for customizing pool size is not stable yet and may change in the future.")
public boolean acceptsPoolSize() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,22 @@ public boolean needsEndpoint() {
return endpoint == null;
}

@Override
public boolean needsMtlsEndpoint() {
return false;
}

@Override
public TransportChannelProvider withEndpoint(String endpoint) {
return toBuilder().setEndpoint(endpoint).build();
}

@Override
public TransportChannelProvider withMtlsEndpoint(String mtlsEndpoint) {
throw new UnsupportedOperationException(
"InstantiatingHttpJsonChannelProvider doesn't need an mtlsEndpoint");
}

/** @deprecated REST transport channel doesn't support channel pooling */
@Deprecated
@Override
Expand Down
5 changes: 5 additions & 0 deletions gax-java/gax/clirr-ignored-differences.xml
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,9 @@
<className>com/google/api/gax/batching/Batcher</className>
<method>*</method>
</difference>
<difference>
<differenceType>7012</differenceType>
<className>com/google/api/gax/rpc/TransportChannelProvider</className>
<method>*</method>
</difference>
</differences>
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ public static ClientContext create(StubSettings settings) throws IOException {
// A valid EndpointContext should have been created in the StubSettings
EndpointContext endpointContext = settings.getEndpointContext();
String endpoint = endpointContext.resolvedEndpoint();
String mtlsEndpoint = settings.getMtlsEndpoint();
Credentials credentials = getCredentials(settings);
// check if need to adjust credentials/endpoint/endpointContext for GDC-H
String settingsGdchApiAudience = settings.getGdchApiAudience();
Expand Down Expand Up @@ -222,6 +223,9 @@ public static ClientContext create(StubSettings settings) throws IOException {
if (transportChannelProvider.needsEndpoint()) {
transportChannelProvider = transportChannelProvider.withEndpoint(endpoint);
}
if (transportChannelProvider.needsMtlsEndpoint()) {
transportChannelProvider = transportChannelProvider.withMtlsEndpoint(mtlsEndpoint);
}
TransportChannel transportChannel = transportChannelProvider.getTransportChannel();

ApiCallContext defaultCallContext =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,23 @@ public boolean needsEndpoint() {
return false;
}

@Override
public boolean needsMtlsEndpoint() {
return false;
}

@Override
public TransportChannelProvider withEndpoint(String endpoint) {
throw new UnsupportedOperationException(
"FixedTransportChannelProvider doesn't need an endpoint");
}

@Override
public TransportChannelProvider withMtlsEndpoint(String mtlsEndpoint) {
throw new UnsupportedOperationException(
"FixedTransportChannelProvider doesn't need an mtlsEndpoint");
}

/** @deprecated FixedTransportChannelProvider doesn't support ChannelPool configuration */
@Deprecated
@Override
Expand Down
Loading

0 comments on commit 3510643

Please sign in to comment.