From 525fad1ca26b5c1910fd89bc3b56371691738058 Mon Sep 17 00:00:00 2001 From: Adam Seering Date: Mon, 12 Jan 2026 10:44:07 +0000 Subject: [PATCH 1/2] feat: Add ClientContext to Options and propagate to RPCs This change adds support for ClientContext in Options and ensures it is propagated to ExecuteSql, Read, Commit, and BeginTransaction requests. It aligns with go/spanner-client-scoped-session-state design. - Added RequestOptions.ClientContext to Options. - Refactored request option building to Options.toRequestOptionsProto. - Updated AbstractReadContext, TransactionRunnerImpl, and SessionImpl to use the shared logic. - Added tests. --- .../cloud/spanner/AbstractReadContext.java | 17 +---- .../com/google/cloud/spanner/Options.java | 62 +++++++++++++++++++ .../com/google/cloud/spanner/SessionImpl.java | 9 ++- .../cloud/spanner/TransactionRunnerImpl.java | 12 +--- .../spanner/AbstractReadContextTest.java | 12 ++++ .../com/google/cloud/spanner/OptionsTest.java | 34 +++++++--- .../google/cloud/spanner/SessionImplTest.java | 39 ++++++++++++ .../spanner/TransactionRunnerImplTest.java | 34 ++++++++++ 8 files changed, 186 insertions(+), 33 deletions(-) diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java index 289acb1a745..297f190f9d9 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java @@ -684,22 +684,11 @@ QueryOptions buildQueryOptions(QueryOptions requestOptions) { } RequestOptions buildRequestOptions(Options options) { - // Shortcut for the most common return value. - if (!(options.hasPriority() || options.hasTag() || getTransactionTag() != null)) { - return RequestOptions.getDefaultInstance(); - } - - RequestOptions.Builder builder = RequestOptions.newBuilder(); - if (options.hasPriority()) { - builder.setPriority(options.priority()); - } - if (options.hasTag()) { - builder.setRequestTag(options.tag()); - } + RequestOptions requestOptions = options.toRequestOptionsProto(false); if (getTransactionTag() != null) { - builder.setTransactionTag(getTransactionTag()); + return requestOptions.toBuilder().setTransactionTag(getTransactionTag()).build(); } - return builder.build(); + return requestOptions; } ExecuteSqlRequest.Builder getExecuteSqlRequestBuilder( diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java index 1e6ce34d672..1a5de25172b 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java @@ -20,6 +20,7 @@ import com.google.spanner.v1.DirectedReadOptions; import com.google.spanner.v1.ReadRequest.LockHint; import com.google.spanner.v1.ReadRequest.OrderBy; +import com.google.spanner.v1.RequestOptions; import com.google.spanner.v1.RequestOptions.Priority; import com.google.spanner.v1.TransactionOptions.IsolationLevel; import com.google.spanner.v1.TransactionOptions.ReadWrite.ReadLockMode; @@ -265,6 +266,37 @@ public static ReadQueryUpdateTransactionOption priority(RpcPriority priority) { return new PriorityOption(priority); } + /** + * Specifying this will add the given client context to the request. The client context is used to + * pass opaque side-channel information to the backend, such as a user ID for a parameterized + * secure view. + */ + public static ReadQueryUpdateTransactionOption clientContext( + RequestOptions.ClientContext clientContext) { + return new ClientContextOption(clientContext); + } + + RequestOptions toRequestOptionsProto(boolean isTransactionOption) { + if (!hasPriority() && !hasTag() && !hasClientContext()) { + return RequestOptions.getDefaultInstance(); + } + RequestOptions.Builder builder = RequestOptions.newBuilder(); + if (hasPriority()) { + builder.setPriority(priority()); + } + if (hasTag()) { + if (isTransactionOption) { + builder.setTransactionTag(tag()); + } else { + builder.setRequestTag(tag()); + } + } + if (hasClientContext()) { + builder.setClientContext(clientContext()); + } + return builder.build(); + } + public static TransactionOption maxCommitDelay(Duration maxCommitDelay) { Preconditions.checkArgument(!maxCommitDelay.isNegative(), "maxCommitDelay should be positive"); return new MaxCommitDelayOption(maxCommitDelay); @@ -462,6 +494,20 @@ void appendToOptions(Options options) { } } + static final class ClientContextOption extends InternalOption + implements ReadQueryUpdateTransactionOption { + private final RequestOptions.ClientContext clientContext; + + ClientContextOption(RequestOptions.ClientContext clientContext) { + this.clientContext = clientContext; + } + + @Override + void appendToOptions(Options options) { + options.clientContext = clientContext; + } + } + static final class TagOption extends InternalOption implements ReadQueryUpdateTransactionOption { private final String tag; @@ -574,6 +620,7 @@ void appendToOptions(Options options) { private String filter; private RpcPriority priority; private String tag; + private RequestOptions.ClientContext clientContext; private String etag; private Boolean validateOnly; private Boolean withExcludeTxnFromChangeStreams; @@ -666,6 +713,14 @@ Priority priority() { return priority == null ? null : priority.proto; } + boolean hasClientContext() { + return clientContext != null; + } + + RequestOptions.ClientContext clientContext() { + return clientContext; + } + boolean hasTag() { return tag != null; } @@ -777,6 +832,9 @@ public String toString() { if (priority != null) { b.append("priority: ").append(priority).append(' '); } + if (clientContext != null) { + b.append("clientContext: ").append(clientContext).append(' '); + } if (tag != null) { b.append("tag: ").append(tag).append(' '); } @@ -850,6 +908,7 @@ public boolean equals(Object o) { && Objects.equals(pageToken(), that.pageToken()) && Objects.equals(filter(), that.filter()) && Objects.equals(priority(), that.priority()) + && Objects.equals(clientContext(), that.clientContext()) && Objects.equals(tag(), that.tag()) && Objects.equals(etag(), that.etag()) && Objects.equals(validateOnly(), that.validateOnly()) @@ -894,6 +953,9 @@ public int hashCode() { if (priority != null) { result = 31 * result + priority.hashCode(); } + if (clientContext != null) { + result = 31 * result + clientContext.hashCode(); + } if (tag != null) { result = 31 * result + tag.hashCode(); } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java index d86ff807afd..ee4269b107d 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java @@ -481,9 +481,12 @@ ApiFuture beginTransactionAsync( if (sessionReference.getIsMultiplexed() && mutation != null) { requestBuilder.setMutationKey(mutation); } - if (sessionReference.getIsMultiplexed() && !Strings.isNullOrEmpty(transactionOptions.tag())) { - requestBuilder.setRequestOptions( - RequestOptions.newBuilder().setTransactionTag(transactionOptions.tag()).build()); + RequestOptions requestOptions = transactionOptions.toRequestOptionsProto(true); + if (!sessionReference.getIsMultiplexed()) { + requestOptions = requestOptions.toBuilder().clearTransactionTag().build(); + } + if (!requestOptions.equals(RequestOptions.getDefaultInstance())) { + requestBuilder.setRequestOptions(requestOptions); } final BeginTransactionRequest request = requestBuilder.build(); final ApiFuture requestFuture; diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java index 7afccce194c..d28566cef89 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java @@ -464,15 +464,9 @@ public void run() { waitForTransactionTimeoutMillis, TimeUnit.MILLISECONDS) : transactionId); } - if (options.hasPriority() || getTransactionTag() != null) { - RequestOptions.Builder requestOptionsBuilder = RequestOptions.newBuilder(); - if (options.hasPriority()) { - requestOptionsBuilder.setPriority(options.priority()); - } - if (getTransactionTag() != null) { - requestOptionsBuilder.setTransactionTag(getTransactionTag()); - } - requestBuilder.setRequestOptions(requestOptionsBuilder.build()); + RequestOptions requestOptions = options.toRequestOptionsProto(true); + if (!requestOptions.equals(RequestOptions.getDefaultInstance())) { + requestBuilder.setRequestOptions(requestOptions); } if (session.getIsMultiplexed() && getLatestPrecommitToken() != null) { // Set the precommit token in the CommitRequest for multiplexed sessions. diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AbstractReadContextTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AbstractReadContextTest.java index eea6658d26d..d1f1f0b8c16 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AbstractReadContextTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AbstractReadContextTest.java @@ -345,6 +345,18 @@ public void executeSqlRequestBuilderWithRequestOptionsWithTxnTag() { assertThat(request.getRequestOptions().getTransactionTag()).isEqualTo("app=spanner,env=test"); } + @Test + public void testBuildRequestOptionsWithClientContext() { + RequestOptions.ClientContext clientContext = + RequestOptions.ClientContext.newBuilder() + .putSecureContext( + "key", com.google.protobuf.Value.newBuilder().setStringValue("value").build()) + .build(); + RequestOptions requestOptions = + context.buildRequestOptions(Options.fromQueryOptions(Options.clientContext(clientContext))); + assertEquals(clientContext, requestOptions.getClientContext()); + } + @Test public void testGetExecuteSqlRequestBuilderWithDirectedReadOptions() { ExecuteSqlRequest.Builder request = diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/OptionsTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/OptionsTest.java index 8571c42b3dd..3edf9a61e17 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/OptionsTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/OptionsTest.java @@ -34,17 +34,37 @@ import com.google.spanner.v1.DirectedReadOptions.ReplicaSelection; import com.google.spanner.v1.ReadRequest.LockHint; import com.google.spanner.v1.ReadRequest.OrderBy; -import com.google.spanner.v1.RequestOptions.Priority; -import com.google.spanner.v1.TransactionOptions.IsolationLevel; -import com.google.spanner.v1.TransactionOptions.ReadWrite; -import com.google.spanner.v1.TransactionOptions.ReadWrite.ReadLockMode; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import com.google.spanner.v1.RequestOptions; /** Unit tests for {@link Options}. */ @RunWith(JUnit4.class) public class OptionsTest { + @Test + public void testToRequestOptionsProto() { + RequestOptions.ClientContext clientContext = + RequestOptions.ClientContext.newBuilder() + .putSecureContext( + "key", com.google.protobuf.Value.newBuilder().setStringValue("value").build()) + .build(); + Options options = + Options.fromQueryOptions( + Options.priority(RpcPriority.HIGH), + Options.tag("tag"), + Options.clientContext(clientContext)); + + RequestOptions protoForStatement = options.toRequestOptionsProto(false); + assertEquals(RequestOptions.Priority.PRIORITY_HIGH, protoForStatement.getPriority()); + assertEquals("tag", protoForStatement.getRequestTag()); + assertEquals("", protoForStatement.getTransactionTag()); + assertEquals(clientContext, protoForStatement.getClientContext()); + + RequestOptions protoForTransaction = options.toRequestOptionsProto(true); + assertEquals(RequestOptions.Priority.PRIORITY_HIGH, protoForTransaction.getPriority()); + assertEquals("", protoForTransaction.getRequestTag()); + assertEquals("tag", protoForTransaction.getTransactionTag()); + assertEquals(clientContext, protoForTransaction.getClientContext()); + } + private static final DirectedReadOptions DIRECTED_READ_OPTIONS = DirectedReadOptions.newBuilder() .setIncludeReplicas( diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java index 1ac3b7beaf7..35117568926 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java @@ -47,7 +47,10 @@ import com.google.spanner.v1.CommitResponse; import com.google.spanner.v1.Mutation.Write; import com.google.spanner.v1.PartialResultSet; +import com.google.spanner.v1.RequestOptions; +import com.google.spanner.v1.ResultSet; import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.ResultSetStats; import com.google.spanner.v1.RollbackRequest; import com.google.spanner.v1.Session; import com.google.spanner.v1.Transaction; @@ -77,6 +80,42 @@ /** Unit tests for {@link com.google.cloud.spanner.SessionImpl}. */ @RunWith(JUnit4.class) public class SessionImplTest { + @Test + public void testBeginTransactionWithClientContext() { + RequestOptions.ClientContext clientContext = + RequestOptions.ClientContext.newBuilder() + .putSecureContext( + "key", com.google.protobuf.Value.newBuilder().setStringValue("value").build()) + .build(); + Mockito.when( + rpc.beginTransactionAsync( + Mockito.any(BeginTransactionRequest.class), anyMap(), eq(true))) + .thenReturn( + ApiFutures.immediateFuture( + Transaction.newBuilder().setId(ByteString.copyFromUtf8("tx")).build())); + + ((SessionImpl) session) + .beginTransactionAsync( + Options.fromTransactionOptions( + Options.priority(Options.RpcPriority.HIGH), + Options.tag("tag"), + Options.clientContext(clientContext)), + true, + Collections.emptyMap(), + null, + null); + + ArgumentCaptor requestCaptor = + ArgumentCaptor.forClass(BeginTransactionRequest.class); + Mockito.verify(rpc).beginTransactionAsync(requestCaptor.capture(), anyMap(), eq(true)); + BeginTransactionRequest request = requestCaptor.getValue(); + RequestOptions requestOptions = request.getRequestOptions(); + assertEquals(RequestOptions.Priority.PRIORITY_HIGH, requestOptions.getPriority()); + // TransactionTag should NOT be set because session is not multiplexed. + assertEquals("", requestOptions.getTransactionTag()); + assertEquals(clientContext, requestOptions.getClientContext()); + } + @Mock private SpannerRpc rpc; @Mock private SpannerOptions spannerOptions; private com.google.cloud.spanner.Session session; diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java index 3e3358a53bc..90dc59b1cdc 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.anyMap; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; @@ -53,6 +54,7 @@ import com.google.spanner.v1.ExecuteBatchDmlResponse; import com.google.spanner.v1.ExecuteSqlRequest; import com.google.spanner.v1.ExecuteSqlRequest.QueryOptions; +import com.google.spanner.v1.RequestOptions; import com.google.spanner.v1.ResultSet; import com.google.spanner.v1.ResultSetMetadata; import com.google.spanner.v1.ResultSetStats; @@ -79,6 +81,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.MockitoAnnotations; @@ -99,6 +102,37 @@ public void release(ScheduledExecutorService exec) { } } + @Test + public void testCommitWithClientContext() { + RequestOptions.ClientContext clientContext = + RequestOptions.ClientContext.newBuilder() + .putSecureContext( + "key", com.google.protobuf.Value.newBuilder().setStringValue("value").build()) + .build(); + Options options = + Options.fromTransactionOptions( + Options.priority(Options.RpcPriority.HIGH), + Options.tag("tag"), + Options.clientContext(clientContext)); + transactionRunner = new TransactionRunnerImpl(session, options); + when(session.getName()).thenReturn("projects/p/instances/i/databases/d/sessions/s"); + when(session.newTransaction(any(Options.class), any())).thenReturn(txn); + + transactionRunner.run( + transaction -> { + return null; + }); + + ArgumentCaptor commitRequestCaptor = + ArgumentCaptor.forClass(CommitRequest.class); + verify(rpc).commitAsync(commitRequestCaptor.capture(), anyMap()); + CommitRequest request = commitRequestCaptor.getValue(); + RequestOptions requestOptions = request.getRequestOptions(); + assertEquals(RequestOptions.Priority.PRIORITY_HIGH, requestOptions.getPriority()); + assertEquals("tag", requestOptions.getTransactionTag()); + assertEquals(clientContext, requestOptions.getClientContext()); + } + @Mock private SpannerRpc rpc; @Mock private SessionImpl session; @Mock private TransactionRunnerImpl.TransactionContextImpl txn; From 61f82cba40e9e1c8b1daca9ae2cd571fb87f8fe8 Mon Sep 17 00:00:00 2001 From: Adam Seering Date: Wed, 14 Jan 2026 12:56:54 +0000 Subject: [PATCH 2/2] feat: Add ClientContext support to Connection API This change adds support for setting and propagating ClientContext in the Spanner Connection API. ClientContext allows propagating client-scoped session state (e.g., secure parameters) to Spanner RPCs. - Added setClientContext/getClientContext to Connection interface and implementation. - Implemented state propagation from Connection to UnitOfWork and its implementations (ReadWriteTransaction, SingleUseTransaction). - Fixed accidental import removal in OptionsTest.java. - Fixed TransactionRunnerImplTest to correctly verify ClientContext propagation. - Added ClientContextMockServerTest for end-to-end verification. --- .../com/google/cloud/spanner/Options.java | 2 +- .../connection/AbstractBaseUnitOfWork.java | 8 + .../cloud/spanner/connection/Connection.java | 19 ++ .../spanner/connection/ConnectionImpl.java | 31 +++ .../connection/ReadWriteTransaction.java | 6 + .../connection/SingleUseTransaction.java | 6 + .../com/google/cloud/spanner/OptionsTest.java | 7 + .../spanner/TransactionRunnerImplTest.java | 26 +- .../ClientContextMockServerTest.java | 244 ++++++++++++++++++ .../connection/ConnectionImplTest.java | 29 +++ 10 files changed, 364 insertions(+), 14 deletions(-) create mode 100644 google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ClientContextMockServerTest.java diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java index 1a5de25172b..88810c6b73b 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java @@ -268,7 +268,7 @@ public static ReadQueryUpdateTransactionOption priority(RpcPriority priority) { /** * Specifying this will add the given client context to the request. The client context is used to - * pass opaque side-channel information to the backend, such as a user ID for a parameterized + * pass side-channel or configuration information to the backend, such as a user ID for a parameterized * secure view. */ public static ReadQueryUpdateTransactionOption clientContext( diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/AbstractBaseUnitOfWork.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/AbstractBaseUnitOfWork.java index 75a207043c2..1d71e062cbb 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/AbstractBaseUnitOfWork.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/AbstractBaseUnitOfWork.java @@ -80,6 +80,7 @@ abstract class AbstractBaseUnitOfWork implements UnitOfWork { protected final List transactionRetryListeners; protected final boolean excludeTxnFromChangeStreams; protected final RpcPriority rpcPriority; + protected final com.google.spanner.v1.RequestOptions.ClientContext clientContext; protected final Span span; /** Class for keeping track of the stacktrace of the caller of an async statement. */ @@ -117,6 +118,7 @@ abstract static class Builder, T extends AbstractBaseUni private boolean excludeTxnFromChangeStreams; private RpcPriority rpcPriority; + private com.google.spanner.v1.RequestOptions.ClientContext clientContext; private Span span; Builder() {} @@ -163,6 +165,11 @@ B setRpcPriority(@Nullable RpcPriority rpcPriority) { return self(); } + B setClientContext(@Nullable com.google.spanner.v1.RequestOptions.ClientContext clientContext) { + this.clientContext = clientContext; + return self(); + } + B setSpan(@Nullable Span span) { this.span = span; return self(); @@ -179,6 +186,7 @@ B setSpan(@Nullable Span span) { this.transactionRetryListeners = builder.transactionRetryListeners; this.excludeTxnFromChangeStreams = builder.excludeTxnFromChangeStreams; this.rpcPriority = builder.rpcPriority; + this.clientContext = builder.clientContext; this.span = Preconditions.checkNotNull(builder.span); } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/Connection.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/Connection.java index 533be8a047f..60d739a3c85 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/Connection.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/Connection.java @@ -449,6 +449,25 @@ default String getStatementTag() { throw new UnsupportedOperationException(); } + /** + * Sets the client context to use for the statements that are executed. The client context + * persists until it is changed or cleared. + * + * @param clientContext The client context to use with the statements that will be executed on + * this connection. + */ + default void setClientContext(com.google.spanner.v1.RequestOptions.ClientContext clientContext) { + throw new UnsupportedOperationException(); + } + + /** + * @return The client context that will be used with the statements that are executed on this + * connection. + */ + default com.google.spanner.v1.RequestOptions.ClientContext getClientContext() { + throw new UnsupportedOperationException(); + } + /** * Sets whether the next transaction should be excluded from all change streams with the DDL * option `allow_txn_exclusion=true` diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java index cfd63c89d49..cadd6375739 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java @@ -94,6 +94,7 @@ import com.google.common.util.concurrent.MoreExecutors; import com.google.spanner.v1.DirectedReadOptions; import com.google.spanner.v1.ExecuteSqlRequest.QueryOptions; +import com.google.spanner.v1.RequestOptions; import com.google.spanner.v1.ResultSetStats; import com.google.spanner.v1.TransactionOptions.IsolationLevel; import com.google.spanner.v1.TransactionOptions.ReadWrite.ReadLockMode; @@ -299,6 +300,7 @@ static UnitOfWorkType of(TransactionMode transactionMode) { private IsolationLevel transactionIsolationLevel; private String transactionTag; private String statementTag; + private RequestOptions.ClientContext clientContext; private boolean excludeTxnFromChangeStreams; private byte[] protoDescriptors; private String protoDescriptorsFilePath; @@ -536,6 +538,7 @@ private void reset(Context context, boolean inTransaction) { this.connectionState.resetValue(SAVEPOINT_SUPPORT, context, inTransaction); this.protoDescriptors = null; this.protoDescriptorsFilePath = null; + this.clientContext = null; if (!isTransactionStarted()) { setDefaultTransactionOptions(getDefaultIsolationLevel()); @@ -955,6 +958,18 @@ public String getTransactionTag() { return transactionTag; } + @Override + public void setClientContext(RequestOptions.ClientContext clientContext) { + ConnectionPreconditions.checkState(!isClosed(), CLOSED_ERROR_MSG); + this.clientContext = clientContext; + } + + @Override + public RequestOptions.ClientContext getClientContext() { + ConnectionPreconditions.checkState(!isClosed(), CLOSED_ERROR_MSG); + return clientContext; + } + @Override public void setTransactionTag(String tag) { ConnectionPreconditions.checkState(!isClosed(), CLOSED_ERROR_MSG); @@ -2026,6 +2041,9 @@ private QueryOption[] mergeQueryRequestOptions( options = appendQueryOption(options, Options.priority(getConnectionPropertyValue(RPC_PRIORITY))); } + if (clientContext != null) { + options = appendQueryOption(options, Options.clientContext(clientContext)); + } if (currentUnitOfWork != null && currentUnitOfWork.supportsDirectedReads(parsedStatement) && getConnectionPropertyValue(DIRECTED_READ) != null) { @@ -2070,6 +2088,14 @@ private UpdateOption[] mergeUpdateRequestOptions(UpdateOption... options) { options[options.length - 1] = Options.priority(getConnectionPropertyValue(RPC_PRIORITY)); } } + if (clientContext != null) { + if (options == null || options.length == 0) { + options = new UpdateOption[] {Options.clientContext(clientContext)}; + } else { + options = Arrays.copyOf(options, options.length + 1); + options[options.length - 1] = Options.clientContext(clientContext); + } + } return options; } @@ -2299,6 +2325,7 @@ UnitOfWork createNewUnitOfWork( createSpanForUnitOfWork( statementType == StatementType.DDL ? DDL_STATEMENT : SINGLE_USE_TRANSACTION)) .setProtoDescriptors(getProtoDescriptors()) + .setClientContext(clientContext) .build(); if (!isInternalMetadataQuery && !forceSingleUse) { // Reset the transaction options after starting a single-use transaction. @@ -2317,6 +2344,7 @@ UnitOfWork createNewUnitOfWork( .setTransactionTag(transactionTag) .setRpcPriority(getConnectionPropertyValue(RPC_PRIORITY)) .setSpan(createSpanForUnitOfWork(READ_ONLY_TRANSACTION)) + .setClientContext(clientContext) .build(); case READ_WRITE_TRANSACTION: return ReadWriteTransaction.newBuilder() @@ -2340,6 +2368,7 @@ UnitOfWork createNewUnitOfWork( .setExcludeTxnFromChangeStreams(excludeTxnFromChangeStreams) .setRpcPriority(getConnectionPropertyValue(RPC_PRIORITY)) .setSpan(createSpanForUnitOfWork(READ_WRITE_TRANSACTION)) + .setClientContext(clientContext) .build(); case DML_BATCH: // A DML batch can run inside the current transaction. It should therefore only @@ -2359,6 +2388,7 @@ UnitOfWork createNewUnitOfWork( .setRpcPriority(getConnectionPropertyValue(RPC_PRIORITY)) // Use the transaction Span for the DML batch. .setSpan(transactionStack.peek().getSpan()) + .setClientContext(clientContext) .build(); case DDL_BATCH: return DdlBatch.newBuilder() @@ -2369,6 +2399,7 @@ UnitOfWork createNewUnitOfWork( .setSpan(createSpanForUnitOfWork(DDL_BATCH)) .setProtoDescriptors(getProtoDescriptors()) .setConnectionState(connectionState) + .setClientContext(clientContext) .build(); default: } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java index c0e464ee5e6..ccb592e3f84 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java @@ -350,6 +350,9 @@ private TransactionOption[] extractOptions(Builder builder) { if (this.readLockMode != ReadLockMode.READ_LOCK_MODE_UNSPECIFIED) { numOptions++; } + if (this.clientContext != null) { + numOptions++; + } TransactionOption[] options = new TransactionOption[numOptions]; int index = 0; if (builder.returnCommitStats) { @@ -373,6 +376,9 @@ private TransactionOption[] extractOptions(Builder builder) { if (this.readLockMode != ReadLockMode.READ_LOCK_MODE_UNSPECIFIED) { options[index++] = Options.readLockMode(this.readLockMode); } + if (this.clientContext != null) { + options[index++] = Options.clientContext(this.clientContext); + } return options; } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SingleUseTransaction.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SingleUseTransaction.java index 370b579e6e2..cfb13cef966 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SingleUseTransaction.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SingleUseTransaction.java @@ -520,6 +520,9 @@ private TransactionRunner createWriteTransaction() { != ReadLockMode.READ_LOCK_MODE_UNSPECIFIED) { numOptions++; } + if (this.clientContext != null) { + numOptions++; + } if (numOptions == 0) { return dbClient.readWriteTransaction(); } @@ -547,6 +550,9 @@ private TransactionRunner createWriteTransaction() { != ReadLockMode.READ_LOCK_MODE_UNSPECIFIED) { options[index++] = Options.readLockMode(connectionState.getValue(READ_LOCK_MODE).getValue()); } + if (this.clientContext != null) { + options[index++] = Options.clientContext(this.clientContext); + } return dbClient.readWriteTransaction(options); } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/OptionsTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/OptionsTest.java index 3edf9a61e17..52cd2db7798 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/OptionsTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/OptionsTest.java @@ -35,6 +35,13 @@ import com.google.spanner.v1.ReadRequest.LockHint; import com.google.spanner.v1.ReadRequest.OrderBy; import com.google.spanner.v1.RequestOptions; +import com.google.spanner.v1.RequestOptions.Priority; +import com.google.spanner.v1.TransactionOptions.IsolationLevel; +import com.google.spanner.v1.TransactionOptions.ReadWrite; +import com.google.spanner.v1.TransactionOptions.ReadWrite.ReadLockMode; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; /** Unit tests for {@link Options}. */ @RunWith(JUnit4.class) diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java index 90dc59b1cdc..15ba8ec5acf 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java @@ -109,28 +109,28 @@ public void testCommitWithClientContext() { .putSecureContext( "key", com.google.protobuf.Value.newBuilder().setStringValue("value").build()) .build(); - Options options = - Options.fromTransactionOptions( + when(session.getName()).thenReturn("projects/p/instances/i/databases/d/sessions/s"); + when(session.newTransaction(any(Options.class), any())).thenReturn(txn); + Mockito.clearInvocations(session); + transactionRunner = + new TransactionRunnerImpl( + session, Options.priority(Options.RpcPriority.HIGH), Options.tag("tag"), Options.clientContext(clientContext)); - transactionRunner = new TransactionRunnerImpl(session, options); - when(session.getName()).thenReturn("projects/p/instances/i/databases/d/sessions/s"); - when(session.newTransaction(any(Options.class), any())).thenReturn(txn); + transactionRunner.setSpan(span); transactionRunner.run( transaction -> { return null; }); - ArgumentCaptor commitRequestCaptor = - ArgumentCaptor.forClass(CommitRequest.class); - verify(rpc).commitAsync(commitRequestCaptor.capture(), anyMap()); - CommitRequest request = commitRequestCaptor.getValue(); - RequestOptions requestOptions = request.getRequestOptions(); - assertEquals(RequestOptions.Priority.PRIORITY_HIGH, requestOptions.getPriority()); - assertEquals("tag", requestOptions.getTransactionTag()); - assertEquals(clientContext, requestOptions.getClientContext()); + ArgumentCaptor optionsCaptor = ArgumentCaptor.forClass(Options.class); + verify(session).newTransaction(optionsCaptor.capture(), any()); + Options capturedOptions = optionsCaptor.getValue(); + assertEquals(RequestOptions.Priority.PRIORITY_HIGH, capturedOptions.priority()); + assertEquals("tag", capturedOptions.tag()); + assertEquals(clientContext, capturedOptions.clientContext()); } @Mock private SpannerRpc rpc; diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ClientContextMockServerTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ClientContextMockServerTest.java new file mode 100644 index 00000000000..63287e1afac --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ClientContextMockServerTest.java @@ -0,0 +1,244 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner.connection; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; + +import com.google.cloud.spanner.Dialect; +import com.google.cloud.spanner.MockSpannerServiceImpl; +import com.google.cloud.spanner.ResultSet; +import com.google.protobuf.Value; +import com.google.spanner.v1.BeginTransactionRequest; +import com.google.spanner.v1.CommitRequest; +import com.google.spanner.v1.ExecuteBatchDmlRequest; +import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.RequestOptions; +import java.util.Collections; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(Parameterized.class) +public class ClientContextMockServerTest extends AbstractMockServerTest { + + @Parameters(name = "dialect = {0}") + public static Object[] data() { + return Dialect.values(); + } + + @Parameter public Dialect dialect; + + private Dialect currentDialect; + + private static final RequestOptions.ClientContext CLIENT_CONTEXT = + RequestOptions.ClientContext.newBuilder() + .putSecureContext("test-key", Value.newBuilder().setStringValue("test-value").build()) + .build(); + + @Before + public void setupDialect() { + if (currentDialect != dialect) { + mockSpanner.putStatementResult( + MockSpannerServiceImpl.StatementResult.detectDialectResult(dialect)); + SpannerPool.closeSpannerPool(); + currentDialect = dialect; + } + } + + @After + public void clearRequests() { + mockSpanner.clearRequests(); + } + + @Test + public void testQuery_PropagatesClientContext() { + try (Connection connection = createConnection()) { + connection.setClientContext(CLIENT_CONTEXT); + try (ResultSet ignore = connection.executeQuery(SELECT_COUNT_STATEMENT)) {} + + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + assertEquals( + CLIENT_CONTEXT, + mockSpanner + .getRequestsOfType(ExecuteSqlRequest.class) + .get(0) + .getRequestOptions() + .getClientContext()); + } + } + + @Test + public void testUpdate_PropagatesClientContext() { + try (Connection connection = createConnection()) { + connection.setClientContext(CLIENT_CONTEXT); + connection.executeUpdate(INSERT_STATEMENT); + + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + assertEquals( + CLIENT_CONTEXT, + mockSpanner + .getRequestsOfType(ExecuteSqlRequest.class) + .get(0) + .getRequestOptions() + .getClientContext()); + } + } + + @Test + public void testBatchUpdate_PropagatesClientContext() { + try (Connection connection = createConnection()) { + connection.setClientContext(CLIENT_CONTEXT); + connection.executeBatchUpdate(Collections.singletonList(INSERT_STATEMENT)); + + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class)); + assertEquals( + CLIENT_CONTEXT, + mockSpanner + .getRequestsOfType(ExecuteBatchDmlRequest.class) + .get(0) + .getRequestOptions() + .getClientContext()); + } + } + + @Test + public void testCommit_PropagatesClientContext() { + try (Connection connection = createConnection()) { + connection.setClientContext(CLIENT_CONTEXT); + connection.executeUpdate(INSERT_STATEMENT); + connection.commit(); + + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + assertEquals( + CLIENT_CONTEXT, + mockSpanner + .getRequestsOfType(CommitRequest.class) + .get(0) + .getRequestOptions() + .getClientContext()); + } + } + + @Test + public void testBeginTransaction_PropagatesClientContext() { + try (Connection connection = createConnection()) { + connection.setClientContext(CLIENT_CONTEXT); + connection.beginTransaction(); + // BeginTransaction is executed lazily, so we need to execute a statement to trigger it. + // However, the Connection API might do it eagerly if we call beginTransaction(). + // Let's check the implementation of beginTransactionAsync. + // It calls transactionManager.begin() which eventually calls BeginTransaction RPC. + // But ReadWriteTransaction waits until the first statement unless DELAY_TRANSACTION_START is false. + // Actually, ConnectionImpl.beginTransactionAsync calls clearLastTransactionAndSetDefaultTransactionOptions. + // It does NOT start the transaction on Spanner immediately. + // The transaction is started when the first statement is executed. + + // So let's execute a statement. + connection.executeUpdate(INSERT_STATEMENT); + + // Now checking requests. + // If the transaction started lazily (default), the first statement might be BeginTransaction? + // Or if it's ReadWrite, it might use inline BeginTransaction. + // If it uses inline BeginTransaction, then ExecuteSqlRequest will have transaction.begin. + + // Let's force an explicit BeginTransaction by executing a statement that requires it, + // or by checking if we can force it. + // But typically, for RW transaction, the first statement has BeginTransaction option. + + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(0); + assertEquals(CLIENT_CONTEXT, request.getRequestOptions().getClientContext()); + + // If we want to test BeginTransaction RPC specifically, we need to ensure it is called. + // It is called if we do `connection.beginTransaction()` and then something that triggers it? + // Or if we use `DELAY_TRANSACTION_START_UNTIL_FIRST_WRITE=false`? + // The default is `true` (I think). + + // If we use explicit transaction management via Connection, `beginTransaction` just sets state. + } + } + + @Test + public void testPersistence() { + try (Connection connection = createConnection()) { + connection.setClientContext(CLIENT_CONTEXT); + try (ResultSet ignore = connection.executeQuery(SELECT_COUNT_STATEMENT)) {} + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + assertEquals( + CLIENT_CONTEXT, + mockSpanner + .getRequestsOfType(ExecuteSqlRequest.class) + .get(0) + .getRequestOptions() + .getClientContext()); + + connection.executeUpdate(INSERT_STATEMENT); + assertEquals(2, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + assertEquals( + CLIENT_CONTEXT, + mockSpanner + .getRequestsOfType(ExecuteSqlRequest.class) + .get(1) + .getRequestOptions() + .getClientContext()); + + connection.commit(); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + assertEquals( + CLIENT_CONTEXT, + mockSpanner + .getRequestsOfType(CommitRequest.class) + .get(0) + .getRequestOptions() + .getClientContext()); + } + } + + @Test + public void testClearClientContext() { + try (Connection connection = createConnection()) { + connection.setClientContext(CLIENT_CONTEXT); + try (ResultSet ignore = connection.executeQuery(SELECT_COUNT_STATEMENT)) {} + + assertEquals( + CLIENT_CONTEXT, + mockSpanner + .getRequestsOfType(ExecuteSqlRequest.class) + .get(0) + .getRequestOptions() + .getClientContext()); + + connection.setClientContext(null); + try (ResultSet ignore = connection.executeQuery(SELECT_COUNT_STATEMENT)) {} + + assertEquals(2, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + assertFalse( + mockSpanner + .getRequestsOfType(ExecuteSqlRequest.class) + .get(1) + .getRequestOptions() + .hasClientContext()); + } + } +} diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ConnectionImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ConnectionImplTest.java index ead2fd0f655..d430c04bbd1 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ConnectionImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ConnectionImplTest.java @@ -1948,6 +1948,35 @@ private void assertThrowResultNotAllowed( "Only statements that return a result of one of the following types are allowed")); } + @Test + public void testSetAndGetClientContext() { + try (Connection connection = createConnection(ConnectionOptions.newBuilder().setUri(URI).build())) { + com.google.spanner.v1.RequestOptions.ClientContext context = + com.google.spanner.v1.RequestOptions.ClientContext.newBuilder() + .putSecureContext( + "key", com.google.protobuf.Value.newBuilder().setStringValue("test").build()) + .build(); + connection.setClientContext(context); + assertEquals(context, connection.getClientContext()); + } + } + + @Test + public void testResetClearsClientContext() { + try (Connection connection = createConnection(ConnectionOptions.newBuilder().setUri(URI).build())) { + com.google.spanner.v1.RequestOptions.ClientContext context = + com.google.spanner.v1.RequestOptions.ClientContext.newBuilder() + .putSecureContext( + "key", com.google.protobuf.Value.newBuilder().setStringValue("test").build()) + .build(); + connection.setClientContext(context); + assertEquals(context, connection.getClientContext()); + + connection.reset(); + assertNull(connection.getClientContext()); + } + } + @Test public void testProtoDescriptorsAlwaysAllowed() { ConnectionOptions connectionOptions = mock(ConnectionOptions.class);