1919import static org .mockito .Mockito .mock ;
2020import static org .mockito .Mockito .when ;
2121
22+ import com .datastax .oss .driver .api .core .DefaultProtocolVersion ;
23+ import com .datastax .oss .driver .api .core .context .DriverContext ;
24+ import com .datastax .oss .driver .api .core .type .DataType ;
25+ import com .datastax .oss .driver .api .core .type .DataTypes ;
26+ import com .datastax .oss .driver .api .core .type .codec .registry .CodecRegistry ;
27+ import com .datastax .oss .driver .internal .core .cql .DefaultColumnDefinition ;
28+ import com .datastax .oss .driver .internal .core .cql .DefaultColumnDefinitions ;
29+ import com .datastax .oss .driver .internal .core .cql .DefaultPreparedStatement ;
2230import com .datastax .oss .driver .shaded .guava .common .base .Charsets ;
31+ import com .datastax .oss .driver .shaded .guava .common .collect .ImmutableList ;
32+ import com .datastax .oss .driver .shaded .guava .common .collect .ImmutableMap ;
33+ import com .datastax .oss .protocol .internal .response .result .ColumnSpec ;
34+ import com .datastax .oss .protocol .internal .response .result .RawType ;
35+ import com .datastax .oss .protocol .internal .util .Bytes ;
2336import java .nio .ByteBuffer ;
37+ import java .util .Collections ;
38+ import java .util .Map ;
39+ import org .junit .Before ;
2440import org .junit .Test ;
41+ import org .junit .runner .RunWith ;
42+ import org .mockito .Mock ;
43+ import org .mockito .junit .MockitoJUnitRunner ;
2544
45+ @ RunWith (MockitoJUnitRunner .class )
2646public class StatementBuilderTest {
2747
2848 private static class MockSimpleStatementBuilder
@@ -46,6 +66,14 @@ public SimpleStatement build() {
4666 }
4767 }
4868
69+ @ Mock private DriverContext context ;
70+
71+ @ Before
72+ public void setup () {
73+ when (context .getCodecRegistry ()).thenReturn (CodecRegistry .DEFAULT );
74+ when (context .getProtocolVersion ()).thenReturn (DefaultProtocolVersion .V4 );
75+ }
76+
4977 @ Test
5078 public void should_handle_set_tracing_without_args () {
5179
@@ -99,4 +127,72 @@ public void should_match_set_routing_key_vararg() {
99127 builderStmt = builder .setRoutingKey (buff2 , buff1 ).build ();
100128 assertThat (expectedStmt .getRoutingKey ()).isNotEqualTo (builderStmt .getRoutingKey ());
101129 }
130+
131+ @ Test
132+ public void should_correctly_set_routing_key_on_boundstatement () {
133+ PreparedStatement preparedStatement = mockPreparedStatement ("UPDATE foo SET v=? WHERE k=?" ,
134+ ImmutableMap .of ("v" , DataTypes .INT , "k" , DataTypes .INT ));
135+ BoundStatement boundStatement = preparedStatement .boundStatementBuilder ().build ();
136+ // we should not have a routing key set
137+ assertThat (boundStatement .getRoutingKey ()).isNull ();
138+ boundStatement = boundStatement .set (0 , 1 , Integer .class ).set (1 , 2 , Integer .class );
139+ // we should have a non-empty routing key now
140+ assertThat (boundStatement .getRoutingKey ()).isNotNull ();
141+ assertThat (boundStatement .getRoutingKey ().hasRemaining ()).isTrue ();
142+ }
143+
144+ @ Test
145+ public void should_correctly_set_routing_key_on_boundstatementbuilder () {
146+ PreparedStatement preparedStatement = mockPreparedStatement ("UPDATE foo SET v=? WHERE k=?" ,
147+ ImmutableMap .of ("v" , DataTypes .INT , "k" , DataTypes .INT ));
148+ BoundStatement boundStatement = preparedStatement .boundStatementBuilder ().build ();
149+ BoundStatement copy = new BoundStatementBuilder (boundStatement ).build ();
150+ // we should not have a routing key set
151+ assertThat (copy .getRoutingKey ()).isNull ();
152+ copy = copy .set (0 , 1 , Integer .class ).set (1 , 2 , Integer .class );
153+ // we should have a non-empty routing key now
154+ assertThat (copy .getRoutingKey ()).isNotNull ();
155+ assertThat (copy .getRoutingKey ().hasRemaining ()).isTrue ();
156+ }
157+
158+ // copied from RequestLogFormatterTest, we should move somewhere to share b/w tests
159+ private PreparedStatement mockPreparedStatement (String query , Map <String , DataType > variables ) {
160+ ImmutableList .Builder <ColumnDefinition > definitions = ImmutableList .builder ();
161+ int i = 0 ;
162+ for (Map .Entry <String , DataType > entry : variables .entrySet ()) {
163+ definitions .add (
164+ new DefaultColumnDefinition (
165+ new ColumnSpec (
166+ "test" ,
167+ "foo" ,
168+ entry .getKey (),
169+ i ,
170+ RawType .PRIMITIVES .get (entry .getValue ().getProtocolCode ())),
171+ context ));
172+ }
173+ return new DefaultPreparedStatement (
174+ Bytes .fromHexString ("0x" ),
175+ query ,
176+ DefaultColumnDefinitions .valueOf (definitions .build ()),
177+ Collections .singletonList (1 ), // note that this line is different from the one in RequestLogFormatterTest
178+ null ,
179+ null ,
180+ null ,
181+ Collections .emptyMap (),
182+ null ,
183+ null ,
184+ null ,
185+ null ,
186+ null ,
187+ Collections .emptyMap (),
188+ null ,
189+ null ,
190+ null ,
191+ Integer .MIN_VALUE ,
192+ null ,
193+ null ,
194+ false ,
195+ context .getCodecRegistry (),
196+ context .getProtocolVersion ());
197+ }
102198}
0 commit comments