diff --git a/vector/src/main/codegen/templates/AbstractFieldReader.java b/vector/src/main/codegen/templates/AbstractFieldReader.java index c7c5b4d78d..efb87c71dc 100644 --- a/vector/src/main/codegen/templates/AbstractFieldReader.java +++ b/vector/src/main/codegen/templates/AbstractFieldReader.java @@ -109,10 +109,6 @@ public void copyAsField(String name, ${name}Writer writer) { - public void copyAsValue(StructWriter writer, ExtensionTypeWriterFactory writerFactory) { - fail("CopyAsValue StructWriter"); - } - public void read(ExtensionHolder holder) { fail("Extension"); } @@ -144,7 +140,13 @@ public int size() { return -1; } + @Override + public ExtensionTypeWriterFactory getExtensionTypeWriterFactory() { + throw new IllegalStateException("The current reader doesn't support reading extension type"); + } + private void fail(String name) { throw new IllegalArgumentException(String.format("You tried to read a [%s] type when you are using a field reader of type [%s].", name, this.getClass().getSimpleName())); } + } diff --git a/vector/src/main/codegen/templates/AbstractFieldWriter.java b/vector/src/main/codegen/templates/AbstractFieldWriter.java index ae5b97faef..bcdc5c5302 100644 --- a/vector/src/main/codegen/templates/AbstractFieldWriter.java +++ b/vector/src/main/codegen/templates/AbstractFieldWriter.java @@ -107,15 +107,14 @@ public void endEntry() { throw new IllegalStateException(String.format("You tried to end a map entry when you are using a ValueWriter of type %s.", this.getClass().getSimpleName())); } + @Override public void write(ExtensionHolder var1) { this.fail("ExtensionType"); } + @Override public void writeExtension(Object var1) { this.fail("ExtensionType"); } - public void addExtensionTypeWriterFactory(ExtensionTypeWriterFactory var1) { - this.fail("ExtensionType"); - } <#list vv.types as type><#list type.minor as minor><#assign name = minor.class?cap_first /> <#assign fields = minor.fields!type.fields /> diff --git a/vector/src/main/codegen/templates/ArrowType.java b/vector/src/main/codegen/templates/ArrowType.java index fd35c1cd2b..5daf57c3d3 100644 --- a/vector/src/main/codegen/templates/ArrowType.java +++ b/vector/src/main/codegen/templates/ArrowType.java @@ -27,8 +27,10 @@ import org.apache.arrow.flatbuf.Type; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.complex.writer.FieldWriter; import org.apache.arrow.vector.types.*; import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonIgnore; @@ -331,6 +333,8 @@ public boolean equals(Object obj) { public T accept(ArrowTypeVisitor visitor) { return visitor.visit(this); } + + public abstract FieldWriter getNewFieldWriter(ValueVector vector); } private static final int defaultDecimalBitWidth = 128; diff --git a/vector/src/main/codegen/templates/BaseReader.java b/vector/src/main/codegen/templates/BaseReader.java index 4c6f49ab9b..c52345af21 100644 --- a/vector/src/main/codegen/templates/BaseReader.java +++ b/vector/src/main/codegen/templates/BaseReader.java @@ -49,7 +49,6 @@ public interface RepeatedStructReader extends StructReader{ boolean next(); int size(); void copyAsValue(StructWriter writer); - void copyAsValue(StructWriter writer, ExtensionTypeWriterFactory writerFactory); } public interface ListReader extends BaseReader{ @@ -60,7 +59,6 @@ public interface RepeatedListReader extends ListReader{ boolean next(); int size(); void copyAsValue(ListWriter writer); - void copyAsValue(ListWriter writer, ExtensionTypeWriterFactory writerFactory); } public interface MapReader extends BaseReader{ @@ -71,7 +69,6 @@ public interface RepeatedMapReader extends MapReader{ boolean next(); int size(); void copyAsValue(MapWriter writer); - void copyAsValue(MapWriter writer, ExtensionTypeWriterFactory writerFactory); } public interface ScalarReader extends diff --git a/vector/src/main/codegen/templates/BaseWriter.java b/vector/src/main/codegen/templates/BaseWriter.java index 78da7fddc3..e5333ef292 100644 --- a/vector/src/main/codegen/templates/BaseWriter.java +++ b/vector/src/main/codegen/templates/BaseWriter.java @@ -123,13 +123,6 @@ public interface ExtensionWriter extends BaseWriter { * @param value the extension type value to write */ void writeExtension(Object value); - - /** - * Adds the given extension type factory. This factory allows configuring writer implementations for specific ExtensionTypeVector. - * - * @param factory the extension type factory to add - */ - void addExtensionTypeWriterFactory(ExtensionTypeWriterFactory factory); } public interface ScalarWriter extends diff --git a/vector/src/main/codegen/templates/ComplexCopier.java b/vector/src/main/codegen/templates/ComplexCopier.java index 4df5478f48..b6449e0e94 100644 --- a/vector/src/main/codegen/templates/ComplexCopier.java +++ b/vector/src/main/codegen/templates/ComplexCopier.java @@ -41,15 +41,8 @@ public class ComplexCopier { * @param input field to read from * @param output field to write to */ - public static void copy(FieldReader input, FieldWriter output) { - writeValue(input, output, null); - } - - public static void copy(FieldReader input, FieldWriter output, ExtensionTypeWriterFactory extensionTypeWriterFactory) { - writeValue(input, output, extensionTypeWriterFactory); - } + public static void copy(FieldReader reader, FieldWriter writer) { - private static void writeValue(FieldReader reader, FieldWriter writer, ExtensionTypeWriterFactory extensionTypeWriterFactory) { final MinorType mt = reader.getMinorType(); switch (mt) { @@ -65,7 +58,7 @@ private static void writeValue(FieldReader reader, FieldWriter writer, Extension FieldReader childReader = reader.reader(); FieldWriter childWriter = getListWriterForReader(childReader, writer); if (childReader.isSet()) { - writeValue(childReader, childWriter, extensionTypeWriterFactory); + copy(childReader, childWriter); } else { childWriter.writeNull(); } @@ -83,8 +76,8 @@ private static void writeValue(FieldReader reader, FieldWriter writer, Extension FieldReader structReader = reader.reader(); if (structReader.isSet()) { writer.startEntry(); - writeValue(mapReader.key(), getMapWriterForReader(mapReader.key(), writer.key()), extensionTypeWriterFactory); - writeValue(mapReader.value(), getMapWriterForReader(mapReader.value(), writer.value()), extensionTypeWriterFactory); + copy(mapReader.key(), getMapWriterForReader(mapReader.key(), writer.key())); + copy(mapReader.value(), getMapWriterForReader(mapReader.value(), writer.value())); writer.endEntry(); } else { writer.writeNull(); @@ -103,7 +96,7 @@ private static void writeValue(FieldReader reader, FieldWriter writer, Extension if (childReader.getMinorType() != Types.MinorType.NULL) { FieldWriter childWriter = getStructWriterForReader(childReader, writer, name); if (childReader.isSet()) { - writeValue(childReader, childWriter, extensionTypeWriterFactory); + copy(childReader, childWriter); } else { childWriter.writeNull(); } @@ -115,13 +108,9 @@ private static void writeValue(FieldReader reader, FieldWriter writer, Extension } break; case EXTENSIONTYPE: - if (extensionTypeWriterFactory == null) { - throw new IllegalArgumentException("Must provide ExtensionTypeWriterFactory"); - } if (reader.isSet()) { Object value = reader.readObject(); if (value != null) { - writer.addExtensionTypeWriterFactory(extensionTypeWriterFactory); writer.writeExtension(value); } } else { diff --git a/vector/src/main/codegen/templates/PromotableWriter.java b/vector/src/main/codegen/templates/PromotableWriter.java index d22eb00b2c..83e6b39c55 100644 --- a/vector/src/main/codegen/templates/PromotableWriter.java +++ b/vector/src/main/codegen/templates/PromotableWriter.java @@ -286,7 +286,7 @@ protected void setWriter(ValueVector v) { writer = new UnionWriter((UnionVector) vector, nullableStructWriterFactory); break; case EXTENSIONTYPE: - writer = new UnionExtensionWriter((ExtensionTypeVector) vector); + writer = ((ExtensionType) vector.getField().getType()).getNewFieldWriter(vector); break; default: writer = type.getNewFieldWriter(vector); @@ -325,6 +325,9 @@ protected boolean requiresArrowType(MinorType type) { @Override protected FieldWriter getWriter(MinorType type, ArrowType arrowType) { + if(type == MinorType.EXTENSIONTYPE) { + lastExtensionType = arrowType; + } if (state == State.UNION) { if (requiresArrowType(type)) { ((UnionWriter) writer).getWriter(type, arrowType); @@ -540,18 +543,25 @@ public void writeLargeVarChar(String value) { getWriter(MinorType.LARGEVARCHAR).writeLargeVarChar(value); } + protected ArrowType lastExtensionType; + @Override public void writeExtension(Object value) { - getWriter(MinorType.EXTENSIONTYPE).writeExtension(value); + FieldWriter writer = getWriter(MinorType.EXTENSIONTYPE, lastExtensionType); + if(writer instanceof UnionWriter) { + ((UnionWriter) writer).writeExtension(value, lastExtensionType); + } else { + writer.writeExtension(value); + } } - @Override - public void addExtensionTypeWriterFactory(ExtensionTypeWriterFactory factory) { - getWriter(MinorType.EXTENSIONTYPE).addExtensionTypeWriterFactory(factory); + public void writeExtension(Object value, ArrowType arrowType) { + getWriter(MinorType.EXTENSIONTYPE, arrowType).writeExtension(value); } - public void addExtensionTypeWriterFactory(ExtensionTypeWriterFactory factory, ArrowType arrowType) { - getWriter(MinorType.EXTENSIONTYPE, arrowType).addExtensionTypeWriterFactory(factory); + @Override + public void write(ExtensionHolder holder) { + getWriter(MinorType.EXTENSIONTYPE, lastExtensionType).write(holder); } @Override diff --git a/vector/src/main/codegen/templates/UnionListWriter.java b/vector/src/main/codegen/templates/UnionListWriter.java index 3c41ac72b6..a01d6ece30 100644 --- a/vector/src/main/codegen/templates/UnionListWriter.java +++ b/vector/src/main/codegen/templates/UnionListWriter.java @@ -204,13 +204,13 @@ public MapWriter map(String name, boolean keysSorted) { @Override public ExtensionWriter extension(ArrowType arrowType) { - this.extensionType = arrowType; + extensionType = arrowType; return this; } + @Override public ExtensionWriter extension(String name, ArrowType arrowType) { - ExtensionWriter extensionWriter = writer.extension(name, arrowType); - return extensionWriter; + return writer.extension(name, arrowType); } <#if listName == "LargeList"> @@ -337,15 +337,10 @@ public void writeNull() { @Override public void writeExtension(Object value) { - writer.writeExtension(value); + writer.writeExtension(value, extensionType); writer.setPosition(writer.idx() + 1); } - @Override - public void addExtensionTypeWriterFactory(ExtensionTypeWriterFactory var1) { - writer.addExtensionTypeWriterFactory(var1, extensionType); - } - public void write(ExtensionHolder var1) { writer.write(var1); writer.setPosition(writer.idx() + 1); diff --git a/vector/src/main/codegen/templates/UnionReader.java b/vector/src/main/codegen/templates/UnionReader.java index 96ad3e1b9b..dcfe991ad6 100644 --- a/vector/src/main/codegen/templates/UnionReader.java +++ b/vector/src/main/codegen/templates/UnionReader.java @@ -79,6 +79,10 @@ public void read(int index, UnionHolder holder) { } private FieldReader getReaderForIndex(int index) { + return getReaderForIndex(index, null); + } + + private FieldReader getReaderForIndex(int index, ArrowType type) { int typeValue = data.getTypeValue(index); FieldReader reader = (FieldReader) readers[typeValue]; if (reader != null) { @@ -105,11 +109,26 @@ private FieldReader getReaderForIndex(int index) { + case EXTENSIONTYPE: + if(type == null) { + throw new RuntimeError("Cannot get Extension reader without an ArrowType"); + } + return (FieldReader) getExtension(type); default: throw new UnsupportedOperationException("Unsupported type: " + MinorType.values()[typeValue]); } } + private ExtensionReader extensionReader; + + private ExtensionReader getExtension(ArrowType type) { + if (extensionReader == null) { + extensionReader = data.getExtension(type).getReader(); + extensionReader.setPosition(idx()); + } + return extensionReader; + } + private SingleStructReaderImpl structReader; private StructReader getStruct() { @@ -240,4 +259,8 @@ public FieldReader reader() { public boolean next() { return getReaderForIndex(idx()).next(); } + + public void read(ExtensionHolder holder){ + getReaderForIndex(idx(), holder.type()).read(holder); + } } diff --git a/vector/src/main/codegen/templates/UnionVector.java b/vector/src/main/codegen/templates/UnionVector.java index 67efdf60f7..c706591966 100644 --- a/vector/src/main/codegen/templates/UnionVector.java +++ b/vector/src/main/codegen/templates/UnionVector.java @@ -379,6 +379,22 @@ public MapVector getMap(String name, ArrowType arrowType) { return mapVector; } + private ExtensionTypeVector extensionVector; + + public ExtensionTypeVector getExtension(ArrowType arrowType) { + if (extensionVector == null) { + int vectorCount = internalStruct.size(); + extensionVector = addOrGet(null, MinorType.EXTENSIONTYPE, arrowType, ExtensionTypeVector.class); + if (internalStruct.size() > vectorCount) { + extensionVector.allocateNew(); + if (callBack != null) { + callBack.doWork(); + } + } + } + return extensionVector; + } + public int getTypeValue(int index) { return typeBuffer.getByte(index * TYPE_WIDTH); } @@ -725,6 +741,8 @@ public ValueVector getVectorByType(int typeId, ArrowType arrowType) { return getListView(); case MAP: return getMap(name, arrowType); + case EXTENSIONTYPE: + return getExtension(arrowType); default: throw new UnsupportedOperationException("Cannot support type: " + MinorType.values()[typeId]); } diff --git a/vector/src/main/codegen/templates/UnionWriter.java b/vector/src/main/codegen/templates/UnionWriter.java index 272edab17c..0db699fd8c 100644 --- a/vector/src/main/codegen/templates/UnionWriter.java +++ b/vector/src/main/codegen/templates/UnionWriter.java @@ -28,6 +28,8 @@ package org.apache.arrow.vector.complex.impl; <#include "/@includes/vv_imports.ftl" /> +import java.util.HashMap; + import org.apache.arrow.vector.complex.writer.BaseWriter; import org.apache.arrow.vector.types.Types.MinorType; @@ -213,8 +215,31 @@ public MapWriter asMap(ArrowType arrowType) { return getMapWriter(arrowType); } + private java.util.Map extensionWriters = new HashMap<>(); + private ExtensionWriter getExtensionWriter(ArrowType arrowType) { - throw new UnsupportedOperationException("ExtensionTypes are not supported yet."); + ExtensionWriter w = extensionWriters.get(arrowType); + if (w == null) { + w = ((ExtensionType) arrowType).getNewFieldWriter(data.getExtension(arrowType)); + w.setPosition(idx()); + extensionWriters.put(arrowType, w); + } + return w; + } + + public void writeExtension(Object value, ArrowType type) { + data.setType(idx(), MinorType.EXTENSIONTYPE); + ExtensionWriter w = getExtensionWriter(type); + w.setPosition(idx()); + w.writeExtension(value); + } + + @Override + public void write(ExtensionHolder holder) { + data.setType(idx(), MinorType.EXTENSIONTYPE); + ExtensionWriter w = getExtensionWriter(holder.type()); + w.setPosition(idx()); + w.write(holder); } BaseWriter getWriter(MinorType minorType) { diff --git a/vector/src/main/java/org/apache/arrow/vector/BaseValueVector.java b/vector/src/main/java/org/apache/arrow/vector/BaseValueVector.java index cc57cde29e..37dfa20616 100644 --- a/vector/src/main/java/org/apache/arrow/vector/BaseValueVector.java +++ b/vector/src/main/java/org/apache/arrow/vector/BaseValueVector.java @@ -22,7 +22,6 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.ReferenceManager; import org.apache.arrow.util.Preconditions; -import org.apache.arrow.vector.complex.impl.ExtensionTypeWriterFactory; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.util.DataSizeRoundingUtil; import org.apache.arrow.vector.util.TransferPair; @@ -261,18 +260,6 @@ public void copyFromSafe(int fromIndex, int thisIndex, ValueVector from) { throw new UnsupportedOperationException(); } - @Override - public void copyFrom( - int fromIndex, int thisIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { - throw new UnsupportedOperationException(); - } - - @Override - public void copyFromSafe( - int fromIndex, int thisIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { - throw new UnsupportedOperationException(); - } - /** * Transfer the validity buffer from `validityBuffer` to the target vector's `validityBuffer`. * Start at `startIndex` and copy `length` number of elements. If the starting index is 8 byte diff --git a/vector/src/main/java/org/apache/arrow/vector/NullVector.java b/vector/src/main/java/org/apache/arrow/vector/NullVector.java index 0d6dab2837..6bfe540d23 100644 --- a/vector/src/main/java/org/apache/arrow/vector/NullVector.java +++ b/vector/src/main/java/org/apache/arrow/vector/NullVector.java @@ -27,7 +27,6 @@ import org.apache.arrow.memory.util.hash.ArrowBufHasher; import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.compare.VectorVisitor; -import org.apache.arrow.vector.complex.impl.ExtensionTypeWriterFactory; import org.apache.arrow.vector.complex.impl.NullReader; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.ipc.message.ArrowFieldNode; @@ -330,18 +329,6 @@ public void copyFromSafe(int fromIndex, int thisIndex, ValueVector from) { throw new UnsupportedOperationException(); } - @Override - public void copyFrom( - int fromIndex, int thisIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { - throw new UnsupportedOperationException(); - } - - @Override - public void copyFromSafe( - int fromIndex, int thisIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { - throw new UnsupportedOperationException(); - } - @Override public String getName() { return this.getField().getName(); diff --git a/vector/src/main/java/org/apache/arrow/vector/ValueVector.java b/vector/src/main/java/org/apache/arrow/vector/ValueVector.java index e0628c2ee1..3a5058256c 100644 --- a/vector/src/main/java/org/apache/arrow/vector/ValueVector.java +++ b/vector/src/main/java/org/apache/arrow/vector/ValueVector.java @@ -22,7 +22,6 @@ import org.apache.arrow.memory.OutOfMemoryException; import org.apache.arrow.memory.util.hash.ArrowBufHasher; import org.apache.arrow.vector.compare.VectorVisitor; -import org.apache.arrow.vector.complex.impl.ExtensionTypeWriterFactory; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.Field; @@ -310,30 +309,6 @@ public interface ValueVector extends Closeable, Iterable { */ void copyFromSafe(int fromIndex, int thisIndex, ValueVector from); - /** - * Copy a cell value from a particular index in source vector to a particular position in this - * vector. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - * @param writerFactory the extension type writer factory to use for copying extension type values - */ - void copyFrom( - int fromIndex, int thisIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory); - - /** - * Same as {@link #copyFrom(int, int, ValueVector)} except that it handles the case when the - * capacity of the vector needs to be expanded before copy. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - * @param writerFactory the extension type writer factory to use for copying extension type values - */ - void copyFromSafe( - int fromIndex, int thisIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory); - /** * Accept a generic {@link VectorVisitor} and return the result. * diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/AbstractContainerVector.java b/vector/src/main/java/org/apache/arrow/vector/complex/AbstractContainerVector.java index 429f9884bb..a6a71cf1a4 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/AbstractContainerVector.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/AbstractContainerVector.java @@ -21,7 +21,6 @@ import org.apache.arrow.vector.DensityAwareVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueVector; -import org.apache.arrow.vector.complex.impl.ExtensionTypeWriterFactory; import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.ArrowType.FixedSizeList; @@ -152,18 +151,6 @@ public void copyFromSafe(int fromIndex, int thisIndex, ValueVector from) { throw new UnsupportedOperationException(); } - @Override - public void copyFrom( - int fromIndex, int thisIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { - throw new UnsupportedOperationException(); - } - - @Override - public void copyFromSafe( - int fromIndex, int thisIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { - throw new UnsupportedOperationException(); - } - @Override public String getName() { return name; diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/LargeListVector.java b/vector/src/main/java/org/apache/arrow/vector/complex/LargeListVector.java index 48c8127e23..997b5a8b78 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/LargeListVector.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/LargeListVector.java @@ -49,7 +49,6 @@ import org.apache.arrow.vector.ZeroVector; import org.apache.arrow.vector.compare.VectorVisitor; import org.apache.arrow.vector.complex.impl.ComplexCopier; -import org.apache.arrow.vector.complex.impl.ExtensionTypeWriterFactory; import org.apache.arrow.vector.complex.impl.UnionLargeListReader; import org.apache.arrow.vector.complex.impl.UnionLargeListWriter; import org.apache.arrow.vector.complex.reader.FieldReader; @@ -483,42 +482,12 @@ public void copyFromSafe(int inIndex, int outIndex, ValueVector from) { */ @Override public void copyFrom(int inIndex, int outIndex, ValueVector from) { - copyFrom(inIndex, outIndex, from, null); - } - - /** - * Copy a cell value from a particular index in source vector to a particular position in this - * vector. - * - * @param inIndex position to copy from in source vector - * @param outIndex position to copy to in this vector - * @param from source vector - * @param writerFactory the extension type writer factory to use for copying extension type values - */ - @Override - public void copyFrom( - int inIndex, int outIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { Preconditions.checkArgument(this.getMinorType() == from.getMinorType()); FieldReader in = from.getReader(); in.setPosition(inIndex); UnionLargeListWriter out = getWriter(); out.setPosition(outIndex); - ComplexCopier.copy(in, out, writerFactory); - } - - /** - * Same as {@link #copyFrom(int, int, ValueVector)} except that it handles the case when the - * capacity of the vector needs to be expanded before copy. - * - * @param inIndex position to copy from in source vector - * @param outIndex position to copy to in this vector - * @param from source vector - * @param writerFactory the extension type writer factory to use for copying extension type values - */ - @Override - public void copyFromSafe( - int inIndex, int outIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { - copyFrom(inIndex, outIndex, from, writerFactory); + ComplexCopier.copy(in, out); } /** diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java b/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java index 992a664449..2da7eb057e 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java @@ -41,7 +41,6 @@ import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.ZeroVector; import org.apache.arrow.vector.compare.VectorVisitor; -import org.apache.arrow.vector.complex.impl.ExtensionTypeWriterFactory; import org.apache.arrow.vector.complex.impl.UnionLargeListViewReader; import org.apache.arrow.vector.complex.impl.UnionLargeListViewWriter; import org.apache.arrow.vector.complex.impl.UnionListReader; @@ -347,20 +346,6 @@ public void copyFrom(int inIndex, int outIndex, ValueVector from) { "LargeListViewVector does not support copyFrom operation yet."); } - @Override - public void copyFromSafe( - int inIndex, int outIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { - throw new UnsupportedOperationException( - "LargeListViewVector does not support copyFromSafe operation yet."); - } - - @Override - public void copyFrom( - int inIndex, int outIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { - throw new UnsupportedOperationException( - "LargeListViewVector does not support copyFrom operation yet."); - } - @Override public FieldVector getDataVector() { return vector; diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java b/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java index 89549257c4..93a313ef4f 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java @@ -42,7 +42,6 @@ import org.apache.arrow.vector.ZeroVector; import org.apache.arrow.vector.compare.VectorVisitor; import org.apache.arrow.vector.complex.impl.ComplexCopier; -import org.apache.arrow.vector.complex.impl.ExtensionTypeWriterFactory; import org.apache.arrow.vector.complex.impl.UnionListReader; import org.apache.arrow.vector.complex.impl.UnionListWriter; import org.apache.arrow.vector.complex.reader.FieldReader; @@ -401,42 +400,12 @@ public void copyFromSafe(int inIndex, int outIndex, ValueVector from) { */ @Override public void copyFrom(int inIndex, int outIndex, ValueVector from) { - copyFrom(inIndex, outIndex, from, null); - } - - /** - * Same as {@link #copyFrom(int, int, ValueVector)} except that it handles the case when the - * capacity of the vector needs to be expanded before copy. - * - * @param inIndex position to copy from in source vector - * @param outIndex position to copy to in this vector - * @param from source vector - * @param writerFactory the extension type writer factory to use for copying extension type values - */ - @Override - public void copyFromSafe( - int inIndex, int outIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { - copyFrom(inIndex, outIndex, from, writerFactory); - } - - /** - * Copy a cell value from a particular index in source vector to a particular position in this - * vector. - * - * @param inIndex position to copy from in source vector - * @param outIndex position to copy to in this vector - * @param from source vector - * @param writerFactory the extension type writer factory to use for copying extension type values - */ - @Override - public void copyFrom( - int inIndex, int outIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { Preconditions.checkArgument(this.getMinorType() == from.getMinorType()); FieldReader in = from.getReader(); in.setPosition(inIndex); FieldWriter out = getWriter(); out.setPosition(outIndex); - ComplexCopier.copy(in, out, writerFactory); + ComplexCopier.copy(in, out); } /** diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java b/vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java index 2784240429..8711db5e0f 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java @@ -42,7 +42,6 @@ import org.apache.arrow.vector.ZeroVector; import org.apache.arrow.vector.compare.VectorVisitor; import org.apache.arrow.vector.complex.impl.ComplexCopier; -import org.apache.arrow.vector.complex.impl.ExtensionTypeWriterFactory; import org.apache.arrow.vector.complex.impl.UnionListViewReader; import org.apache.arrow.vector.complex.impl.UnionListViewWriter; import org.apache.arrow.vector.complex.reader.FieldReader; @@ -339,12 +338,6 @@ public void copyFromSafe(int inIndex, int outIndex, ValueVector from) { copyFrom(inIndex, outIndex, from); } - @Override - public void copyFromSafe( - int inIndex, int outIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { - copyFrom(inIndex, outIndex, from, writerFactory); - } - @Override public OUT accept(VectorVisitor visitor, IN value) { return visitor.visit(this, value); @@ -352,18 +345,12 @@ public OUT accept(VectorVisitor visitor, IN value) { @Override public void copyFrom(int inIndex, int outIndex, ValueVector from) { - copyFrom(inIndex, outIndex, from, null); - } - - @Override - public void copyFrom( - int inIndex, int outIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { Preconditions.checkArgument(this.getMinorType() == from.getMinorType()); FieldReader in = from.getReader(); in.setPosition(inIndex); FieldWriter out = getWriter(); out.setPosition(outIndex); - ComplexCopier.copy(in, out, writerFactory); + ComplexCopier.copy(in, out); } @Override diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/impl/AbstractBaseReader.java b/vector/src/main/java/org/apache/arrow/vector/complex/impl/AbstractBaseReader.java index bf074ecb90..64c1f836c2 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/impl/AbstractBaseReader.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/impl/AbstractBaseReader.java @@ -117,12 +117,7 @@ public void copyAsValue(MapWriter writer) { } @Override - public void copyAsValue(ListWriter writer, ExtensionTypeWriterFactory writerFactory) { - ComplexCopier.copy(this, (FieldWriter) writer, writerFactory); - } - - @Override - public void copyAsValue(MapWriter writer, ExtensionTypeWriterFactory writerFactory) { - ComplexCopier.copy(this, (FieldWriter) writer, writerFactory); + public ExtensionTypeWriterFactory getExtensionTypeWriterFactory() { + throw new IllegalStateException("The current reader doesn't support reading extension type"); } } diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionExtensionWriter.java b/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionExtensionWriter.java index 4219069cba..52e63306d8 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionExtensionWriter.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionExtensionWriter.java @@ -59,12 +59,6 @@ public void writeExtension(Object var1) { this.writer.writeExtension(var1); } - @Override - public void addExtensionTypeWriterFactory(ExtensionTypeWriterFactory factory) { - this.writer = factory.getWriterImpl(vector); - this.writer.setPosition(idx()); - } - public void write(ExtensionHolder holder) { this.writer.write(holder); } @@ -79,6 +73,7 @@ public void setPosition(int index) { @Override public void writeNull() { - this.writer.writeNull(); + this.vector.setNull(getPosition()); + this.vector.setValueCount(getPosition() + 1); } } diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionLargeListReader.java b/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionLargeListReader.java index a9104cb0d2..be236c3166 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionLargeListReader.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionLargeListReader.java @@ -105,8 +105,4 @@ public boolean next() { public void copyAsValue(UnionLargeListWriter writer) { ComplexCopier.copy(this, (FieldWriter) writer); } - - public void copyAsValue(UnionLargeListWriter writer, ExtensionTypeWriterFactory writerFactory) { - ComplexCopier.copy(this, (FieldWriter) writer, writerFactory); - } } diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/reader/ExtensionReader.java b/vector/src/main/java/org/apache/arrow/vector/complex/reader/ExtensionReader.java index 1ba7b27156..406e080d1d 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/reader/ExtensionReader.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/reader/ExtensionReader.java @@ -16,6 +16,7 @@ */ package org.apache.arrow.vector.complex.reader; +import org.apache.arrow.vector.complex.impl.ExtensionTypeWriterFactory; import org.apache.arrow.vector.holders.ExtensionHolder; /** Interface for reading extension types. Extends the functionality of {@link BaseReader}. */ @@ -41,4 +42,15 @@ public interface ExtensionReader extends BaseReader { * @return true if the value is set, false otherwise */ boolean isSet(); + + /** + * Gets the extension type writer factory associated with this reader. + * + *

The writer factory is used to create appropriate writers when copying extension type values + * to another vector. This allows the copy operation to preserve the extension type semantics. + * + * @return the extension type writer factory + * @throws IllegalStateException if the reader doesn't support extension types + */ + ExtensionTypeWriterFactory getExtensionTypeWriterFactory(); } diff --git a/vector/src/main/java/org/apache/arrow/vector/extension/OpaqueType.java b/vector/src/main/java/org/apache/arrow/vector/extension/OpaqueType.java index ca56214fda..780a4ee659 100644 --- a/vector/src/main/java/org/apache/arrow/vector/extension/OpaqueType.java +++ b/vector/src/main/java/org/apache/arrow/vector/extension/OpaqueType.java @@ -54,10 +54,12 @@ import org.apache.arrow.vector.TimeStampNanoVector; import org.apache.arrow.vector.TimeStampSecTZVector; import org.apache.arrow.vector.TimeStampSecVector; +import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.ViewVarBinaryVector; import org.apache.arrow.vector.ViewVarCharVector; +import org.apache.arrow.vector.complex.writer.FieldWriter; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.ExtensionTypeRegistry; @@ -177,6 +179,11 @@ public int hashCode() { return Objects.hash(super.hashCode(), storageType, typeName, vendorName); } + @Override + public FieldWriter getNewFieldWriter(ValueVector vector) { + throw new UnsupportedOperationException("WriterImpl not yet implemented."); + } + @Override public String toString() { return "OpaqueType(" diff --git a/vector/src/main/java/org/apache/arrow/vector/holders/ExtensionHolder.java b/vector/src/main/java/org/apache/arrow/vector/holders/ExtensionHolder.java index fc7ed85878..4d3f767aef 100644 --- a/vector/src/main/java/org/apache/arrow/vector/holders/ExtensionHolder.java +++ b/vector/src/main/java/org/apache/arrow/vector/holders/ExtensionHolder.java @@ -16,7 +16,11 @@ */ package org.apache.arrow.vector.holders; +import org.apache.arrow.vector.types.pojo.ArrowType; + /** Base {@link ValueHolder} class for a {@link org.apache.arrow.vector.ExtensionTypeVector}. */ public abstract class ExtensionHolder implements ValueHolder { public int isSet; + + public abstract ArrowType type(); } diff --git a/vector/src/test/java/org/apache/arrow/vector/TestLargeListVector.java b/vector/src/test/java/org/apache/arrow/vector/TestLargeListVector.java index d5cbf925b2..0eb8f5c066 100644 --- a/vector/src/test/java/org/apache/arrow/vector/TestLargeListVector.java +++ b/vector/src/test/java/org/apache/arrow/vector/TestLargeListVector.java @@ -23,20 +23,26 @@ import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertTrue; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.UUID; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.complex.BaseRepeatedValueVector; import org.apache.arrow.vector.complex.LargeListVector; import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.impl.UnionLargeListReader; import org.apache.arrow.vector.complex.impl.UnionLargeListWriter; import org.apache.arrow.vector.complex.reader.FieldReader; +import org.apache.arrow.vector.complex.writer.BaseWriter.ExtensionWriter; +import org.apache.arrow.vector.holder.UuidHolder; import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.UuidType; import org.apache.arrow.vector.util.TransferPair; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -1021,6 +1027,83 @@ public void testGetTransferPairWithField() throws Exception { } } + @Test + public void testCopyValueSafeForExtensionType() throws Exception { + try (LargeListVector inVector = LargeListVector.empty("input", allocator); + LargeListVector outVector = LargeListVector.empty("output", allocator)) { + UnionLargeListWriter writer = inVector.getWriter(); + writer.allocate(); + + // Create first list with UUIDs + writer.setPosition(0); + UUID u1 = UUID.randomUUID(); + UUID u2 = UUID.randomUUID(); + writer.startList(); + ExtensionWriter extensionWriter = writer.extension(UuidType.INSTANCE); + extensionWriter.writeExtension(u1); + extensionWriter.writeExtension(u2); + writer.endList(); + + // Create second list with UUIDs + writer.setPosition(1); + UUID u3 = UUID.randomUUID(); + UUID u4 = UUID.randomUUID(); + writer.startList(); + extensionWriter = writer.extension(UuidType.INSTANCE); + extensionWriter.writeExtension(u3); + extensionWriter.writeExtension(u4); + extensionWriter.writeNull(); + + writer.endList(); + writer.setValueCount(2); + + // Use copyFromSafe with ExtensionTypeWriterFactory + // This internally calls TransferImpl.copyValueSafe with ExtensionTypeWriterFactory + outVector.allocateNew(); + TransferPair tp = inVector.makeTransferPair(outVector); + tp.copyValueSafe(0, 0); + tp.copyValueSafe(1, 1); + outVector.setValueCount(2); + + // Verify first list + UnionLargeListReader reader = outVector.getReader(); + reader.setPosition(0); + assertTrue(reader.isSet(), "first list shouldn't be null"); + reader.next(); + FieldReader uuidReader = reader.reader(); + UuidHolder holder = new UuidHolder(); + uuidReader.read(holder); + ByteBuffer bb = ByteBuffer.wrap(holder.value); + UUID actualUuid = new UUID(bb.getLong(), bb.getLong()); + assertEquals(u1, actualUuid); + reader.next(); + uuidReader = reader.reader(); + uuidReader.read(holder); + bb = ByteBuffer.wrap(holder.value); + actualUuid = new UUID(bb.getLong(), bb.getLong()); + assertEquals(u2, actualUuid); + + // Verify second list + reader.setPosition(1); + assertTrue(reader.isSet(), "second list shouldn't be null"); + reader.next(); + uuidReader = reader.reader(); + uuidReader.read(holder); + bb = ByteBuffer.wrap(holder.value); + actualUuid = new UUID(bb.getLong(), bb.getLong()); + assertEquals(u3, actualUuid); + reader.next(); + uuidReader = reader.reader(); + uuidReader.read(holder); + bb = ByteBuffer.wrap(holder.value); + actualUuid = new UUID(bb.getLong(), bb.getLong()); + assertEquals(u4, actualUuid); + reader.next(); + uuidReader = reader.reader(); + assertFalse(uuidReader.isSet(), "third element should be null"); + } + } + private void writeIntValues(UnionLargeListWriter writer, int[] values) { writer.startList(); for (int v : values) { diff --git a/vector/src/test/java/org/apache/arrow/vector/TestListVector.java b/vector/src/test/java/org/apache/arrow/vector/TestListVector.java index c6c7c5c862..c206ed2261 100644 --- a/vector/src/test/java/org/apache/arrow/vector/TestListVector.java +++ b/vector/src/test/java/org/apache/arrow/vector/TestListVector.java @@ -36,7 +36,6 @@ import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.impl.UnionListReader; import org.apache.arrow.vector.complex.impl.UnionListWriter; -import org.apache.arrow.vector.complex.impl.UuidWriterFactory; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.complex.writer.BaseWriter.ExtensionWriter; import org.apache.arrow.vector.holder.UuidHolder; @@ -1216,8 +1215,7 @@ public void testListVectorWithExtensionType() throws Exception { UUID u1 = UUID.randomUUID(); UUID u2 = UUID.randomUUID(); writer.startList(); - ExtensionWriter extensionWriter = writer.extension(new UuidType()); - extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); + ExtensionWriter extensionWriter = writer.extension(UuidType.INSTANCE); extensionWriter.writeExtension(u1); extensionWriter.writeExtension(u2); writer.endList(); @@ -1244,8 +1242,7 @@ public void testListVectorReaderForExtensionType() throws Exception { UUID u1 = UUID.randomUUID(); UUID u2 = UUID.randomUUID(); writer.startList(); - ExtensionWriter extensionWriter = writer.extension(new UuidType()); - extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); + ExtensionWriter extensionWriter = writer.extension(UuidType.INSTANCE); extensionWriter.writeExtension(u1); extensionWriter.writeExtension(u2); writer.endList(); @@ -1281,19 +1278,18 @@ public void testCopyFromForExtensionType() throws Exception { UUID u1 = UUID.randomUUID(); UUID u2 = UUID.randomUUID(); writer.startList(); - ExtensionWriter extensionWriter = writer.extension(new UuidType()); - extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); - extensionWriter.writeExtension(u1); - extensionWriter.writeExtension(u2); - extensionWriter.writeNull(); + + writer.extension(UuidType.INSTANCE).writeExtension(u1); + writer.writeExtension(u2); + writer.writeNull(); writer.endList(); - writer.setValueCount(1); + writer.setValueCount(3); // copy values from input to output outVector.allocateNew(); - outVector.copyFrom(0, 0, inVector, new UuidWriterFactory()); - outVector.setValueCount(1); + outVector.copyFrom(0, 0, inVector); + outVector.setValueCount(3); UnionListReader reader = outVector.getReader(); assertTrue(reader.isSet(), "shouldn't be null"); @@ -1314,6 +1310,83 @@ public void testCopyFromForExtensionType() throws Exception { } } + @Test + public void testCopyValueSafeForExtensionType() throws Exception { + try (ListVector inVector = ListVector.empty("input", allocator); + ListVector outVector = ListVector.empty("output", allocator)) { + UnionListWriter writer = inVector.getWriter(); + writer.allocate(); + + // Create first list with UUIDs + writer.setPosition(0); + UUID u1 = UUID.randomUUID(); + UUID u2 = UUID.randomUUID(); + writer.startList(); + ExtensionWriter extensionWriter = writer.extension(UuidType.INSTANCE); + extensionWriter.writeExtension(u1); + extensionWriter.writeExtension(u2); + writer.endList(); + + // Create second list with UUIDs + writer.setPosition(1); + UUID u3 = UUID.randomUUID(); + UUID u4 = UUID.randomUUID(); + writer.startList(); + extensionWriter = writer.extension(UuidType.INSTANCE); + extensionWriter.writeExtension(u3); + extensionWriter.writeExtension(u4); + extensionWriter.writeNull(); + + writer.endList(); + writer.setValueCount(2); + + // Use TransferPair with ExtensionTypeWriterFactory + // This tests the new makeTransferPair API with writerFactory parameter + outVector.allocateNew(); + TransferPair transferPair = inVector.makeTransferPair(outVector); + transferPair.copyValueSafe(0, 0); + transferPair.copyValueSafe(1, 1); + outVector.setValueCount(2); + + // Verify first list + UnionListReader reader = outVector.getReader(); + reader.setPosition(0); + assertTrue(reader.isSet(), "first list shouldn't be null"); + reader.next(); + FieldReader uuidReader = reader.reader(); + UuidHolder holder = new UuidHolder(); + uuidReader.read(holder); + ByteBuffer bb = ByteBuffer.wrap(holder.value); + UUID actualUuid = new UUID(bb.getLong(), bb.getLong()); + assertEquals(u1, actualUuid); + reader.next(); + uuidReader = reader.reader(); + uuidReader.read(holder); + bb = ByteBuffer.wrap(holder.value); + actualUuid = new UUID(bb.getLong(), bb.getLong()); + assertEquals(u2, actualUuid); + + // Verify second list + reader.setPosition(1); + assertTrue(reader.isSet(), "second list shouldn't be null"); + reader.next(); + uuidReader = reader.reader(); + uuidReader.read(holder); + bb = ByteBuffer.wrap(holder.value); + actualUuid = new UUID(bb.getLong(), bb.getLong()); + assertEquals(u3, actualUuid); + reader.next(); + uuidReader = reader.reader(); + uuidReader.read(holder); + bb = ByteBuffer.wrap(holder.value); + actualUuid = new UUID(bb.getLong(), bb.getLong()); + assertEquals(u4, actualUuid); + reader.next(); + uuidReader = reader.reader(); + assertFalse(uuidReader.isSet(), "third element should be null"); + } + } + private void writeIntValues(UnionListWriter writer, int[] values) { writer.startList(); for (int v : values) { diff --git a/vector/src/test/java/org/apache/arrow/vector/TestMapVector.java b/vector/src/test/java/org/apache/arrow/vector/TestMapVector.java index 1a1810d0f7..125a243541 100644 --- a/vector/src/test/java/org/apache/arrow/vector/TestMapVector.java +++ b/vector/src/test/java/org/apache/arrow/vector/TestMapVector.java @@ -34,7 +34,6 @@ import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.complex.impl.UnionMapReader; import org.apache.arrow.vector.complex.impl.UnionMapWriter; -import org.apache.arrow.vector.complex.impl.UuidWriterFactory; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.complex.writer.BaseWriter.ExtensionWriter; import org.apache.arrow.vector.complex.writer.BaseWriter.ListWriter; @@ -1281,14 +1280,12 @@ public void testMapVectorWithExtensionType() throws Exception { writer.startMap(); writer.startEntry(); writer.key().bigInt().writeBigInt(0); - ExtensionWriter extensionWriter = writer.value().extension(new UuidType()); - extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); + ExtensionWriter extensionWriter = writer.value().extension(UuidType.INSTANCE); extensionWriter.writeExtension(u1); writer.endEntry(); writer.startEntry(); writer.key().bigInt().writeBigInt(1); - extensionWriter = writer.value().extension(new UuidType()); - extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); + extensionWriter = writer.value().extension(UuidType.INSTANCE); extensionWriter.writeExtension(u2); writer.endEntry(); writer.endMap(); @@ -1325,21 +1322,19 @@ public void testCopyFromForExtensionType() throws Exception { writer.startMap(); writer.startEntry(); writer.key().bigInt().writeBigInt(0); - ExtensionWriter extensionWriter = writer.value().extension(new UuidType()); - extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); + ExtensionWriter extensionWriter = writer.value().extension(UuidType.INSTANCE); extensionWriter.writeExtension(u1); writer.endEntry(); writer.startEntry(); writer.key().bigInt().writeBigInt(1); - extensionWriter = writer.value().extension(new UuidType()); - extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); + extensionWriter = writer.value().extension(UuidType.INSTANCE); extensionWriter.writeExtension(u2); writer.endEntry(); writer.endMap(); writer.setValueCount(1); outVector.allocateNew(); - outVector.copyFrom(0, 0, inVector, new UuidWriterFactory()); + outVector.copyFrom(0, 0, inVector); outVector.setValueCount(1); UnionMapReader mapReader = outVector.getReader(); diff --git a/vector/src/test/java/org/apache/arrow/vector/TestStructVector.java b/vector/src/test/java/org/apache/arrow/vector/TestStructVector.java index d40af9ae89..307c636270 100644 --- a/vector/src/test/java/org/apache/arrow/vector/TestStructVector.java +++ b/vector/src/test/java/org/apache/arrow/vector/TestStructVector.java @@ -160,17 +160,23 @@ public void testGetPrimitiveVectors() { UnionVector unionVector = vector.addOrGetUnion("union"); unionVector.addVector(new BigIntVector("bigInt", allocator)); unionVector.addVector(new SmallIntVector("smallInt", allocator)); + unionVector.addVector(new UuidVector("uuid", allocator)); // add varchar vector vector.addOrGet( "varchar", FieldType.nullable(MinorType.VARCHAR.getType()), VarCharVector.class); + // add extension vector + vector.addOrGet("extension", FieldType.nullable(UuidType.INSTANCE), UuidVector.class); + List primitiveVectors = vector.getPrimitiveVectors(); - assertEquals(4, primitiveVectors.size()); + assertEquals(6, primitiveVectors.size()); assertEquals(MinorType.INT, primitiveVectors.get(0).getMinorType()); assertEquals(MinorType.BIGINT, primitiveVectors.get(1).getMinorType()); assertEquals(MinorType.SMALLINT, primitiveVectors.get(2).getMinorType()); - assertEquals(MinorType.VARCHAR, primitiveVectors.get(3).getMinorType()); + assertEquals(MinorType.EXTENSIONTYPE, primitiveVectors.get(3).getMinorType()); + assertEquals(MinorType.VARCHAR, primitiveVectors.get(4).getMinorType()); + assertEquals(MinorType.EXTENSIONTYPE, primitiveVectors.get(5).getMinorType()); } } diff --git a/vector/src/test/java/org/apache/arrow/vector/UuidVector.java b/vector/src/test/java/org/apache/arrow/vector/UuidVector.java index 72ba4aa555..d64be54c36 100644 --- a/vector/src/test/java/org/apache/arrow/vector/UuidVector.java +++ b/vector/src/test/java/org/apache/arrow/vector/UuidVector.java @@ -22,6 +22,7 @@ import org.apache.arrow.memory.util.hash.ArrowBufHasher; import org.apache.arrow.vector.complex.impl.UuidReaderImpl; import org.apache.arrow.vector.complex.reader.FieldReader; +import org.apache.arrow.vector.holder.NullableUuidHolder; import org.apache.arrow.vector.holder.UuidHolder; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; @@ -97,6 +98,11 @@ public void get(int index, UuidHolder holder) { holder.isSet = 1; } + public void get(int index, NullableUuidHolder holder) { + holder.value = getUnderlyingVector().get(index); + holder.isSet = 1; + } + public class TransferImpl implements TransferPair { UuidVector to; ValueVector targetUnderlyingVector; diff --git a/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestComplexCopier.java b/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestComplexCopier.java index 738e8905e3..cecb42b92c 100644 --- a/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestComplexCopier.java +++ b/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestComplexCopier.java @@ -860,8 +860,7 @@ public void testCopyListVectorWithExtensionType() { for (int i = 0; i < COUNT; i++) { listWriter.setPosition(i); listWriter.startList(); - ExtensionWriter extensionWriter = listWriter.extension(new UuidType()); - extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); + ExtensionWriter extensionWriter = listWriter.extension(UuidType.INSTANCE); extensionWriter.writeExtension(UUID.randomUUID()); extensionWriter.writeExtension(UUID.randomUUID()); listWriter.endList(); @@ -874,7 +873,7 @@ public void testCopyListVectorWithExtensionType() { for (int i = 0; i < COUNT; i++) { in.setPosition(i); out.setPosition(i); - ComplexCopier.copy(in, out, new UuidWriterFactory()); + ComplexCopier.copy(in, out); } to.setValueCount(COUNT); @@ -896,11 +895,9 @@ public void testCopyMapVectorWithExtensionType() { mapWriter.setPosition(i); mapWriter.startMap(); mapWriter.startEntry(); - ExtensionWriter extensionKeyWriter = mapWriter.key().extension(new UuidType()); - extensionKeyWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); + ExtensionWriter extensionKeyWriter = mapWriter.key().extension(UuidType.INSTANCE); extensionKeyWriter.writeExtension(UUID.randomUUID()); - ExtensionWriter extensionValueWriter = mapWriter.value().extension(new UuidType()); - extensionValueWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); + ExtensionWriter extensionValueWriter = mapWriter.value().extension(UuidType.INSTANCE); extensionValueWriter.writeExtension(UUID.randomUUID()); mapWriter.endEntry(); mapWriter.endMap(); @@ -914,7 +911,7 @@ public void testCopyMapVectorWithExtensionType() { for (int i = 0; i < COUNT; i++) { in.setPosition(i); out.setPosition(i); - ComplexCopier.copy(in, out, new UuidWriterFactory()); + ComplexCopier.copy(in, out); } to.setValueCount(COUNT); @@ -934,11 +931,9 @@ public void testCopyStructVectorWithExtensionType() { for (int i = 0; i < COUNT; i++) { structWriter.setPosition(i); structWriter.start(); - ExtensionWriter extensionWriter1 = structWriter.extension("timestamp1", new UuidType()); - extensionWriter1.addExtensionTypeWriterFactory(new UuidWriterFactory()); + ExtensionWriter extensionWriter1 = structWriter.extension("uuid1", UuidType.INSTANCE); extensionWriter1.writeExtension(UUID.randomUUID()); - ExtensionWriter extensionWriter2 = structWriter.extension("timestamp2", new UuidType()); - extensionWriter2.addExtensionTypeWriterFactory(new UuidWriterFactory()); + ExtensionWriter extensionWriter2 = structWriter.extension("uuid2", UuidType.INSTANCE); extensionWriter2.writeExtension(UUID.randomUUID()); structWriter.end(); } @@ -951,7 +946,7 @@ public void testCopyStructVectorWithExtensionType() { for (int i = 0; i < COUNT; i++) { in.setPosition(i); out.setPosition(i); - ComplexCopier.copy(in, out, new UuidWriterFactory()); + ComplexCopier.copy(in, out); } to.setValueCount(COUNT); diff --git a/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java b/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java index 7b8b1f9ef9..3a5f6b2954 100644 --- a/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java +++ b/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java @@ -31,6 +31,7 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.DirtyRootAllocator; +import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.LargeVarBinaryVector; import org.apache.arrow.vector.LargeVarCharVector; import org.apache.arrow.vector.UuidVector; @@ -41,6 +42,7 @@ import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.complex.UnionVector; import org.apache.arrow.vector.complex.writer.BaseWriter.StructWriter; +import org.apache.arrow.vector.holder.UuidHolder; import org.apache.arrow.vector.holders.DurationHolder; import org.apache.arrow.vector.holders.FixedSizeBinaryHolder; import org.apache.arrow.vector.holders.NullableDecimalHolder; @@ -100,7 +102,6 @@ public void testPromoteToUnion() throws Exception { writer.integer("A").writeInt(10); // we don't write anything in 3 - writer.setPosition(4); writer.integer("A").writeInt(100); @@ -130,9 +131,23 @@ public void testPromoteToUnion() throws Exception { binHolder.buffer = buf; writer.fixedSizeBinary("A", 4).write(binHolder); + writer.setPosition(9); + UUID uuid = UUID.randomUUID(); + writer.extension("A", UuidType.INSTANCE).writeExtension(uuid); + writer.end(); + + writer.setPosition(10); + UUID uuid2 = UUID.randomUUID(); + UuidHolder uuidHolder = new UuidHolder(); + uuidHolder.value = + ByteBuffer.allocate(16) + .putLong(uuid2.getMostSignificantBits()) + .putLong(uuid2.getLeastSignificantBits()) + .array(); + writer.extension("A", UuidType.INSTANCE).write(uuidHolder); writer.end(); - container.setValueCount(9); + container.setValueCount(11); final UnionVector uv = v.getChild("A", UnionVector.class); @@ -169,6 +184,12 @@ public void testPromoteToUnion() throws Exception { .order(ByteOrder.nativeOrder()) .getInt()); + assertFalse(uv.isNull(9), "9 shouldn't be null"); + assertEquals(uuid, uv.getObject(9)); + + assertFalse(uv.isNull(10), "10 shouldn't be null"); + assertEquals(uuid2, uv.getObject(10)); + container.clear(); container.allocateNew(); @@ -785,13 +806,12 @@ public void testExtensionType() throws Exception { try (final NonNullableStructVector container = NonNullableStructVector.empty(EMPTY_SCHEMA_PATH, allocator); final UuidVector v = - container.addOrGet("uuid", FieldType.nullable(new UuidType()), UuidVector.class); + container.addOrGet("uuid", FieldType.nullable(UuidType.INSTANCE), UuidVector.class); final PromotableWriter writer = new PromotableWriter(v, container)) { UUID u1 = UUID.randomUUID(); UUID u2 = UUID.randomUUID(); container.allocateNew(); container.setValueCount(1); - writer.addExtensionTypeWriterFactory(new UuidWriterFactory()); writer.setPosition(0); writer.writeExtension(u1); @@ -810,13 +830,13 @@ public void testExtensionType() throws Exception { public void testExtensionTypeForList() throws Exception { try (final ListVector container = ListVector.empty(EMPTY_SCHEMA_PATH, allocator); final UuidVector v = - (UuidVector) container.addOrGetVector(FieldType.nullable(new UuidType())).getVector(); + (UuidVector) + container.addOrGetVector(FieldType.nullable(UuidType.INSTANCE)).getVector(); final PromotableWriter writer = new PromotableWriter(v, container)) { UUID u1 = UUID.randomUUID(); UUID u2 = UUID.randomUUID(); container.allocateNew(); container.setValueCount(1); - writer.addExtensionTypeWriterFactory(new UuidWriterFactory()); writer.setPosition(0); writer.writeExtension(u1); @@ -825,7 +845,7 @@ public void testExtensionTypeForList() throws Exception { container.setValueCount(2); - UuidVector uuidVector = (UuidVector) container.getDataVector(); + FieldVector uuidVector = container.getDataVector(); assertEquals(u1, uuidVector.getObject(0)); assertEquals(u2, uuidVector.getObject(1)); } diff --git a/vector/src/test/java/org/apache/arrow/vector/complex/impl/UuidReaderImpl.java b/vector/src/test/java/org/apache/arrow/vector/complex/impl/UuidReaderImpl.java index 6b98d3b340..2e5377ac91 100644 --- a/vector/src/test/java/org/apache/arrow/vector/complex/impl/UuidReaderImpl.java +++ b/vector/src/test/java/org/apache/arrow/vector/complex/impl/UuidReaderImpl.java @@ -17,6 +17,7 @@ package org.apache.arrow.vector.complex.impl; import org.apache.arrow.vector.UuidVector; +import org.apache.arrow.vector.holder.NullableUuidHolder; import org.apache.arrow.vector.holder.UuidHolder; import org.apache.arrow.vector.holders.ExtensionHolder; import org.apache.arrow.vector.types.Types.MinorType; @@ -46,9 +47,18 @@ public boolean isSet() { return !vector.isNull(idx()); } - @Override public void read(ExtensionHolder holder) { - vector.get(idx(), (UuidHolder) holder); + if (holder instanceof NullableUuidHolder) { + vector.get(idx(), (NullableUuidHolder) holder); + } else if (holder instanceof UuidHolder) { + vector.get(idx(), (UuidHolder) holder); + } else { + throw new IllegalArgumentException("Holder type not supported"); + } + } + + public void read(NullableUuidHolder holder) { + vector.get(idx(), holder); } @Override @@ -66,4 +76,9 @@ public void copyAsValue(AbstractExtensionTypeWriter writer) { public Object readObject() { return vector.getObject(idx()); } + + @Override + public ExtensionTypeWriterFactory getExtensionTypeWriterFactory() { + return new UuidWriterFactory(); + } } diff --git a/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java b/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java index f374eb41e4..b9f1500079 100644 --- a/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java +++ b/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java @@ -66,7 +66,6 @@ import org.apache.arrow.vector.complex.impl.UnionMapReader; import org.apache.arrow.vector.complex.impl.UnionReader; import org.apache.arrow.vector.complex.impl.UnionWriter; -import org.apache.arrow.vector.complex.impl.UuidWriterFactory; import org.apache.arrow.vector.complex.reader.BaseReader.StructReader; import org.apache.arrow.vector.complex.reader.BigIntReader; import org.apache.arrow.vector.complex.reader.FieldReader; @@ -78,6 +77,7 @@ import org.apache.arrow.vector.complex.writer.BaseWriter.ListWriter; import org.apache.arrow.vector.complex.writer.BaseWriter.MapWriter; import org.apache.arrow.vector.complex.writer.BaseWriter.StructWriter; +import org.apache.arrow.vector.holder.NullableUuidHolder; import org.apache.arrow.vector.holder.UuidHolder; import org.apache.arrow.vector.holders.DecimalHolder; import org.apache.arrow.vector.holders.DurationHolder; @@ -1105,6 +1105,13 @@ public void simpleUnion() throws Exception { new UnionVector("union", allocator, /* field type */ null, /* call-back */ null); UnionWriter unionWriter = new UnionWriter(vector); unionWriter.allocate(); + + UUID uuid = UUID.randomUUID(); + ByteBuffer bb = ByteBuffer.allocate(16); + bb.putLong(uuid.getMostSignificantBits()); + bb.putLong(uuid.getLeastSignificantBits()); + byte[] uuidByte = bb.array(); + for (int i = 0; i < COUNT; i++) { unionWriter.setPosition(i); if (i % 5 == 0) { @@ -1127,6 +1134,11 @@ public void simpleUnion() throws Exception { holder.buffer = buf; unionWriter.write(holder); bufs.add(buf); + } else if (i % 5 == 4) { + UuidHolder holder = new UuidHolder(); + + holder.value = uuidByte; + unionWriter.write(holder); } else { unionWriter.writeFloat4((float) i); } @@ -1152,6 +1164,13 @@ public void simpleUnion() throws Exception { unionReader.read(holder); assertEquals(i, holder.buffer.getInt(0)); assertEquals(4, holder.byteWidth); + } else if (i % 5 == 4) { + NullableUuidHolder holder = new NullableUuidHolder(); + unionReader.read(holder); + ByteBuffer b = ByteBuffer.wrap(holder.value); + long high = b.getLong(); + long low = b.getLong(); + assertEquals(new UUID(high, low), uuid); } else { assertEquals((float) i, unionReader.readFloat(), 1e-12); } @@ -2509,9 +2528,8 @@ public void extensionWriterReader() throws Exception { StructWriter rootWriter = writer.rootAsStruct(); { - ExtensionWriter extensionWriter = rootWriter.extension("uuid1", new UuidType()); + ExtensionWriter extensionWriter = rootWriter.extension("uuid1", UuidType.INSTANCE); extensionWriter.setPosition(0); - extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); extensionWriter.writeExtension(u1); } // read diff --git a/vector/src/test/java/org/apache/arrow/vector/holder/NullableUuidHolder.java b/vector/src/test/java/org/apache/arrow/vector/holder/NullableUuidHolder.java new file mode 100644 index 0000000000..5b5061e42d --- /dev/null +++ b/vector/src/test/java/org/apache/arrow/vector/holder/NullableUuidHolder.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.vector.holder; + +import org.apache.arrow.vector.holders.ExtensionHolder; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.UuidType; + +public class NullableUuidHolder extends ExtensionHolder { + public byte[] value; + + @Override + public ArrowType type() { + return UuidType.INSTANCE; + } +} diff --git a/vector/src/test/java/org/apache/arrow/vector/holder/UuidHolder.java b/vector/src/test/java/org/apache/arrow/vector/holder/UuidHolder.java index 207b0951a7..301d8eddd1 100644 --- a/vector/src/test/java/org/apache/arrow/vector/holder/UuidHolder.java +++ b/vector/src/test/java/org/apache/arrow/vector/holder/UuidHolder.java @@ -17,7 +17,14 @@ package org.apache.arrow.vector.holder; import org.apache.arrow.vector.holders.ExtensionHolder; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.UuidType; public class UuidHolder extends ExtensionHolder { public byte[] value; + + @Override + public ArrowType type() { + return UuidType.INSTANCE; + } } diff --git a/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java b/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java index d24708d66c..3d6f9ff73d 100644 --- a/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java +++ b/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java @@ -43,10 +43,12 @@ import org.apache.arrow.vector.Float4Vector; import org.apache.arrow.vector.UuidVector; import org.apache.arrow.vector.ValueIterableVector; +import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.compare.Range; import org.apache.arrow.vector.compare.RangeEqualsVisitor; import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.complex.writer.FieldWriter; import org.apache.arrow.vector.ipc.ArrowFileReader; import org.apache.arrow.vector.ipc.ArrowFileWriter; import org.apache.arrow.vector.types.FloatingPointPrecision; @@ -331,6 +333,11 @@ public String serialize() { public FieldVector getNewVector(String name, FieldType fieldType, BufferAllocator allocator) { return new LocationVector(name, allocator); } + + @Override + public FieldWriter getNewFieldWriter(ValueVector vector) { + throw new UnsupportedOperationException("Not yet implemented."); + } } public static class LocationVector extends ExtensionTypeVector diff --git a/vector/src/test/java/org/apache/arrow/vector/types/pojo/UuidType.java b/vector/src/test/java/org/apache/arrow/vector/types/pojo/UuidType.java index 5e2bd8881b..dbd88927ce 100644 --- a/vector/src/test/java/org/apache/arrow/vector/types/pojo/UuidType.java +++ b/vector/src/test/java/org/apache/arrow/vector/types/pojo/UuidType.java @@ -20,9 +20,13 @@ import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.FixedSizeBinaryVector; import org.apache.arrow.vector.UuidVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.complex.impl.UuidWriterImpl; +import org.apache.arrow.vector.complex.writer.FieldWriter; import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType; public class UuidType extends ExtensionType { + public static final UuidType INSTANCE = new UuidType(); @Override public ArrowType storageType() { @@ -57,4 +61,9 @@ public String serialize() { public FieldVector getNewVector(String name, FieldType fieldType, BufferAllocator allocator) { return new UuidVector(name, allocator, new FixedSizeBinaryVector(name, allocator, 16)); } + + @Override + public FieldWriter getNewFieldWriter(ValueVector vector) { + return new UuidWriterImpl((UuidVector) vector); + } }