@@ -1210,38 +1210,56 @@ class SQLContext(@transient val sparkContext: SparkContext)
12101210 * Returns a Catalyst Schema for the given java bean class.
12111211 */
12121212 protected def getSchema (beanClass : Class [_]): Seq [AttributeReference ] = {
1213+ val (dataType, _) = inferDataType(beanClass)
1214+ dataType.asInstanceOf [StructType ].fields.map { f =>
1215+ AttributeReference (f.name, f.dataType, f.nullable)()
1216+ }
1217+ }
1218+
1219+ /**
1220+ * Infers the corresponding SQL data type of a Java class.
1221+ * @param clazz Java class
1222+ * @return (SQL data type, nullable)
1223+ */
1224+ private def inferDataType (clazz : Class [_]): (DataType , Boolean ) = {
12131225 // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
1214- val beanInfo = Introspector .getBeanInfo(beanClass)
1215-
1216- // Note: The ordering of elements may differ from when the schema is inferred in Scala.
1217- // This is because beanInfo.getPropertyDescriptors gives no guarantees about
1218- // element ordering.
1219- val fields = beanInfo.getPropertyDescriptors.filterNot(_.getName == " class" )
1220- fields.map { property =>
1221- val (dataType, nullable) = property.getPropertyType match {
1222- case c : Class [_] if c.isAnnotationPresent(classOf [SQLUserDefinedType ]) =>
1223- (c.getAnnotation(classOf [SQLUserDefinedType ]).udt().newInstance(), true )
1224- case c : Class [_] if c == classOf [java.lang.String ] => (StringType , true )
1225- case c : Class [_] if c == java.lang.Short .TYPE => (ShortType , false )
1226- case c : Class [_] if c == java.lang.Integer .TYPE => (IntegerType , false )
1227- case c : Class [_] if c == java.lang.Long .TYPE => (LongType , false )
1228- case c : Class [_] if c == java.lang.Double .TYPE => (DoubleType , false )
1229- case c : Class [_] if c == java.lang.Byte .TYPE => (ByteType , false )
1230- case c : Class [_] if c == java.lang.Float .TYPE => (FloatType , false )
1231- case c : Class [_] if c == java.lang.Boolean .TYPE => (BooleanType , false )
1232-
1233- case c : Class [_] if c == classOf [java.lang.Short ] => (ShortType , true )
1234- case c : Class [_] if c == classOf [java.lang.Integer ] => (IntegerType , true )
1235- case c : Class [_] if c == classOf [java.lang.Long ] => (LongType , true )
1236- case c : Class [_] if c == classOf [java.lang.Double ] => (DoubleType , true )
1237- case c : Class [_] if c == classOf [java.lang.Byte ] => (ByteType , true )
1238- case c : Class [_] if c == classOf [java.lang.Float ] => (FloatType , true )
1239- case c : Class [_] if c == classOf [java.lang.Boolean ] => (BooleanType , true )
1240- case c : Class [_] if c == classOf [java.math.BigDecimal ] => (DecimalType (), true )
1241- case c : Class [_] if c == classOf [java.sql.Date ] => (DateType , true )
1242- case c : Class [_] if c == classOf [java.sql.Timestamp ] => (TimestampType , true )
1243- }
1244- AttributeReference (property.getName, dataType, nullable)()
1226+ clazz match {
1227+ case c : Class [_] if c.isAnnotationPresent(classOf [SQLUserDefinedType ]) =>
1228+ (c.getAnnotation(classOf [SQLUserDefinedType ]).udt().newInstance(), true )
1229+
1230+ case c : Class [_] if c == classOf [java.lang.String ] => (StringType , true )
1231+ case c : Class [_] if c == java.lang.Short .TYPE => (ShortType , false )
1232+ case c : Class [_] if c == java.lang.Integer .TYPE => (IntegerType , false )
1233+ case c : Class [_] if c == java.lang.Long .TYPE => (LongType , false )
1234+ case c : Class [_] if c == java.lang.Double .TYPE => (DoubleType , false )
1235+ case c : Class [_] if c == java.lang.Byte .TYPE => (ByteType , false )
1236+ case c : Class [_] if c == java.lang.Float .TYPE => (FloatType , false )
1237+ case c : Class [_] if c == java.lang.Boolean .TYPE => (BooleanType , false )
1238+
1239+ case c : Class [_] if c == classOf [java.lang.Short ] => (ShortType , true )
1240+ case c : Class [_] if c == classOf [java.lang.Integer ] => (IntegerType , true )
1241+ case c : Class [_] if c == classOf [java.lang.Long ] => (LongType , true )
1242+ case c : Class [_] if c == classOf [java.lang.Double ] => (DoubleType , true )
1243+ case c : Class [_] if c == classOf [java.lang.Byte ] => (ByteType , true )
1244+ case c : Class [_] if c == classOf [java.lang.Float ] => (FloatType , true )
1245+ case c : Class [_] if c == classOf [java.lang.Boolean ] => (BooleanType , true )
1246+
1247+ case c : Class [_] if c == classOf [java.math.BigDecimal ] => (DecimalType (), true )
1248+ case c : Class [_] if c == classOf [java.sql.Date ] => (DateType , true )
1249+ case c : Class [_] if c == classOf [java.sql.Timestamp ] => (TimestampType , true )
1250+
1251+ case c : Class [_] if c.isArray =>
1252+ val (dataType, nullable) = inferDataType(c.getComponentType)
1253+ (ArrayType (dataType, nullable), true )
1254+
1255+ case _ =>
1256+ val beanInfo = Introspector .getBeanInfo(clazz)
1257+ val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == " class" )
1258+ val fields = properties.map { property =>
1259+ val (dataType, nullable) = inferDataType(property.getPropertyType)
1260+ new StructField (property.getName, dataType, nullable)
1261+ }
1262+ (new StructType (fields), true )
12451263 }
12461264 }
12471265}
0 commit comments