Skip to content

Commit 244bcff

Browse files
gengliangwanggatorsmile
authored andcommitted
[SPARK-24811][SQL] Avro: add new function from_avro and to_avro
## What changes were proposed in this pull request? Add a new function from_avro for parsing a binary column of avro format and converting it into its corresponding catalyst value. Add a new function to_avro for converting a column into binary of avro format with the specified schema. This PR is in progress. Will add test cases. ## How was this patch tested? Author: Gengliang Wang <[email protected]> Closes #21774 from gengliangwang/from_and_to_avro.
1 parent cc4d64b commit 244bcff

File tree

6 files changed

+432
-0
lines changed

6 files changed

+432
-0
lines changed
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import org.apache.avro.Schema
21+
import org.apache.avro.generic.GenericDatumReader
22+
import org.apache.avro.io.{BinaryDecoder, DecoderFactory}
23+
24+
import org.apache.spark.sql.avro.{AvroDeserializer, SchemaConverters}
25+
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, UnaryExpression}
26+
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
27+
import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType}
28+
29+
case class AvroDataToCatalyst(child: Expression, jsonFormatSchema: String)
30+
extends UnaryExpression with ExpectsInputTypes {
31+
32+
override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType)
33+
34+
override lazy val dataType: DataType = SchemaConverters.toSqlType(avroSchema).dataType
35+
36+
override def nullable: Boolean = true
37+
38+
@transient private lazy val avroSchema = new Schema.Parser().parse(jsonFormatSchema)
39+
40+
@transient private lazy val reader = new GenericDatumReader[Any](avroSchema)
41+
42+
@transient private lazy val deserializer = new AvroDeserializer(avroSchema, dataType)
43+
44+
@transient private var decoder: BinaryDecoder = _
45+
46+
@transient private var result: Any = _
47+
48+
override def nullSafeEval(input: Any): Any = {
49+
val binary = input.asInstanceOf[Array[Byte]]
50+
decoder = DecoderFactory.get().binaryDecoder(binary, 0, binary.length, decoder)
51+
result = reader.read(result, decoder)
52+
deserializer.deserialize(result)
53+
}
54+
55+
override def simpleString: String = {
56+
s"from_avro(${child.sql}, ${dataType.simpleString})"
57+
}
58+
59+
override def sql: String = {
60+
s"from_avro(${child.sql}, ${dataType.catalogString})"
61+
}
62+
63+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
64+
val expr = ctx.addReferenceObj("this", this)
65+
defineCodeGen(ctx, ev, input =>
66+
s"(${CodeGenerator.boxedType(dataType)})$expr.nullSafeEval($input)")
67+
}
68+
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import java.io.ByteArrayOutputStream
21+
22+
import org.apache.avro.generic.GenericDatumWriter
23+
import org.apache.avro.io.{BinaryEncoder, EncoderFactory}
24+
25+
import org.apache.spark.sql.avro.{AvroSerializer, SchemaConverters}
26+
import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression}
27+
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
28+
import org.apache.spark.sql.types.{BinaryType, DataType}
29+
30+
case class CatalystDataToAvro(child: Expression) extends UnaryExpression {
31+
32+
override def dataType: DataType = BinaryType
33+
34+
@transient private lazy val avroType =
35+
SchemaConverters.toAvroType(child.dataType, child.nullable)
36+
37+
@transient private lazy val serializer =
38+
new AvroSerializer(child.dataType, avroType, child.nullable)
39+
40+
@transient private lazy val writer =
41+
new GenericDatumWriter[Any](avroType)
42+
43+
@transient private var encoder: BinaryEncoder = _
44+
45+
@transient private lazy val out = new ByteArrayOutputStream
46+
47+
override def nullSafeEval(input: Any): Any = {
48+
out.reset()
49+
encoder = EncoderFactory.get().directBinaryEncoder(out, encoder)
50+
val avroData = serializer.serialize(input)
51+
writer.write(avroData, encoder)
52+
encoder.flush()
53+
out.toByteArray
54+
}
55+
56+
override def simpleString: String = {
57+
s"to_avro(${child.sql}, ${child.dataType.simpleString})"
58+
}
59+
60+
override def sql: String = {
61+
s"to_avro(${child.sql}, ${child.dataType.catalogString})"
62+
}
63+
64+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
65+
val expr = ctx.addReferenceObj("this", this)
66+
defineCodeGen(ctx, ev, input =>
67+
s"(byte[]) $expr.nullSafeEval($input)")
68+
}
69+
}

external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717

1818
package org.apache.spark.sql
1919

20+
import org.apache.avro.Schema
21+
22+
import org.apache.spark.annotation.Experimental
23+
2024
package object avro {
2125
/**
2226
* Adds a method, `avro`, to DataFrameWriter that allows you to write avro files using
@@ -36,4 +40,31 @@ package object avro {
3640
@scala.annotation.varargs
3741
def avro(sources: String*): DataFrame = reader.format("avro").load(sources: _*)
3842
}
43+
44+
/**
45+
* Converts a binary column of avro format into its corresponding catalyst value. The specified
46+
* schema must match the read data, otherwise the behavior is undefined: it may fail or return
47+
* arbitrary result.
48+
*
49+
* @param data the binary column.
50+
* @param jsonFormatSchema the avro schema in JSON string format.
51+
*
52+
* @since 2.4.0
53+
*/
54+
@Experimental
55+
def from_avro(data: Column, jsonFormatSchema: String): Column = {
56+
new Column(AvroDataToCatalyst(data.expr, jsonFormatSchema))
57+
}
58+
59+
/**
60+
* Converts a column into binary of avro format.
61+
*
62+
* @param data the data column.
63+
*
64+
* @since 2.4.0
65+
*/
66+
@Experimental
67+
def to_avro(data: Column): Column = {
68+
new Column(CatalystDataToAvro(data.expr))
69+
}
3970
}
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.avro
19+
20+
import org.apache.avro.Schema
21+
22+
import org.apache.spark.SparkFunSuite
23+
import org.apache.spark.sql.{AvroDataToCatalyst, CatalystDataToAvro, RandomDataGenerator}
24+
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
25+
import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, GenericInternalRow, Literal}
26+
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData}
27+
import org.apache.spark.sql.types._
28+
import org.apache.spark.unsafe.types.UTF8String
29+
30+
class AvroCatalystDataConversionSuite extends SparkFunSuite with ExpressionEvalHelper {
31+
32+
private def roundTripTest(data: Literal): Unit = {
33+
val avroType = SchemaConverters.toAvroType(data.dataType, data.nullable)
34+
checkResult(data, avroType.toString, data.eval())
35+
}
36+
37+
private def checkResult(data: Literal, schema: String, expected: Any): Unit = {
38+
checkEvaluation(
39+
AvroDataToCatalyst(CatalystDataToAvro(data), schema),
40+
prepareExpectedResult(expected))
41+
}
42+
43+
private def assertFail(data: Literal, schema: String): Unit = {
44+
intercept[java.io.EOFException] {
45+
AvroDataToCatalyst(CatalystDataToAvro(data), schema).eval()
46+
}
47+
}
48+
49+
private val testingTypes = Seq(
50+
BooleanType,
51+
ByteType,
52+
ShortType,
53+
IntegerType,
54+
LongType,
55+
FloatType,
56+
DoubleType,
57+
DecimalType(8, 0), // 32 bits decimal without fraction
58+
DecimalType(8, 4), // 32 bits decimal
59+
DecimalType(16, 0), // 64 bits decimal without fraction
60+
DecimalType(16, 11), // 64 bits decimal
61+
DecimalType(38, 0),
62+
DecimalType(38, 38),
63+
StringType,
64+
BinaryType)
65+
66+
protected def prepareExpectedResult(expected: Any): Any = expected match {
67+
// Spark decimal is converted to avro string=
68+
case d: Decimal => UTF8String.fromString(d.toString)
69+
// Spark byte and short both map to avro int
70+
case b: Byte => b.toInt
71+
case s: Short => s.toInt
72+
case row: GenericInternalRow => InternalRow.fromSeq(row.values.map(prepareExpectedResult))
73+
case array: GenericArrayData => new GenericArrayData(array.array.map(prepareExpectedResult))
74+
case map: MapData =>
75+
val keys = new GenericArrayData(
76+
map.keyArray().asInstanceOf[GenericArrayData].array.map(prepareExpectedResult))
77+
val values = new GenericArrayData(
78+
map.valueArray().asInstanceOf[GenericArrayData].array.map(prepareExpectedResult))
79+
new ArrayBasedMapData(keys, values)
80+
case other => other
81+
}
82+
83+
testingTypes.foreach { dt =>
84+
val seed = scala.util.Random.nextLong()
85+
test(s"single $dt with seed $seed") {
86+
val rand = new scala.util.Random(seed)
87+
val data = RandomDataGenerator.forType(dt, rand = rand).get.apply()
88+
val converter = CatalystTypeConverters.createToCatalystConverter(dt)
89+
val input = Literal.create(converter(data), dt)
90+
roundTripTest(input)
91+
}
92+
}
93+
94+
for (_ <- 1 to 5) {
95+
val seed = scala.util.Random.nextLong()
96+
val rand = new scala.util.Random(seed)
97+
val schema = RandomDataGenerator.randomSchema(rand, 5, testingTypes)
98+
test(s"flat schema ${schema.catalogString} with seed $seed") {
99+
val data = RandomDataGenerator.randomRow(rand, schema)
100+
val converter = CatalystTypeConverters.createToCatalystConverter(schema)
101+
val input = Literal.create(converter(data), schema)
102+
roundTripTest(input)
103+
}
104+
}
105+
106+
for (_ <- 1 to 5) {
107+
val seed = scala.util.Random.nextLong()
108+
val rand = new scala.util.Random(seed)
109+
val schema = RandomDataGenerator.randomNestedSchema(rand, 10, testingTypes)
110+
test(s"nested schema ${schema.catalogString} with seed $seed") {
111+
val data = RandomDataGenerator.randomRow(rand, schema)
112+
val converter = CatalystTypeConverters.createToCatalystConverter(schema)
113+
val input = Literal.create(converter(data), schema)
114+
roundTripTest(input)
115+
}
116+
}
117+
118+
test("read int as string") {
119+
val data = Literal(1)
120+
val avroTypeJson =
121+
s"""
122+
|{
123+
| "type": "string",
124+
| "name": "my_string"
125+
|}
126+
""".stripMargin
127+
128+
// When read int as string, avro reader is not able to parse the binary and fail.
129+
assertFail(data, avroTypeJson)
130+
}
131+
132+
test("read string as int") {
133+
val data = Literal("abc")
134+
val avroTypeJson =
135+
s"""
136+
|{
137+
| "type": "int",
138+
| "name": "my_int"
139+
|}
140+
""".stripMargin
141+
142+
// When read string data as int, avro reader is not able to find the type mismatch and read
143+
// the string length as int value.
144+
checkResult(data, avroTypeJson, 3)
145+
}
146+
147+
test("read float as double") {
148+
val data = Literal(1.23f)
149+
val avroTypeJson =
150+
s"""
151+
|{
152+
| "type": "double",
153+
| "name": "my_double"
154+
|}
155+
""".stripMargin
156+
157+
// When read float data as double, avro reader fails(trying to read 8 bytes while the data have
158+
// only 4 bytes).
159+
assertFail(data, avroTypeJson)
160+
}
161+
162+
test("read double as float") {
163+
val data = Literal(1.23)
164+
val avroTypeJson =
165+
s"""
166+
|{
167+
| "type": "float",
168+
| "name": "my_float"
169+
|}
170+
""".stripMargin
171+
172+
// avro reader reads the first 4 bytes of a double as a float, the result is totally undefined.
173+
checkResult(data, avroTypeJson, 5.848603E35f)
174+
}
175+
}

0 commit comments

Comments
 (0)