|
| 1 | +/* |
| 2 | + * Copyright (C) 2016 Databricks, Inc. |
| 3 | + * |
| 4 | + * Portions of this software incorporate or are derived from software contained within Apache Spark, |
| 5 | + * and this modified software differs from the Apache Spark software provided under the Apache |
| 6 | + * License, Version 2.0, a copy of which you may obtain at |
| 7 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | + */ |
| 9 | + |
| 10 | +package com.databricks.spark.avro |
| 11 | + |
| 12 | +import java.io.{IOException, OutputStream} |
| 13 | +import java.nio.ByteBuffer |
| 14 | +import java.sql.Timestamp |
| 15 | +import java.util.HashMap |
| 16 | + |
| 17 | +import scala.collection.immutable.Map |
| 18 | + |
| 19 | +import org.apache.avro.{Schema, SchemaBuilder} |
| 20 | +import org.apache.avro.generic.GenericData.Record |
| 21 | +import org.apache.avro.generic.GenericRecord |
| 22 | +import org.apache.avro.mapred.AvroKey |
| 23 | +import org.apache.avro.mapreduce.AvroKeyOutputFormat |
| 24 | +import org.apache.hadoop.fs.Path |
| 25 | +import org.apache.hadoop.io.NullWritable |
| 26 | +import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext} |
| 27 | + |
| 28 | +import org.apache.spark.sql.Row |
| 29 | +import org.apache.spark.sql.execution.datasources.OutputWriter |
| 30 | +import org.apache.spark.sql.types._ |
| 31 | + |
| 32 | +// NOTE: This class is instantiated and used on executor side only, no need to be serializable. |
| 33 | +private[avro] class AvroOutputWriter( |
| 34 | + path: String, |
| 35 | + context: TaskAttemptContext, |
| 36 | + schema: StructType, |
| 37 | + recordName: String, |
| 38 | + recordNamespace: String) extends OutputWriter { |
| 39 | + |
| 40 | + private lazy val converter = createConverterToAvro(schema, recordName, recordNamespace) |
| 41 | + |
| 42 | + /** |
| 43 | + * Overrides the couple of methods responsible for generating the output streams / files so |
| 44 | + * that the data can be correctly partitioned |
| 45 | + */ |
| 46 | + private val recordWriter: RecordWriter[AvroKey[GenericRecord], NullWritable] = |
| 47 | + new AvroKeyOutputFormat[GenericRecord]() { |
| 48 | + |
| 49 | + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { |
| 50 | + new Path(path) |
| 51 | + } |
| 52 | + |
| 53 | + @throws(classOf[IOException]) |
| 54 | + override def getAvroFileOutputStream(c: TaskAttemptContext): OutputStream = { |
| 55 | + val path = getDefaultWorkFile(context, ".avro") |
| 56 | + path.getFileSystem(context.getConfiguration).create(path) |
| 57 | + } |
| 58 | + |
| 59 | + }.getRecordWriter(context) |
| 60 | + |
| 61 | + override def write(row: Row): Unit = { |
| 62 | + val key = new AvroKey(converter(row).asInstanceOf[GenericRecord]) |
| 63 | + recordWriter.write(key, NullWritable.get()) |
| 64 | + } |
| 65 | + |
| 66 | + override def close(): Unit = recordWriter.close(context) |
| 67 | + |
| 68 | + /** |
| 69 | + * This function constructs converter function for a given sparkSQL datatype. This is used in |
| 70 | + * writing Avro records out to disk |
| 71 | + */ |
| 72 | + private def createConverterToAvro( |
| 73 | + dataType: DataType, |
| 74 | + structName: String, |
| 75 | + recordNamespace: String): (Any) => Any = { |
| 76 | + dataType match { |
| 77 | + case BinaryType => (item: Any) => item match { |
| 78 | + case null => null |
| 79 | + case bytes: Array[Byte] => ByteBuffer.wrap(bytes) |
| 80 | + } |
| 81 | + case ByteType | ShortType | IntegerType | LongType | |
| 82 | + FloatType | DoubleType | StringType | BooleanType => identity |
| 83 | + case _: DecimalType => (item: Any) => if (item == null) null else item.toString |
| 84 | + case TimestampType => (item: Any) => |
| 85 | + if (item == null) null else item.asInstanceOf[Timestamp].getTime |
| 86 | + case ArrayType(elementType, _) => |
| 87 | + val elementConverter = createConverterToAvro(elementType, structName, recordNamespace) |
| 88 | + (item: Any) => { |
| 89 | + if (item == null) { |
| 90 | + null |
| 91 | + } else { |
| 92 | + val sourceArray = item.asInstanceOf[Seq[Any]] |
| 93 | + val sourceArraySize = sourceArray.size |
| 94 | + val targetArray = new Array[Any](sourceArraySize) |
| 95 | + var idx = 0 |
| 96 | + while (idx < sourceArraySize) { |
| 97 | + targetArray(idx) = elementConverter(sourceArray(idx)) |
| 98 | + idx += 1 |
| 99 | + } |
| 100 | + targetArray |
| 101 | + } |
| 102 | + } |
| 103 | + case MapType(StringType, valueType, _) => |
| 104 | + val valueConverter = createConverterToAvro(valueType, structName, recordNamespace) |
| 105 | + (item: Any) => { |
| 106 | + if (item == null) { |
| 107 | + null |
| 108 | + } else { |
| 109 | + val javaMap = new HashMap[String, Any]() |
| 110 | + item.asInstanceOf[Map[String, Any]].foreach { case (key, value) => |
| 111 | + javaMap.put(key, valueConverter(value)) |
| 112 | + } |
| 113 | + javaMap |
| 114 | + } |
| 115 | + } |
| 116 | + case structType: StructType => |
| 117 | + val builder = SchemaBuilder.record(structName).namespace(recordNamespace) |
| 118 | + val schema: Schema = SchemaConverters.convertStructToAvro( |
| 119 | + structType, builder, recordNamespace) |
| 120 | + val fieldConverters = structType.fields.map(field => |
| 121 | + createConverterToAvro(field.dataType, field.name, recordNamespace)) |
| 122 | + (item: Any) => { |
| 123 | + if (item == null) { |
| 124 | + null |
| 125 | + } else { |
| 126 | + val record = new Record(schema) |
| 127 | + val convertersIterator = fieldConverters.iterator |
| 128 | + val fieldNamesIterator = dataType.asInstanceOf[StructType].fieldNames.iterator |
| 129 | + val rowIterator = item.asInstanceOf[Row].toSeq.iterator |
| 130 | + |
| 131 | + while (convertersIterator.hasNext) { |
| 132 | + val converter = convertersIterator.next() |
| 133 | + record.put(fieldNamesIterator.next(), converter(rowIterator.next())) |
| 134 | + } |
| 135 | + record |
| 136 | + } |
| 137 | + } |
| 138 | + } |
| 139 | + } |
| 140 | +} |
0 commit comments