diff --git a/src/main/java/com/datastax/oss/protocol/internal/request/Startup.java b/src/main/java/com/datastax/oss/protocol/internal/request/Startup.java index c879838..5e513fa 100644 --- a/src/main/java/com/datastax/oss/protocol/internal/request/Startup.java +++ b/src/main/java/com/datastax/oss/protocol/internal/request/Startup.java @@ -23,24 +23,38 @@ import java.util.Map; public class Startup extends Message { - private static final String CQL_VERSION_KEY = "CQL_VERSION"; - private static final String COMPRESSION_KEY = "COMPRESSION"; + public static final String CQL_VERSION_KEY = "CQL_VERSION"; + public static final String COMPRESSION_KEY = "COMPRESSION"; private static final String CQL_VERSION = "3.0.0"; public final Map options; public Startup(String compressionAlgorithm) { - super(false, ProtocolConstants.Opcode.STARTUP); - this.options = + this( (compressionAlgorithm == null || compressionAlgorithm.isEmpty()) ? NullAllowingImmutableMap.of(CQL_VERSION_KEY, CQL_VERSION) : NullAllowingImmutableMap.of( - CQL_VERSION_KEY, CQL_VERSION, COMPRESSION_KEY, compressionAlgorithm); + CQL_VERSION_KEY, CQL_VERSION, COMPRESSION_KEY, compressionAlgorithm)); } public Startup() { - this(null); + this((Map) null); + } + + public Startup(Map options) { + super(false, ProtocolConstants.Opcode.STARTUP); + if (options != null) { + if (options.containsKey(CQL_VERSION_KEY)) { + this.options = NullAllowingImmutableMap.copyOf(options); + } else { + NullAllowingImmutableMap.Builder builder = + NullAllowingImmutableMap.builder(options.size() + 1); + this.options = builder.put(CQL_VERSION_KEY, CQL_VERSION).putAll(options).build(); + } + } else { + this.options = NullAllowingImmutableMap.of(CQL_VERSION_KEY, CQL_VERSION); + } } @Override @@ -69,7 +83,7 @@ public int encodedSize(Message message) { @Override public Message decode(B source, PrimitiveCodec decoder) { Map map = decoder.readStringMap(source); - return new Startup(map.get(COMPRESSION_KEY)); + return new Startup(map); } } } diff --git a/src/test/java/com/datastax/oss/protocol/internal/request/StartupTest.java b/src/test/java/com/datastax/oss/protocol/internal/request/StartupTest.java index 56d2f47..5889692 100644 --- a/src/test/java/com/datastax/oss/protocol/internal/request/StartupTest.java +++ b/src/test/java/com/datastax/oss/protocol/internal/request/StartupTest.java @@ -22,6 +22,7 @@ import com.datastax.oss.protocol.internal.PrimitiveSizes; import com.datastax.oss.protocol.internal.TestDataProviders; import com.datastax.oss.protocol.internal.binary.MockBinaryString; +import com.datastax.oss.protocol.internal.util.collection.NullAllowingImmutableMap; import com.tngtech.java.junit.dataprovider.DataProviderRunner; import com.tngtech.java.junit.dataprovider.UseDataProvider; import org.junit.Test; @@ -90,4 +91,74 @@ public void should_encode_and_decode_without_compression(int protocolVersion) { assertThat(decoded.options).hasSize(1).containsEntry("CQL_VERSION", "3.0.0"); } + + @Test + @UseDataProvider(location = TestDataProviders.class, value = "protocolV3OrAbove") + public void should_encode_and_decode_with_custom_options(int protocolVersion) { + Startup initial = new Startup(NullAllowingImmutableMap.of("CUSTOM_OPT1", "VALUE1")); + + MockBinaryString encoded = encode(initial, protocolVersion); + + assertThat(encoded) + .isEqualTo( + new MockBinaryString() + .unsignedShort(2) // size of string map + // string map entries + .string("CQL_VERSION") + .string("3.0.0") + .string("CUSTOM_OPT1") + .string("VALUE1")); + assertThat(encodedSize(initial, protocolVersion)) + .isEqualTo( + PrimitiveSizes.SHORT + + (PrimitiveSizes.SHORT + "CQL_VERSION".length()) + + (PrimitiveSizes.SHORT + "3.0.0".length()) + + (PrimitiveSizes.SHORT + "CUSTOM_OPT1".length()) + + (PrimitiveSizes.SHORT + "VALUE1".length())); + + Startup decoded = decode(encoded, protocolVersion); + + assertThat(decoded.options) + .hasSize(2) + .containsEntry("CQL_VERSION", "3.0.0") + .containsEntry("CUSTOM_OPT1", "VALUE1"); + } + + @Test + @UseDataProvider(location = TestDataProviders.class, value = "protocolV3OrAbove") + public void should_encode_and_decode_with_custom_options_and_compression(int protocolVersion) { + Startup initial = + new Startup(NullAllowingImmutableMap.of("CUSTOM_OPT1", "VALUE1", "COMPRESSION", "LZ4")); + + MockBinaryString encoded = encode(initial, protocolVersion); + + assertThat(encoded) + .isEqualTo( + new MockBinaryString() + .unsignedShort(3) // size of string map + // string map entries + .string("CQL_VERSION") + .string("3.0.0") + .string("CUSTOM_OPT1") + .string("VALUE1") + .string("COMPRESSION") + .string("LZ4")); + assertThat(encodedSize(initial, protocolVersion)) + .isEqualTo( + PrimitiveSizes.SHORT + + (PrimitiveSizes.SHORT + "CQL_VERSION".length()) + + (PrimitiveSizes.SHORT + "3.0.0".length()) + + (PrimitiveSizes.SHORT + "CUSTOM_OPT1".length()) + + (PrimitiveSizes.SHORT + "VALUE1".length()) + + (PrimitiveSizes.SHORT + "COMPRESSION".length()) + + (PrimitiveSizes.SHORT + "Lz4".length())); + + Startup decoded = decode(encoded, protocolVersion); + + assertThat(decoded.options) + .hasSize(3) + .containsEntry("CQL_VERSION", "3.0.0") + .containsEntry("CUSTOM_OPT1", "VALUE1") + .containsEntry("COMPRESSION", "LZ4"); + } }