diff --git a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java index 4bd63441e5..b75ba746f9 100644 --- a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java +++ b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java @@ -141,6 +141,7 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP @Nullable private final Boolean keepAliveWithoutCalls; private final ChannelPoolSettings channelPoolSettings; @Nullable private final Credentials credentials; + @Nullable private final CallCredentials altsCallCredentials; @Nullable private final CallCredentials mtlsS2ACallCredentials; @Nullable private final ChannelPrimer channelPrimer; @Nullable private final Boolean attemptDirectPath; @@ -191,6 +192,7 @@ private InstantiatingGrpcChannelProvider(Builder builder) { this.channelPoolSettings = builder.channelPoolSettings; this.channelConfigurator = builder.channelConfigurator; this.credentials = builder.credentials; + this.altsCallCredentials = builder.altsCallCredentials; this.mtlsS2ACallCredentials = builder.mtlsS2ACallCredentials; this.channelPrimer = builder.channelPrimer; this.attemptDirectPath = builder.attemptDirectPath; @@ -616,8 +618,14 @@ private ManagedChannel createSingleChannel() throws IOException { boolean useDirectPathXds = false; if (canUseDirectPath()) { CallCredentials callCreds = MoreCallCredentials.from(credentials); + // altsCallCredentials may be null and GoogleDefaultChannelCredentials + // will solely use callCreds. Otherwise it uses altsCallCredentials + // for DirectPath connections and callCreds for CloudPath fallbacks. ChannelCredentials channelCreds = - GoogleDefaultChannelCredentials.newBuilder().callCredentials(callCreds).build(); + GoogleDefaultChannelCredentials.newBuilder() + .callCredentials(callCreds) + .altsCallCredentials(altsCallCredentials) + .build(); useDirectPathXds = isDirectPathXdsEnabled(); if (useDirectPathXds) { // google-c2p: CloudToProd(C2P) Directpath. This scheme is defined in @@ -822,6 +830,7 @@ public static final class Builder { @Nullable private Boolean keepAliveWithoutCalls; @Nullable private ApiFunction channelConfigurator; @Nullable private Credentials credentials; + @Nullable private CallCredentials altsCallCredentials; @Nullable private CallCredentials mtlsS2ACallCredentials; @Nullable private ChannelPrimer channelPrimer; private ChannelPoolSettings channelPoolSettings; @@ -853,6 +862,7 @@ private Builder(InstantiatingGrpcChannelProvider provider) { this.keepAliveWithoutCalls = provider.keepAliveWithoutCalls; this.channelConfigurator = provider.channelConfigurator; this.credentials = provider.credentials; + this.altsCallCredentials = provider.altsCallCredentials; this.mtlsS2ACallCredentials = provider.mtlsS2ACallCredentials; this.channelPrimer = provider.channelPrimer; this.channelPoolSettings = provider.channelPoolSettings; @@ -919,6 +929,7 @@ Builder setUseS2A(boolean useS2A) { this.useS2A = useS2A; return this; } + /* * Sets the allowed hard bound token types for this TransportChannelProvider. * @@ -996,6 +1007,7 @@ public Integer getMaxInboundMetadataSize() { public Builder setKeepAliveTime(org.threeten.bp.Duration duration) { return setKeepAliveTimeDuration(toJavaTimeDuration(duration)); } + /** The time without read activity before sending a keepalive ping. */ public Builder setKeepAliveTimeDuration(java.time.Duration duration) { this.keepAliveTime = duration; @@ -1172,6 +1184,18 @@ boolean isMtlsS2AHardBoundTokensEnabled() { .anyMatch(val -> val.equals(HardBoundTokenTypes.MTLS_S2A)); } + boolean isDirectPathBoundTokenEnabled() { + // If the list of allowed hard bound token types is empty or doesn't contain + // {@code HardBoundTokenTypes.ALTS}, the {@code credentials} are null or not of type + // {@code ComputeEngineCredentials} then DirectPath hard bound tokens should not be used. + // DirectPath hard bound tokens should only be used on ALTS channels. + if (allowedHardBoundTokenTypes.isEmpty() + || this.credentials == null + || !(credentials instanceof ComputeEngineCredentials)) return false; + return allowedHardBoundTokenTypes.stream() + .anyMatch(val -> val.equals(HardBoundTokenTypes.ALTS)); + } + CallCredentials createHardBoundTokensCallCredentials( ComputeEngineCredentials.GoogleAuthTransport googleAuthTransport, ComputeEngineCredentials.BindingEnforcement bindingEnforcement) { @@ -1194,6 +1218,11 @@ public InstantiatingGrpcChannelProvider build() { ComputeEngineCredentials.GoogleAuthTransport.MTLS, ComputeEngineCredentials.BindingEnforcement.ON); } + if (isDirectPathBoundTokenEnabled()) { + this.altsCallCredentials = + createHardBoundTokensCallCredentials( + ComputeEngineCredentials.GoogleAuthTransport.ALTS, null); + } InstantiatingGrpcChannelProvider instantiatingGrpcChannelProvider = new InstantiatingGrpcChannelProvider(this); instantiatingGrpcChannelProvider.removeApiKeyCredentialDuplicateHeaders(); diff --git a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java index 9540235b18..86203ce47d 100644 --- a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java +++ b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java @@ -39,6 +39,7 @@ import com.google.api.core.ApiFunction; import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.Builder; +import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.HardBoundTokenTypes; import com.google.api.gax.rpc.FixedHeaderProvider; import com.google.api.gax.rpc.HeaderProvider; import com.google.api.gax.rpc.TransportChannel; @@ -735,6 +736,59 @@ public void canUseDirectPath_happyPath() throws IOException { .setEndpoint(DEFAULT_ENDPOINT) .setEnvProvider(envProvider) .setHeaderProvider(Mockito.mock(HeaderProvider.class)); + Truth.assertThat(builder.isDirectPathBoundTokenEnabled()).isFalse(); + InstantiatingGrpcChannelProvider provider = + new InstantiatingGrpcChannelProvider(builder, GCE_PRODUCTION_NAME_AFTER_2016); + Truth.assertThat(provider.canUseDirectPath()).isTrue(); + + // verify this info is passed correctly to transport channel + TransportChannel transportChannel = provider.getTransportChannel(); + Truth.assertThat(((GrpcTransportChannel) transportChannel).isDirectPath()).isTrue(); + transportChannel.shutdownNow(); + } + + @Test + public void canUseDirectPath_boundTokenNotEnabledWithNonComputeCredentials() { + System.setProperty("os.name", "Linux"); + Credentials credentials = Mockito.mock(Credentials.class); + EnvironmentProvider envProvider = Mockito.mock(EnvironmentProvider.class); + Mockito.when( + envProvider.getenv( + InstantiatingGrpcChannelProvider.DIRECT_PATH_ENV_DISABLE_DIRECT_PATH)) + .thenReturn("false"); + InstantiatingGrpcChannelProvider.Builder builder = + InstantiatingGrpcChannelProvider.newBuilder() + .setAttemptDirectPath(true) + .setAllowHardBoundTokenTypes(Collections.singletonList(HardBoundTokenTypes.ALTS)) + .setCredentials(credentials) + .setEndpoint(DEFAULT_ENDPOINT) + .setEnvProvider(envProvider); + Truth.assertThat(builder.isDirectPathBoundTokenEnabled()).isFalse(); + InstantiatingGrpcChannelProvider provider = + new InstantiatingGrpcChannelProvider(builder, GCE_PRODUCTION_NAME_AFTER_2016); + Truth.assertThat(provider.canUseDirectPath()).isFalse(); + } + + @Test + public void canUseDirectPath_happyPathWithBoundToken() throws IOException { + System.setProperty("os.name", "Linux"); + EnvironmentProvider envProvider = Mockito.mock(EnvironmentProvider.class); + Mockito.when( + envProvider.getenv( + InstantiatingGrpcChannelProvider.DIRECT_PATH_ENV_DISABLE_DIRECT_PATH)) + .thenReturn("false"); + // verify the credentials gets called and returns a non-null builder. + Mockito.when(computeEngineCredentials.toBuilder()) + .thenReturn(ComputeEngineCredentials.newBuilder()); + InstantiatingGrpcChannelProvider.Builder builder = + InstantiatingGrpcChannelProvider.newBuilder() + .setAttemptDirectPath(true) + .setCredentials(computeEngineCredentials) + .setAllowHardBoundTokenTypes(Collections.singletonList(HardBoundTokenTypes.ALTS)) + .setEndpoint(DEFAULT_ENDPOINT) + .setEnvProvider(envProvider) + .setHeaderProvider(Mockito.mock(HeaderProvider.class)); + Truth.assertThat(builder.isDirectPathBoundTokenEnabled()).isTrue(); InstantiatingGrpcChannelProvider provider = new InstantiatingGrpcChannelProvider(builder, GCE_PRODUCTION_NAME_AFTER_2016); Truth.assertThat(provider.canUseDirectPath()).isTrue();