1919import static org .mockito .Mockito .mock ;
2020import static org .mockito .Mockito .when ;
2121
22+ import com .datastax .oss .driver .api .core .DefaultProtocolVersion ;
23+ import com .datastax .oss .driver .api .core .context .DriverContext ;
24+ import com .datastax .oss .driver .api .core .type .DataType ;
25+ import com .datastax .oss .driver .api .core .type .DataTypes ;
26+ import com .datastax .oss .driver .api .core .type .codec .registry .CodecRegistry ;
27+ import com .datastax .oss .driver .internal .core .cql .DefaultColumnDefinition ;
28+ import com .datastax .oss .driver .internal .core .cql .DefaultColumnDefinitions ;
29+ import com .datastax .oss .driver .internal .core .cql .DefaultPreparedStatement ;
2230import com .datastax .oss .driver .shaded .guava .common .base .Charsets ;
31+ import com .datastax .oss .driver .shaded .guava .common .collect .ImmutableList ;
32+ import com .datastax .oss .driver .shaded .guava .common .collect .ImmutableMap ;
33+ import com .datastax .oss .protocol .internal .response .result .ColumnSpec ;
34+ import com .datastax .oss .protocol .internal .response .result .RawType ;
35+ import com .datastax .oss .protocol .internal .util .Bytes ;
2336import java .nio .ByteBuffer ;
37+ import java .util .Collections ;
38+ import java .util .Map ;
39+ import org .junit .Before ;
2440import org .junit .Test ;
41+ import org .junit .runner .RunWith ;
42+ import org .mockito .Mock ;
43+ import org .mockito .junit .MockitoJUnitRunner ;
2544
45+ @ RunWith (MockitoJUnitRunner .class )
2646public class StatementBuilderTest {
2747
2848 private static class MockSimpleStatementBuilder
@@ -46,6 +66,14 @@ public SimpleStatement build() {
4666 }
4767 }
4868
69+ @ Mock private DriverContext context ;
70+
71+ @ Before
72+ public void setup () {
73+ when (context .getCodecRegistry ()).thenReturn (CodecRegistry .DEFAULT );
74+ when (context .getProtocolVersion ()).thenReturn (DefaultProtocolVersion .V4 );
75+ }
76+
4977 @ Test
5078 public void should_handle_set_tracing_without_args () {
5179
@@ -99,4 +127,77 @@ public void should_match_set_routing_key_vararg() {
99127 builderStmt = builder .setRoutingKey (buff2 , buff1 ).build ();
100128 assertThat (expectedStmt .getRoutingKey ()).isNotEqualTo (builderStmt .getRoutingKey ());
101129 }
130+
131+ @ Test
132+ public void should_correctly_set_routing_key_on_boundstatement () {
133+ PreparedStatement preparedStatement =
134+ mockPreparedStatement (
135+ "UPDATE foo SET v=? WHERE k=?" ,
136+ ImmutableMap .of ("v" , DataTypes .INT , "k" , DataTypes .INT ));
137+ BoundStatement boundStatement = preparedStatement .boundStatementBuilder ().build ();
138+ // we should not have a routing key set
139+ assertThat (boundStatement .getRoutingKey ()).isNull ();
140+ boundStatement = boundStatement .set (0 , 1 , Integer .class ).set (1 , 2 , Integer .class );
141+ // we should have a non-empty routing key now
142+ assertThat (boundStatement .getRoutingKey ()).isNotNull ();
143+ assertThat (boundStatement .getRoutingKey ().hasRemaining ()).isTrue ();
144+ }
145+
146+ @ Test
147+ public void should_correctly_set_routing_key_on_boundstatementbuilder () {
148+ PreparedStatement preparedStatement =
149+ mockPreparedStatement (
150+ "UPDATE foo SET v=? WHERE k=?" ,
151+ ImmutableMap .of ("v" , DataTypes .INT , "k" , DataTypes .INT ));
152+ BoundStatement boundStatement = preparedStatement .boundStatementBuilder ().build ();
153+ BoundStatement copy = new BoundStatementBuilder (boundStatement ).build ();
154+ // we should not have a routing key set
155+ assertThat (copy .getRoutingKey ()).isNull ();
156+ copy = copy .set (0 , 1 , Integer .class ).set (1 , 2 , Integer .class );
157+ // we should have a non-empty routing key now
158+ assertThat (copy .getRoutingKey ()).isNotNull ();
159+ assertThat (copy .getRoutingKey ().hasRemaining ()).isTrue ();
160+ }
161+
162+ // copied from RequestLogFormatterTest, we should move somewhere to share b/w tests
163+ private PreparedStatement mockPreparedStatement (String query , Map <String , DataType > variables ) {
164+ ImmutableList .Builder <ColumnDefinition > definitions = ImmutableList .builder ();
165+ int i = 0 ;
166+ for (Map .Entry <String , DataType > entry : variables .entrySet ()) {
167+ definitions .add (
168+ new DefaultColumnDefinition (
169+ new ColumnSpec (
170+ "test" ,
171+ "foo" ,
172+ entry .getKey (),
173+ i ,
174+ RawType .PRIMITIVES .get (entry .getValue ().getProtocolCode ())),
175+ context ));
176+ }
177+ return new DefaultPreparedStatement (
178+ Bytes .fromHexString ("0x" ),
179+ query ,
180+ DefaultColumnDefinitions .valueOf (definitions .build ()),
181+ Collections .singletonList (
182+ 1 ), // note that this line is different from the one in RequestLogFormatterTest
183+ null ,
184+ null ,
185+ null ,
186+ Collections .emptyMap (),
187+ null ,
188+ null ,
189+ null ,
190+ null ,
191+ null ,
192+ Collections .emptyMap (),
193+ null ,
194+ null ,
195+ null ,
196+ Integer .MIN_VALUE ,
197+ null ,
198+ null ,
199+ false ,
200+ context .getCodecRegistry (),
201+ context .getProtocolVersion ());
202+ }
102203}
0 commit comments