Skip to content

Commit cb13eef

Browse files
committed
Don't return empty routing key when partition key is unbound
DefaultBoundStatement#getRoutingKey has logic to infer the routing key when no one has explicitly called setRoutingKey or otherwise set the routing key on the statement. It however doesn't check for cases where nothing has been bound yet on the statement. This causes more problems if the user decides to get a BoundStatementBuilder from the PreparedStatement, set some fields on it, and then copy it by constructing new BoundStatementBuilder objects with the BoundStatement as a parameter, since the empty ByteBuffer gets copied to all bound statements, resulting in all requests being targeted to the same Cassandra node in a token-aware load balancing policy.
1 parent a650ee4 commit cb13eef

File tree

2 files changed

+102
-1
lines changed

2 files changed

+102
-1
lines changed

core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBoundStatement.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ public ByteBuffer getRoutingKey() {
358358
if (indices.isEmpty()) {
359359
return null;
360360
} else if (indices.size() == 1) {
361-
return getBytesUnsafe(indices.get(0));
361+
return isSet(0) ? getBytesUnsafe(indices.get(0)) : null;
362362
} else {
363363
ByteBuffer[] components = new ByteBuffer[indices.size()];
364364
for (int i = 0; i < components.length; i++) {

core/src/test/java/com/datastax/oss/driver/api/core/cql/StatementBuilderTest.java

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,30 @@
1919
import static org.mockito.Mockito.mock;
2020
import static org.mockito.Mockito.when;
2121

22+
import com.datastax.oss.driver.api.core.DefaultProtocolVersion;
23+
import com.datastax.oss.driver.api.core.context.DriverContext;
24+
import com.datastax.oss.driver.api.core.type.DataType;
25+
import com.datastax.oss.driver.api.core.type.DataTypes;
26+
import com.datastax.oss.driver.api.core.type.codec.registry.CodecRegistry;
27+
import com.datastax.oss.driver.internal.core.cql.DefaultColumnDefinition;
28+
import com.datastax.oss.driver.internal.core.cql.DefaultColumnDefinitions;
29+
import com.datastax.oss.driver.internal.core.cql.DefaultPreparedStatement;
2230
import com.datastax.oss.driver.shaded.guava.common.base.Charsets;
31+
import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableList;
32+
import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableMap;
33+
import com.datastax.oss.protocol.internal.response.result.ColumnSpec;
34+
import com.datastax.oss.protocol.internal.response.result.RawType;
35+
import com.datastax.oss.protocol.internal.util.Bytes;
2336
import java.nio.ByteBuffer;
37+
import java.util.Collections;
38+
import java.util.Map;
39+
import org.junit.Before;
2440
import org.junit.Test;
41+
import org.junit.runner.RunWith;
42+
import org.mockito.Mock;
43+
import org.mockito.junit.MockitoJUnitRunner;
2544

45+
@RunWith(MockitoJUnitRunner.class)
2646
public class StatementBuilderTest {
2747

2848
private static class MockSimpleStatementBuilder
@@ -46,6 +66,14 @@ public SimpleStatement build() {
4666
}
4767
}
4868

69+
@Mock private DriverContext context;
70+
71+
@Before
72+
public void setup() {
73+
when(context.getCodecRegistry()).thenReturn(CodecRegistry.DEFAULT);
74+
when(context.getProtocolVersion()).thenReturn(DefaultProtocolVersion.V4);
75+
}
76+
4977
@Test
5078
public void should_handle_set_tracing_without_args() {
5179

@@ -99,4 +127,77 @@ public void should_match_set_routing_key_vararg() {
99127
builderStmt = builder.setRoutingKey(buff2, buff1).build();
100128
assertThat(expectedStmt.getRoutingKey()).isNotEqualTo(builderStmt.getRoutingKey());
101129
}
130+
131+
@Test
132+
public void should_correctly_set_routing_key_on_boundstatement() {
133+
PreparedStatement preparedStatement =
134+
mockPreparedStatement(
135+
"UPDATE foo SET v=? WHERE k=?",
136+
ImmutableMap.of("v", DataTypes.INT, "k", DataTypes.INT));
137+
BoundStatement boundStatement = preparedStatement.boundStatementBuilder().build();
138+
// we should not have a routing key set
139+
assertThat(boundStatement.getRoutingKey()).isNull();
140+
boundStatement = boundStatement.set(0, 1, Integer.class).set(1, 2, Integer.class);
141+
// we should have a non-empty routing key now
142+
assertThat(boundStatement.getRoutingKey()).isNotNull();
143+
assertThat(boundStatement.getRoutingKey().hasRemaining()).isTrue();
144+
}
145+
146+
@Test
147+
public void should_correctly_set_routing_key_on_boundstatementbuilder() {
148+
PreparedStatement preparedStatement =
149+
mockPreparedStatement(
150+
"UPDATE foo SET v=? WHERE k=?",
151+
ImmutableMap.of("v", DataTypes.INT, "k", DataTypes.INT));
152+
BoundStatement boundStatement = preparedStatement.boundStatementBuilder().build();
153+
BoundStatement copy = new BoundStatementBuilder(boundStatement).build();
154+
// we should not have a routing key set
155+
assertThat(copy.getRoutingKey()).isNull();
156+
copy = copy.set(0, 1, Integer.class).set(1, 2, Integer.class);
157+
// we should have a non-empty routing key now
158+
assertThat(copy.getRoutingKey()).isNotNull();
159+
assertThat(copy.getRoutingKey().hasRemaining()).isTrue();
160+
}
161+
162+
// copied from RequestLogFormatterTest, we should move somewhere to share b/w tests
163+
private PreparedStatement mockPreparedStatement(String query, Map<String, DataType> variables) {
164+
ImmutableList.Builder<ColumnDefinition> definitions = ImmutableList.builder();
165+
int i = 0;
166+
for (Map.Entry<String, DataType> entry : variables.entrySet()) {
167+
definitions.add(
168+
new DefaultColumnDefinition(
169+
new ColumnSpec(
170+
"test",
171+
"foo",
172+
entry.getKey(),
173+
i,
174+
RawType.PRIMITIVES.get(entry.getValue().getProtocolCode())),
175+
context));
176+
}
177+
return new DefaultPreparedStatement(
178+
Bytes.fromHexString("0x"),
179+
query,
180+
DefaultColumnDefinitions.valueOf(definitions.build()),
181+
Collections.singletonList(
182+
1), // note that this line is different from the one in RequestLogFormatterTest
183+
null,
184+
null,
185+
null,
186+
Collections.emptyMap(),
187+
null,
188+
null,
189+
null,
190+
null,
191+
null,
192+
Collections.emptyMap(),
193+
null,
194+
null,
195+
null,
196+
Integer.MIN_VALUE,
197+
null,
198+
null,
199+
false,
200+
context.getCodecRegistry(),
201+
context.getProtocolVersion());
202+
}
102203
}

0 commit comments

Comments
 (0)