diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java index e4196e2ac63..46107de897a 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java @@ -52,6 +52,7 @@ class GrpcStreamIterator extends AbstractIterator private TimeUnit streamWaitTimeoutUnit; private long streamWaitTimeoutValue; private SpannerException error; + private boolean done; @VisibleForTesting GrpcStreamIterator(int prefetchChunks, boolean cancelQueryWhenClientIsClosed) { @@ -166,11 +167,17 @@ private class ConsumerImpl implements SpannerRpc.ResultStreamConsumer { @Override public void onPartialResultSet(PartialResultSet results) { addToStream(results); + if (results.getLast()) { + done = true; + addToStream(END_OF_STREAM); + } } @Override public void onCompleted() { - addToStream(END_OF_STREAM); + if (!done) { + addToStream(END_OF_STREAM); + } } @Override diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GrpcResultSetTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GrpcResultSetTest.java index 25c01560e92..5faaf3fd817 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GrpcResultSetTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GrpcResultSetTest.java @@ -19,6 +19,7 @@ import static com.google.common.testing.SerializableTester.reserialize; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; @@ -1115,4 +1116,58 @@ public void getProtoEnumList() { resultSet.getProtoEnum(0, Genre::forNumber); }); } + + @Test + public void verifyResultSetWithLastTrue() { + long[] longArray = {111, 333, 444, 0, -1, -2234, Long.MAX_VALUE, Long.MIN_VALUE}; + + consumer.onPartialResultSet( + PartialResultSet.newBuilder() + .setMetadata( + makeMetadata(Type.struct(Type.StructField.of("f", Type.array(Type.int64()))))) + .addValues(Value.int64Array(longArray).toProto()) + .setLast(false) + .build()); + assertTrue(resultSet.next()); + consumer.onPartialResultSet( + PartialResultSet.newBuilder() + .setMetadata( + makeMetadata(Type.struct(Type.StructField.of("f", Type.array(Type.int64()))))) + .addValues(Value.int64Array(longArray).toProto()) + .setLast(true) + .build()); + assertTrue(resultSet.next()); + assertFalse(resultSet.next()); + consumer.onCompleted(); + } + + @Test + public void shouldThrowDeadlineExceededIfLastTrueIsNotReceived() { + long[] longArray = {111, 333, 444, 0, -1, -2234, Long.MAX_VALUE, Long.MIN_VALUE}; + + consumer.onPartialResultSet( + PartialResultSet.newBuilder() + .setMetadata( + makeMetadata(Type.struct(Type.StructField.of("f", Type.array(Type.int64()))))) + .addValues(Value.int64Array(longArray).toProto()) + .setLast(false) + .build()); + assertTrue(resultSet.next()); + consumer.onPartialResultSet( + PartialResultSet.newBuilder() + .setMetadata( + makeMetadata(Type.struct(Type.StructField.of("f", Type.array(Type.int64()))))) + .addValues(Value.int64Array(longArray).toProto()) + .setLast(false) + .build()); + assertTrue(resultSet.next()); + SpannerException spannerException = + assertThrows( + SpannerException.class, + () -> { + assertThat(resultSet.next()).isFalse(); + }); + assertEquals("DEADLINE_EXCEEDED: stream wait timeout", spannerException.getMessage()); + consumer.onCompleted(); + } }