diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java index 289acb1a745..297f190f9d9 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java @@ -684,22 +684,11 @@ QueryOptions buildQueryOptions(QueryOptions requestOptions) { } RequestOptions buildRequestOptions(Options options) { - // Shortcut for the most common return value. - if (!(options.hasPriority() || options.hasTag() || getTransactionTag() != null)) { - return RequestOptions.getDefaultInstance(); - } - - RequestOptions.Builder builder = RequestOptions.newBuilder(); - if (options.hasPriority()) { - builder.setPriority(options.priority()); - } - if (options.hasTag()) { - builder.setRequestTag(options.tag()); - } + RequestOptions requestOptions = options.toRequestOptionsProto(false); if (getTransactionTag() != null) { - builder.setTransactionTag(getTransactionTag()); + return requestOptions.toBuilder().setTransactionTag(getTransactionTag()).build(); } - return builder.build(); + return requestOptions; } ExecuteSqlRequest.Builder getExecuteSqlRequestBuilder( diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java index 1e6ce34d672..88810c6b73b 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java @@ -20,6 +20,7 @@ import com.google.spanner.v1.DirectedReadOptions; import com.google.spanner.v1.ReadRequest.LockHint; import com.google.spanner.v1.ReadRequest.OrderBy; +import com.google.spanner.v1.RequestOptions; import com.google.spanner.v1.RequestOptions.Priority; import com.google.spanner.v1.TransactionOptions.IsolationLevel; import com.google.spanner.v1.TransactionOptions.ReadWrite.ReadLockMode; @@ -265,6 +266,37 @@ public static ReadQueryUpdateTransactionOption priority(RpcPriority priority) { return new PriorityOption(priority); } + /** + * Specifying this will add the given client context to the request. The client context is used to + * pass side-channel or configuration information to the backend, such as a user ID for a parameterized + * secure view. + */ + public static ReadQueryUpdateTransactionOption clientContext( + RequestOptions.ClientContext clientContext) { + return new ClientContextOption(clientContext); + } + + RequestOptions toRequestOptionsProto(boolean isTransactionOption) { + if (!hasPriority() && !hasTag() && !hasClientContext()) { + return RequestOptions.getDefaultInstance(); + } + RequestOptions.Builder builder = RequestOptions.newBuilder(); + if (hasPriority()) { + builder.setPriority(priority()); + } + if (hasTag()) { + if (isTransactionOption) { + builder.setTransactionTag(tag()); + } else { + builder.setRequestTag(tag()); + } + } + if (hasClientContext()) { + builder.setClientContext(clientContext()); + } + return builder.build(); + } + public static TransactionOption maxCommitDelay(Duration maxCommitDelay) { Preconditions.checkArgument(!maxCommitDelay.isNegative(), "maxCommitDelay should be positive"); return new MaxCommitDelayOption(maxCommitDelay); @@ -462,6 +494,20 @@ void appendToOptions(Options options) { } } + static final class ClientContextOption extends InternalOption + implements ReadQueryUpdateTransactionOption { + private final RequestOptions.ClientContext clientContext; + + ClientContextOption(RequestOptions.ClientContext clientContext) { + this.clientContext = clientContext; + } + + @Override + void appendToOptions(Options options) { + options.clientContext = clientContext; + } + } + static final class TagOption extends InternalOption implements ReadQueryUpdateTransactionOption { private final String tag; @@ -574,6 +620,7 @@ void appendToOptions(Options options) { private String filter; private RpcPriority priority; private String tag; + private RequestOptions.ClientContext clientContext; private String etag; private Boolean validateOnly; private Boolean withExcludeTxnFromChangeStreams; @@ -666,6 +713,14 @@ Priority priority() { return priority == null ? null : priority.proto; } + boolean hasClientContext() { + return clientContext != null; + } + + RequestOptions.ClientContext clientContext() { + return clientContext; + } + boolean hasTag() { return tag != null; } @@ -777,6 +832,9 @@ public String toString() { if (priority != null) { b.append("priority: ").append(priority).append(' '); } + if (clientContext != null) { + b.append("clientContext: ").append(clientContext).append(' '); + } if (tag != null) { b.append("tag: ").append(tag).append(' '); } @@ -850,6 +908,7 @@ public boolean equals(Object o) { && Objects.equals(pageToken(), that.pageToken()) && Objects.equals(filter(), that.filter()) && Objects.equals(priority(), that.priority()) + && Objects.equals(clientContext(), that.clientContext()) && Objects.equals(tag(), that.tag()) && Objects.equals(etag(), that.etag()) && Objects.equals(validateOnly(), that.validateOnly()) @@ -894,6 +953,9 @@ public int hashCode() { if (priority != null) { result = 31 * result + priority.hashCode(); } + if (clientContext != null) { + result = 31 * result + clientContext.hashCode(); + } if (tag != null) { result = 31 * result + tag.hashCode(); } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java index d86ff807afd..ee4269b107d 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java @@ -481,9 +481,12 @@ ApiFuture beginTransactionAsync( if (sessionReference.getIsMultiplexed() && mutation != null) { requestBuilder.setMutationKey(mutation); } - if (sessionReference.getIsMultiplexed() && !Strings.isNullOrEmpty(transactionOptions.tag())) { - requestBuilder.setRequestOptions( - RequestOptions.newBuilder().setTransactionTag(transactionOptions.tag()).build()); + RequestOptions requestOptions = transactionOptions.toRequestOptionsProto(true); + if (!sessionReference.getIsMultiplexed()) { + requestOptions = requestOptions.toBuilder().clearTransactionTag().build(); + } + if (!requestOptions.equals(RequestOptions.getDefaultInstance())) { + requestBuilder.setRequestOptions(requestOptions); } final BeginTransactionRequest request = requestBuilder.build(); final ApiFuture requestFuture; diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java index 7afccce194c..d28566cef89 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java @@ -464,15 +464,9 @@ public void run() { waitForTransactionTimeoutMillis, TimeUnit.MILLISECONDS) : transactionId); } - if (options.hasPriority() || getTransactionTag() != null) { - RequestOptions.Builder requestOptionsBuilder = RequestOptions.newBuilder(); - if (options.hasPriority()) { - requestOptionsBuilder.setPriority(options.priority()); - } - if (getTransactionTag() != null) { - requestOptionsBuilder.setTransactionTag(getTransactionTag()); - } - requestBuilder.setRequestOptions(requestOptionsBuilder.build()); + RequestOptions requestOptions = options.toRequestOptionsProto(true); + if (!requestOptions.equals(RequestOptions.getDefaultInstance())) { + requestBuilder.setRequestOptions(requestOptions); } if (session.getIsMultiplexed() && getLatestPrecommitToken() != null) { // Set the precommit token in the CommitRequest for multiplexed sessions. diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/AbstractBaseUnitOfWork.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/AbstractBaseUnitOfWork.java index 75a207043c2..1d71e062cbb 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/AbstractBaseUnitOfWork.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/AbstractBaseUnitOfWork.java @@ -80,6 +80,7 @@ abstract class AbstractBaseUnitOfWork implements UnitOfWork { protected final List transactionRetryListeners; protected final boolean excludeTxnFromChangeStreams; protected final RpcPriority rpcPriority; + protected final com.google.spanner.v1.RequestOptions.ClientContext clientContext; protected final Span span; /** Class for keeping track of the stacktrace of the caller of an async statement. */ @@ -117,6 +118,7 @@ abstract static class Builder, T extends AbstractBaseUni private boolean excludeTxnFromChangeStreams; private RpcPriority rpcPriority; + private com.google.spanner.v1.RequestOptions.ClientContext clientContext; private Span span; Builder() {} @@ -163,6 +165,11 @@ B setRpcPriority(@Nullable RpcPriority rpcPriority) { return self(); } + B setClientContext(@Nullable com.google.spanner.v1.RequestOptions.ClientContext clientContext) { + this.clientContext = clientContext; + return self(); + } + B setSpan(@Nullable Span span) { this.span = span; return self(); @@ -179,6 +186,7 @@ B setSpan(@Nullable Span span) { this.transactionRetryListeners = builder.transactionRetryListeners; this.excludeTxnFromChangeStreams = builder.excludeTxnFromChangeStreams; this.rpcPriority = builder.rpcPriority; + this.clientContext = builder.clientContext; this.span = Preconditions.checkNotNull(builder.span); } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/Connection.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/Connection.java index 533be8a047f..60d739a3c85 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/Connection.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/Connection.java @@ -449,6 +449,25 @@ default String getStatementTag() { throw new UnsupportedOperationException(); } + /** + * Sets the client context to use for the statements that are executed. The client context + * persists until it is changed or cleared. + * + * @param clientContext The client context to use with the statements that will be executed on + * this connection. + */ + default void setClientContext(com.google.spanner.v1.RequestOptions.ClientContext clientContext) { + throw new UnsupportedOperationException(); + } + + /** + * @return The client context that will be used with the statements that are executed on this + * connection. + */ + default com.google.spanner.v1.RequestOptions.ClientContext getClientContext() { + throw new UnsupportedOperationException(); + } + /** * Sets whether the next transaction should be excluded from all change streams with the DDL * option `allow_txn_exclusion=true` diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java index cfd63c89d49..cadd6375739 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java @@ -94,6 +94,7 @@ import com.google.common.util.concurrent.MoreExecutors; import com.google.spanner.v1.DirectedReadOptions; import com.google.spanner.v1.ExecuteSqlRequest.QueryOptions; +import com.google.spanner.v1.RequestOptions; import com.google.spanner.v1.ResultSetStats; import com.google.spanner.v1.TransactionOptions.IsolationLevel; import com.google.spanner.v1.TransactionOptions.ReadWrite.ReadLockMode; @@ -299,6 +300,7 @@ static UnitOfWorkType of(TransactionMode transactionMode) { private IsolationLevel transactionIsolationLevel; private String transactionTag; private String statementTag; + private RequestOptions.ClientContext clientContext; private boolean excludeTxnFromChangeStreams; private byte[] protoDescriptors; private String protoDescriptorsFilePath; @@ -536,6 +538,7 @@ private void reset(Context context, boolean inTransaction) { this.connectionState.resetValue(SAVEPOINT_SUPPORT, context, inTransaction); this.protoDescriptors = null; this.protoDescriptorsFilePath = null; + this.clientContext = null; if (!isTransactionStarted()) { setDefaultTransactionOptions(getDefaultIsolationLevel()); @@ -955,6 +958,18 @@ public String getTransactionTag() { return transactionTag; } + @Override + public void setClientContext(RequestOptions.ClientContext clientContext) { + ConnectionPreconditions.checkState(!isClosed(), CLOSED_ERROR_MSG); + this.clientContext = clientContext; + } + + @Override + public RequestOptions.ClientContext getClientContext() { + ConnectionPreconditions.checkState(!isClosed(), CLOSED_ERROR_MSG); + return clientContext; + } + @Override public void setTransactionTag(String tag) { ConnectionPreconditions.checkState(!isClosed(), CLOSED_ERROR_MSG); @@ -2026,6 +2041,9 @@ private QueryOption[] mergeQueryRequestOptions( options = appendQueryOption(options, Options.priority(getConnectionPropertyValue(RPC_PRIORITY))); } + if (clientContext != null) { + options = appendQueryOption(options, Options.clientContext(clientContext)); + } if (currentUnitOfWork != null && currentUnitOfWork.supportsDirectedReads(parsedStatement) && getConnectionPropertyValue(DIRECTED_READ) != null) { @@ -2070,6 +2088,14 @@ private UpdateOption[] mergeUpdateRequestOptions(UpdateOption... options) { options[options.length - 1] = Options.priority(getConnectionPropertyValue(RPC_PRIORITY)); } } + if (clientContext != null) { + if (options == null || options.length == 0) { + options = new UpdateOption[] {Options.clientContext(clientContext)}; + } else { + options = Arrays.copyOf(options, options.length + 1); + options[options.length - 1] = Options.clientContext(clientContext); + } + } return options; } @@ -2299,6 +2325,7 @@ UnitOfWork createNewUnitOfWork( createSpanForUnitOfWork( statementType == StatementType.DDL ? DDL_STATEMENT : SINGLE_USE_TRANSACTION)) .setProtoDescriptors(getProtoDescriptors()) + .setClientContext(clientContext) .build(); if (!isInternalMetadataQuery && !forceSingleUse) { // Reset the transaction options after starting a single-use transaction. @@ -2317,6 +2344,7 @@ UnitOfWork createNewUnitOfWork( .setTransactionTag(transactionTag) .setRpcPriority(getConnectionPropertyValue(RPC_PRIORITY)) .setSpan(createSpanForUnitOfWork(READ_ONLY_TRANSACTION)) + .setClientContext(clientContext) .build(); case READ_WRITE_TRANSACTION: return ReadWriteTransaction.newBuilder() @@ -2340,6 +2368,7 @@ UnitOfWork createNewUnitOfWork( .setExcludeTxnFromChangeStreams(excludeTxnFromChangeStreams) .setRpcPriority(getConnectionPropertyValue(RPC_PRIORITY)) .setSpan(createSpanForUnitOfWork(READ_WRITE_TRANSACTION)) + .setClientContext(clientContext) .build(); case DML_BATCH: // A DML batch can run inside the current transaction. It should therefore only @@ -2359,6 +2388,7 @@ UnitOfWork createNewUnitOfWork( .setRpcPriority(getConnectionPropertyValue(RPC_PRIORITY)) // Use the transaction Span for the DML batch. .setSpan(transactionStack.peek().getSpan()) + .setClientContext(clientContext) .build(); case DDL_BATCH: return DdlBatch.newBuilder() @@ -2369,6 +2399,7 @@ UnitOfWork createNewUnitOfWork( .setSpan(createSpanForUnitOfWork(DDL_BATCH)) .setProtoDescriptors(getProtoDescriptors()) .setConnectionState(connectionState) + .setClientContext(clientContext) .build(); default: } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java index c0e464ee5e6..ccb592e3f84 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java @@ -350,6 +350,9 @@ private TransactionOption[] extractOptions(Builder builder) { if (this.readLockMode != ReadLockMode.READ_LOCK_MODE_UNSPECIFIED) { numOptions++; } + if (this.clientContext != null) { + numOptions++; + } TransactionOption[] options = new TransactionOption[numOptions]; int index = 0; if (builder.returnCommitStats) { @@ -373,6 +376,9 @@ private TransactionOption[] extractOptions(Builder builder) { if (this.readLockMode != ReadLockMode.READ_LOCK_MODE_UNSPECIFIED) { options[index++] = Options.readLockMode(this.readLockMode); } + if (this.clientContext != null) { + options[index++] = Options.clientContext(this.clientContext); + } return options; } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SingleUseTransaction.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SingleUseTransaction.java index 370b579e6e2..cfb13cef966 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SingleUseTransaction.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SingleUseTransaction.java @@ -520,6 +520,9 @@ private TransactionRunner createWriteTransaction() { != ReadLockMode.READ_LOCK_MODE_UNSPECIFIED) { numOptions++; } + if (this.clientContext != null) { + numOptions++; + } if (numOptions == 0) { return dbClient.readWriteTransaction(); } @@ -547,6 +550,9 @@ private TransactionRunner createWriteTransaction() { != ReadLockMode.READ_LOCK_MODE_UNSPECIFIED) { options[index++] = Options.readLockMode(connectionState.getValue(READ_LOCK_MODE).getValue()); } + if (this.clientContext != null) { + options[index++] = Options.clientContext(this.clientContext); + } return dbClient.readWriteTransaction(options); } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AbstractReadContextTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AbstractReadContextTest.java index eea6658d26d..d1f1f0b8c16 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AbstractReadContextTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AbstractReadContextTest.java @@ -345,6 +345,18 @@ public void executeSqlRequestBuilderWithRequestOptionsWithTxnTag() { assertThat(request.getRequestOptions().getTransactionTag()).isEqualTo("app=spanner,env=test"); } + @Test + public void testBuildRequestOptionsWithClientContext() { + RequestOptions.ClientContext clientContext = + RequestOptions.ClientContext.newBuilder() + .putSecureContext( + "key", com.google.protobuf.Value.newBuilder().setStringValue("value").build()) + .build(); + RequestOptions requestOptions = + context.buildRequestOptions(Options.fromQueryOptions(Options.clientContext(clientContext))); + assertEquals(clientContext, requestOptions.getClientContext()); + } + @Test public void testGetExecuteSqlRequestBuilderWithDirectedReadOptions() { ExecuteSqlRequest.Builder request = diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/OptionsTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/OptionsTest.java index 8571c42b3dd..52cd2db7798 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/OptionsTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/OptionsTest.java @@ -34,6 +34,7 @@ import com.google.spanner.v1.DirectedReadOptions.ReplicaSelection; import com.google.spanner.v1.ReadRequest.LockHint; import com.google.spanner.v1.ReadRequest.OrderBy; +import com.google.spanner.v1.RequestOptions; import com.google.spanner.v1.RequestOptions.Priority; import com.google.spanner.v1.TransactionOptions.IsolationLevel; import com.google.spanner.v1.TransactionOptions.ReadWrite; @@ -45,6 +46,32 @@ /** Unit tests for {@link Options}. */ @RunWith(JUnit4.class) public class OptionsTest { + @Test + public void testToRequestOptionsProto() { + RequestOptions.ClientContext clientContext = + RequestOptions.ClientContext.newBuilder() + .putSecureContext( + "key", com.google.protobuf.Value.newBuilder().setStringValue("value").build()) + .build(); + Options options = + Options.fromQueryOptions( + Options.priority(RpcPriority.HIGH), + Options.tag("tag"), + Options.clientContext(clientContext)); + + RequestOptions protoForStatement = options.toRequestOptionsProto(false); + assertEquals(RequestOptions.Priority.PRIORITY_HIGH, protoForStatement.getPriority()); + assertEquals("tag", protoForStatement.getRequestTag()); + assertEquals("", protoForStatement.getTransactionTag()); + assertEquals(clientContext, protoForStatement.getClientContext()); + + RequestOptions protoForTransaction = options.toRequestOptionsProto(true); + assertEquals(RequestOptions.Priority.PRIORITY_HIGH, protoForTransaction.getPriority()); + assertEquals("", protoForTransaction.getRequestTag()); + assertEquals("tag", protoForTransaction.getTransactionTag()); + assertEquals(clientContext, protoForTransaction.getClientContext()); + } + private static final DirectedReadOptions DIRECTED_READ_OPTIONS = DirectedReadOptions.newBuilder() .setIncludeReplicas( diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java index 1ac3b7beaf7..35117568926 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java @@ -47,7 +47,10 @@ import com.google.spanner.v1.CommitResponse; import com.google.spanner.v1.Mutation.Write; import com.google.spanner.v1.PartialResultSet; +import com.google.spanner.v1.RequestOptions; +import com.google.spanner.v1.ResultSet; import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.ResultSetStats; import com.google.spanner.v1.RollbackRequest; import com.google.spanner.v1.Session; import com.google.spanner.v1.Transaction; @@ -77,6 +80,42 @@ /** Unit tests for {@link com.google.cloud.spanner.SessionImpl}. */ @RunWith(JUnit4.class) public class SessionImplTest { + @Test + public void testBeginTransactionWithClientContext() { + RequestOptions.ClientContext clientContext = + RequestOptions.ClientContext.newBuilder() + .putSecureContext( + "key", com.google.protobuf.Value.newBuilder().setStringValue("value").build()) + .build(); + Mockito.when( + rpc.beginTransactionAsync( + Mockito.any(BeginTransactionRequest.class), anyMap(), eq(true))) + .thenReturn( + ApiFutures.immediateFuture( + Transaction.newBuilder().setId(ByteString.copyFromUtf8("tx")).build())); + + ((SessionImpl) session) + .beginTransactionAsync( + Options.fromTransactionOptions( + Options.priority(Options.RpcPriority.HIGH), + Options.tag("tag"), + Options.clientContext(clientContext)), + true, + Collections.emptyMap(), + null, + null); + + ArgumentCaptor requestCaptor = + ArgumentCaptor.forClass(BeginTransactionRequest.class); + Mockito.verify(rpc).beginTransactionAsync(requestCaptor.capture(), anyMap(), eq(true)); + BeginTransactionRequest request = requestCaptor.getValue(); + RequestOptions requestOptions = request.getRequestOptions(); + assertEquals(RequestOptions.Priority.PRIORITY_HIGH, requestOptions.getPriority()); + // TransactionTag should NOT be set because session is not multiplexed. + assertEquals("", requestOptions.getTransactionTag()); + assertEquals(clientContext, requestOptions.getClientContext()); + } + @Mock private SpannerRpc rpc; @Mock private SpannerOptions spannerOptions; private com.google.cloud.spanner.Session session; diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java index 3e3358a53bc..15ba8ec5acf 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.anyMap; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; @@ -53,6 +54,7 @@ import com.google.spanner.v1.ExecuteBatchDmlResponse; import com.google.spanner.v1.ExecuteSqlRequest; import com.google.spanner.v1.ExecuteSqlRequest.QueryOptions; +import com.google.spanner.v1.RequestOptions; import com.google.spanner.v1.ResultSet; import com.google.spanner.v1.ResultSetMetadata; import com.google.spanner.v1.ResultSetStats; @@ -79,6 +81,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.MockitoAnnotations; @@ -99,6 +102,37 @@ public void release(ScheduledExecutorService exec) { } } + @Test + public void testCommitWithClientContext() { + RequestOptions.ClientContext clientContext = + RequestOptions.ClientContext.newBuilder() + .putSecureContext( + "key", com.google.protobuf.Value.newBuilder().setStringValue("value").build()) + .build(); + when(session.getName()).thenReturn("projects/p/instances/i/databases/d/sessions/s"); + when(session.newTransaction(any(Options.class), any())).thenReturn(txn); + Mockito.clearInvocations(session); + transactionRunner = + new TransactionRunnerImpl( + session, + Options.priority(Options.RpcPriority.HIGH), + Options.tag("tag"), + Options.clientContext(clientContext)); + transactionRunner.setSpan(span); + + transactionRunner.run( + transaction -> { + return null; + }); + + ArgumentCaptor optionsCaptor = ArgumentCaptor.forClass(Options.class); + verify(session).newTransaction(optionsCaptor.capture(), any()); + Options capturedOptions = optionsCaptor.getValue(); + assertEquals(RequestOptions.Priority.PRIORITY_HIGH, capturedOptions.priority()); + assertEquals("tag", capturedOptions.tag()); + assertEquals(clientContext, capturedOptions.clientContext()); + } + @Mock private SpannerRpc rpc; @Mock private SessionImpl session; @Mock private TransactionRunnerImpl.TransactionContextImpl txn; diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ClientContextMockServerTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ClientContextMockServerTest.java new file mode 100644 index 00000000000..63287e1afac --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ClientContextMockServerTest.java @@ -0,0 +1,244 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner.connection; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; + +import com.google.cloud.spanner.Dialect; +import com.google.cloud.spanner.MockSpannerServiceImpl; +import com.google.cloud.spanner.ResultSet; +import com.google.protobuf.Value; +import com.google.spanner.v1.BeginTransactionRequest; +import com.google.spanner.v1.CommitRequest; +import com.google.spanner.v1.ExecuteBatchDmlRequest; +import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.RequestOptions; +import java.util.Collections; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(Parameterized.class) +public class ClientContextMockServerTest extends AbstractMockServerTest { + + @Parameters(name = "dialect = {0}") + public static Object[] data() { + return Dialect.values(); + } + + @Parameter public Dialect dialect; + + private Dialect currentDialect; + + private static final RequestOptions.ClientContext CLIENT_CONTEXT = + RequestOptions.ClientContext.newBuilder() + .putSecureContext("test-key", Value.newBuilder().setStringValue("test-value").build()) + .build(); + + @Before + public void setupDialect() { + if (currentDialect != dialect) { + mockSpanner.putStatementResult( + MockSpannerServiceImpl.StatementResult.detectDialectResult(dialect)); + SpannerPool.closeSpannerPool(); + currentDialect = dialect; + } + } + + @After + public void clearRequests() { + mockSpanner.clearRequests(); + } + + @Test + public void testQuery_PropagatesClientContext() { + try (Connection connection = createConnection()) { + connection.setClientContext(CLIENT_CONTEXT); + try (ResultSet ignore = connection.executeQuery(SELECT_COUNT_STATEMENT)) {} + + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + assertEquals( + CLIENT_CONTEXT, + mockSpanner + .getRequestsOfType(ExecuteSqlRequest.class) + .get(0) + .getRequestOptions() + .getClientContext()); + } + } + + @Test + public void testUpdate_PropagatesClientContext() { + try (Connection connection = createConnection()) { + connection.setClientContext(CLIENT_CONTEXT); + connection.executeUpdate(INSERT_STATEMENT); + + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + assertEquals( + CLIENT_CONTEXT, + mockSpanner + .getRequestsOfType(ExecuteSqlRequest.class) + .get(0) + .getRequestOptions() + .getClientContext()); + } + } + + @Test + public void testBatchUpdate_PropagatesClientContext() { + try (Connection connection = createConnection()) { + connection.setClientContext(CLIENT_CONTEXT); + connection.executeBatchUpdate(Collections.singletonList(INSERT_STATEMENT)); + + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class)); + assertEquals( + CLIENT_CONTEXT, + mockSpanner + .getRequestsOfType(ExecuteBatchDmlRequest.class) + .get(0) + .getRequestOptions() + .getClientContext()); + } + } + + @Test + public void testCommit_PropagatesClientContext() { + try (Connection connection = createConnection()) { + connection.setClientContext(CLIENT_CONTEXT); + connection.executeUpdate(INSERT_STATEMENT); + connection.commit(); + + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + assertEquals( + CLIENT_CONTEXT, + mockSpanner + .getRequestsOfType(CommitRequest.class) + .get(0) + .getRequestOptions() + .getClientContext()); + } + } + + @Test + public void testBeginTransaction_PropagatesClientContext() { + try (Connection connection = createConnection()) { + connection.setClientContext(CLIENT_CONTEXT); + connection.beginTransaction(); + // BeginTransaction is executed lazily, so we need to execute a statement to trigger it. + // However, the Connection API might do it eagerly if we call beginTransaction(). + // Let's check the implementation of beginTransactionAsync. + // It calls transactionManager.begin() which eventually calls BeginTransaction RPC. + // But ReadWriteTransaction waits until the first statement unless DELAY_TRANSACTION_START is false. + // Actually, ConnectionImpl.beginTransactionAsync calls clearLastTransactionAndSetDefaultTransactionOptions. + // It does NOT start the transaction on Spanner immediately. + // The transaction is started when the first statement is executed. + + // So let's execute a statement. + connection.executeUpdate(INSERT_STATEMENT); + + // Now checking requests. + // If the transaction started lazily (default), the first statement might be BeginTransaction? + // Or if it's ReadWrite, it might use inline BeginTransaction. + // If it uses inline BeginTransaction, then ExecuteSqlRequest will have transaction.begin. + + // Let's force an explicit BeginTransaction by executing a statement that requires it, + // or by checking if we can force it. + // But typically, for RW transaction, the first statement has BeginTransaction option. + + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(0); + assertEquals(CLIENT_CONTEXT, request.getRequestOptions().getClientContext()); + + // If we want to test BeginTransaction RPC specifically, we need to ensure it is called. + // It is called if we do `connection.beginTransaction()` and then something that triggers it? + // Or if we use `DELAY_TRANSACTION_START_UNTIL_FIRST_WRITE=false`? + // The default is `true` (I think). + + // If we use explicit transaction management via Connection, `beginTransaction` just sets state. + } + } + + @Test + public void testPersistence() { + try (Connection connection = createConnection()) { + connection.setClientContext(CLIENT_CONTEXT); + try (ResultSet ignore = connection.executeQuery(SELECT_COUNT_STATEMENT)) {} + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + assertEquals( + CLIENT_CONTEXT, + mockSpanner + .getRequestsOfType(ExecuteSqlRequest.class) + .get(0) + .getRequestOptions() + .getClientContext()); + + connection.executeUpdate(INSERT_STATEMENT); + assertEquals(2, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + assertEquals( + CLIENT_CONTEXT, + mockSpanner + .getRequestsOfType(ExecuteSqlRequest.class) + .get(1) + .getRequestOptions() + .getClientContext()); + + connection.commit(); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + assertEquals( + CLIENT_CONTEXT, + mockSpanner + .getRequestsOfType(CommitRequest.class) + .get(0) + .getRequestOptions() + .getClientContext()); + } + } + + @Test + public void testClearClientContext() { + try (Connection connection = createConnection()) { + connection.setClientContext(CLIENT_CONTEXT); + try (ResultSet ignore = connection.executeQuery(SELECT_COUNT_STATEMENT)) {} + + assertEquals( + CLIENT_CONTEXT, + mockSpanner + .getRequestsOfType(ExecuteSqlRequest.class) + .get(0) + .getRequestOptions() + .getClientContext()); + + connection.setClientContext(null); + try (ResultSet ignore = connection.executeQuery(SELECT_COUNT_STATEMENT)) {} + + assertEquals(2, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + assertFalse( + mockSpanner + .getRequestsOfType(ExecuteSqlRequest.class) + .get(1) + .getRequestOptions() + .hasClientContext()); + } + } +} diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ConnectionImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ConnectionImplTest.java index ead2fd0f655..d430c04bbd1 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ConnectionImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ConnectionImplTest.java @@ -1948,6 +1948,35 @@ private void assertThrowResultNotAllowed( "Only statements that return a result of one of the following types are allowed")); } + @Test + public void testSetAndGetClientContext() { + try (Connection connection = createConnection(ConnectionOptions.newBuilder().setUri(URI).build())) { + com.google.spanner.v1.RequestOptions.ClientContext context = + com.google.spanner.v1.RequestOptions.ClientContext.newBuilder() + .putSecureContext( + "key", com.google.protobuf.Value.newBuilder().setStringValue("test").build()) + .build(); + connection.setClientContext(context); + assertEquals(context, connection.getClientContext()); + } + } + + @Test + public void testResetClearsClientContext() { + try (Connection connection = createConnection(ConnectionOptions.newBuilder().setUri(URI).build())) { + com.google.spanner.v1.RequestOptions.ClientContext context = + com.google.spanner.v1.RequestOptions.ClientContext.newBuilder() + .putSecureContext( + "key", com.google.protobuf.Value.newBuilder().setStringValue("test").build()) + .build(); + connection.setClientContext(context); + assertEquals(context, connection.getClientContext()); + + connection.reset(); + assertNull(connection.getClientContext()); + } + } + @Test public void testProtoDescriptorsAlwaysAllowed() { ConnectionOptions connectionOptions = mock(ConnectionOptions.class);