From a68321400c1068449698d03cebd0fbf648627133 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 3 Nov 2014 12:24:24 -0800 Subject: [PATCH 001/652] [SPARK-4148][PySpark] fix seed distribution and add some tests for rdd.sample The current way of seed distribution makes the random sequences from partition i and i+1 offset by 1. ~~~ In [14]: import random In [15]: r1 = random.Random(10) In [16]: r1.randint(0, 1) Out[16]: 1 In [17]: r1.random() Out[17]: 0.4288890546751146 In [18]: r1.random() Out[18]: 0.5780913011344704 In [19]: r2 = random.Random(10) In [20]: r2.randint(0, 1) Out[20]: 1 In [21]: r2.randint(0, 1) Out[21]: 0 In [22]: r2.random() Out[22]: 0.5780913011344704 ~~~ Note: The new tests are not for this bug fix. Author: Xiangrui Meng Closes #3010 from mengxr/SPARK-4148 and squashes the following commits: 869ae4b [Xiangrui Meng] move tests tests.py c1bacd9 [Xiangrui Meng] fix seed distribution and add some tests for rdd.sample (cherry picked from commit 3cca1962207745814b9d83e791713c91b659c36c) Signed-off-by: Xiangrui Meng --- python/pyspark/rdd.py | 3 --- python/pyspark/rddsampler.py | 11 +++++------ python/pyspark/tests.py | 15 +++++++++++++++ 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 550c9dd80522..4f025b9f1170 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -316,9 +316,6 @@ def sample(self, withReplacement, fraction, seed=None): """ Return a sampled subset of this RDD (relies on numpy and falls back on default random generator if numpy is unavailable). - - >>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP - [2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98] """ assert fraction >= 0.0, "Negative fraction value: %s" % fraction return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True) diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py index 528a181e8905..f5c3cfd259a5 100644 --- a/python/pyspark/rddsampler.py +++ b/python/pyspark/rddsampler.py @@ -40,14 +40,13 @@ def __init__(self, withReplacement, seed=None): def initRandomGenerator(self, split): if self._use_numpy: import numpy - self._random = numpy.random.RandomState(self._seed) + self._random = numpy.random.RandomState(self._seed ^ split) else: - self._random = random.Random(self._seed) + self._random = random.Random(self._seed ^ split) - for _ in range(0, split): - # discard the next few values in the sequence to have a - # different seed for the different splits - self._random.randint(0, 2 ** 32 - 1) + # mixing because the initial seeds are close to each other + for _ in xrange(10): + self._random.randint(0, 1) self._split = split self._rand_initialized = True diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 37a128907b3a..253a471849c3 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -648,6 +648,21 @@ def test_distinct(self): self.assertEquals(result.getNumPartitions(), 5) self.assertEquals(result.count(), 3) + def test_sample(self): + rdd = self.sc.parallelize(range(0, 100), 4) + wo = rdd.sample(False, 0.1, 2).collect() + wo_dup = rdd.sample(False, 0.1, 2).collect() + self.assertSetEqual(set(wo), set(wo_dup)) + wr = rdd.sample(True, 0.2, 5).collect() + wr_dup = rdd.sample(True, 0.2, 5).collect() + self.assertSetEqual(set(wr), set(wr_dup)) + wo_s10 = rdd.sample(False, 0.3, 10).collect() + wo_s20 = rdd.sample(False, 0.3, 20).collect() + self.assertNotEqual(set(wo_s10), set(wo_s20)) + wr_s11 = rdd.sample(True, 0.4, 11).collect() + wr_s21 = rdd.sample(True, 0.4, 21).collect() + self.assertNotEqual(set(wr_s11), set(wr_s21)) + class ProfilerTests(PySparkTestCase): From fc782896b5d51161feee950107df2acf17e12422 Mon Sep 17 00:00:00 2001 From: fi Date: Mon, 3 Nov 2014 12:56:56 -0800 Subject: [PATCH 002/652] [SPARK-4211][Build] Fixes hive.version in Maven profile hive-0.13.1 instead of `hive.version=0.13.1`. e.g. mvn -Phive -Phive=0.13.1 Note: `hive.version=0.13.1a` is the default property value. However, when explicitly specifying the `hive-0.13.1` maven profile, the wrong one would be selected. References: PR #2685, which resolved a package incompatibility issue with Hive-0.13.1 by introducing a special version Hive-0.13.1a Author: fi Closes #3072 from coderfi/master and squashes the following commits: 7ca4b1e [fi] Fixes the `hive-0.13.1` maven profile referencing `hive.version=0.13.1` instead of the Spark compatible `hive.version=0.13.1a` Note: `hive.version=0.13.1a` is the default version. However, when explicitly specifying the `hive-0.13.1` maven profile, the wrong one would be selected. e.g. mvn -Phive -Phive=0.13.1 See PR #2685 (cherry picked from commit df607da025488d6c924d3d70eddb67f5523080d3) Signed-off-by: Michael Armbrust --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 6191cd3a541e..eb613531b8a5 100644 --- a/pom.xml +++ b/pom.xml @@ -1359,7 +1359,7 @@ false - 0.13.1 + 0.13.1a 0.13.1 10.10.1.1 From 292da4ef25d6cce23bfde7b9ab663a574dfd2b00 Mon Sep 17 00:00:00 2001 From: ravipesala Date: Mon, 3 Nov 2014 13:07:41 -0800 Subject: [PATCH 003/652] [SPARK-4207][SQL] Query which has syntax like 'not like' is not working in Spark SQL Queries which has 'not like' is not working spark sql. sql("SELECT * FROM records where value not like 'val%'") same query works in Spark HiveQL Author: ravipesala Closes #3075 from ravipesala/SPARK-4207 and squashes the following commits: 35c11e7 [ravipesala] Supported 'not like' syntax in sql (cherry picked from commit 2b6e1ce6ee7b1ba8160bcbee97f5bbff5c46ca09) Signed-off-by: Michael Armbrust --- .../main/scala/org/apache/spark/sql/catalyst/SqlParser.scala | 1 + .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 00fc4d75c9ea..5e613e0f18ba 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -242,6 +242,7 @@ class SqlParser extends AbstractSparkSQLParser { | termExpression ~ (RLIKE ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } | termExpression ~ (REGEXP ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } | termExpression ~ (LIKE ~> termExpression) ^^ { case e1 ~ e2 => Like(e1, e2) } + | termExpression ~ (NOT ~ LIKE ~> termExpression) ^^ { case e1 ~ e2 => Not(Like(e1, e2)) } | termExpression ~ (IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ { case e1 ~ e2 => In(e1, e2) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 6bf439377aa3..702714af5308 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -938,4 +938,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer(sql("SELECT key FROM testData WHERE key not between 0 and 10 order by key"), (11 to 100).map(i => Seq(i))) } + + test("SPARK-4207 Query which has syntax like 'not like' is not working in Spark SQL") { + checkAnswer(sql("SELECT key FROM testData WHERE value not like '100%' order by key"), + (1 to 99).map(i => Seq(i))) + } } From cc5dc4247979dc001302f7af978801b789acdbfa Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 3 Nov 2014 13:17:09 -0800 Subject: [PATCH 004/652] [SPARK-3594] [PySpark] [SQL] take more rows to infer schema or sampling This patch will try to infer schema for RDD which has empty value (None, [], {}) in the first row. It will try first 100 rows and merge the types into schema, also merge fields of StructType together. If there is still NullType in schema, then it will show an warning, tell user to try with sampling. If sampling is presented, it will infer schema from all the rows after sampling. Also, add samplingRatio for jsonFile() and jsonRDD() Author: Davies Liu Author: Davies Liu Closes #2716 from davies/infer and squashes the following commits: e678f6d [Davies Liu] Merge branch 'master' of github.com:apache/spark into infer 34b5c63 [Davies Liu] Merge branch 'master' of github.com:apache/spark into infer 567dc60 [Davies Liu] update docs 9767b27 [Davies Liu] Merge branch 'master' into infer e48d7fb [Davies Liu] fix tests 29e94d5 [Davies Liu] let NullType inherit from PrimitiveType ee5d524 [Davies Liu] Merge branch 'master' of github.com:apache/spark into infer 540d1d5 [Davies Liu] merge fields for StructType f93fd84 [Davies Liu] add more tests 3603e00 [Davies Liu] take more rows to infer schema, or infer the schema by sampling the RDD (cherry picked from commit 24544fbce05665ab4999a1fe5aac434d29cd912c) Signed-off-by: Michael Armbrust --- python/pyspark/sql.py | 196 ++++++++++++------ python/pyspark/tests.py | 19 ++ .../spark/sql/catalyst/types/dataTypes.scala | 2 +- 3 files changed, 148 insertions(+), 69 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 98e41f857567..675df084bf30 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -109,6 +109,15 @@ def __eq__(self, other): return self is other +class NullType(PrimitiveType): + + """Spark SQL NullType + + The data type representing None, used for the types which has not + been inferred. + """ + + class StringType(PrimitiveType): """Spark SQL StringType @@ -331,7 +340,7 @@ class StructField(DataType): """ - def __init__(self, name, dataType, nullable, metadata=None): + def __init__(self, name, dataType, nullable=True, metadata=None): """Creates a StructField :param name: the name of this field. :param dataType: the data type of this field. @@ -484,6 +493,7 @@ def _parse_datatype_json_value(json_value): # Mapping Python types to Spark SQL DataType _type_mappings = { + type(None): NullType, bool: BooleanType, int: IntegerType, long: LongType, @@ -500,22 +510,22 @@ def _parse_datatype_json_value(json_value): def _infer_type(obj): """Infer the DataType from obj""" - if obj is None: - raise ValueError("Can not infer type for None") - dataType = _type_mappings.get(type(obj)) if dataType is not None: return dataType() if isinstance(obj, dict): - if not obj: - raise ValueError("Can not infer type for empty dict") - key, value = obj.iteritems().next() - return MapType(_infer_type(key), _infer_type(value), True) + for key, value in obj.iteritems(): + if key is not None and value is not None: + return MapType(_infer_type(key), _infer_type(value), True) + else: + return MapType(NullType(), NullType(), True) elif isinstance(obj, (list, array)): - if not obj: - raise ValueError("Can not infer type for empty list/array") - return ArrayType(_infer_type(obj[0]), True) + for v in obj: + if v is not None: + return ArrayType(_infer_type(obj[0]), True) + else: + return ArrayType(NullType(), True) else: try: return _infer_schema(obj) @@ -548,60 +558,93 @@ def _infer_schema(row): return StructType(fields) -def _create_converter(obj, dataType): +def _has_nulltype(dt): + """ Return whether there is NullType in `dt` or not """ + if isinstance(dt, StructType): + return any(_has_nulltype(f.dataType) for f in dt.fields) + elif isinstance(dt, ArrayType): + return _has_nulltype((dt.elementType)) + elif isinstance(dt, MapType): + return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType) + else: + return isinstance(dt, NullType) + + +def _merge_type(a, b): + if isinstance(a, NullType): + return b + elif isinstance(b, NullType): + return a + elif type(a) is not type(b): + # TODO: type cast (such as int -> long) + raise TypeError("Can not merge type %s and %s" % (a, b)) + + # same type + if isinstance(a, StructType): + nfs = dict((f.name, f.dataType) for f in b.fields) + fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()))) + for f in a.fields] + names = set([f.name for f in fields]) + for n in nfs: + if n not in names: + fields.append(StructField(n, nfs[n])) + return StructType(fields) + + elif isinstance(a, ArrayType): + return ArrayType(_merge_type(a.elementType, b.elementType), True) + + elif isinstance(a, MapType): + return MapType(_merge_type(a.keyType, b.keyType), + _merge_type(a.valueType, b.valueType), + True) + else: + return a + + +def _create_converter(dataType): """Create an converter to drop the names of fields in obj """ if isinstance(dataType, ArrayType): - conv = _create_converter(obj[0], dataType.elementType) + conv = _create_converter(dataType.elementType) return lambda row: map(conv, row) elif isinstance(dataType, MapType): - value = obj.values()[0] - conv = _create_converter(value, dataType.valueType) + conv = _create_converter(dataType.valueType) return lambda row: dict((k, conv(v)) for k, v in row.iteritems()) + elif isinstance(dataType, NullType): + return lambda x: None + elif not isinstance(dataType, StructType): return lambda x: x # dataType must be StructType names = [f.name for f in dataType.fields] + converters = [_create_converter(f.dataType) for f in dataType.fields] + + def convert_struct(obj): + if obj is None: + return + + if isinstance(obj, tuple): + if hasattr(obj, "fields"): + d = dict(zip(obj.fields, obj)) + if hasattr(obj, "__FIELDS__"): + d = dict(zip(obj.__FIELDS__, obj)) + elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): + d = dict(obj) + else: + raise ValueError("unexpected tuple: %s" % obj) - if isinstance(obj, dict): - conv = lambda o: tuple(o.get(n) for n in names) - - elif isinstance(obj, tuple): - if hasattr(obj, "_fields"): # namedtuple - conv = tuple - elif hasattr(obj, "__FIELDS__"): - conv = tuple - elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): - conv = lambda o: tuple(v for k, v in o) + elif isinstance(obj, dict): + d = obj + elif hasattr(obj, "__dict__"): # object + d = obj.__dict__ else: - raise ValueError("unexpected tuple") + raise ValueError("Unexpected obj: %s" % obj) - elif hasattr(obj, "__dict__"): # object - conv = lambda o: [o.__dict__.get(n, None) for n in names] + return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) - if all(isinstance(f.dataType, PrimitiveType) for f in dataType.fields): - return conv - - row = conv(obj) - convs = [_create_converter(v, f.dataType) - for v, f in zip(row, dataType.fields)] - - def nested_conv(row): - return tuple(f(v) for f, v in zip(convs, conv(row))) - - return nested_conv - - -def _drop_schema(rows, schema): - """ all the names of fields, becoming tuples""" - iterator = iter(rows) - row = iterator.next() - converter = _create_converter(row, schema) - yield converter(row) - for i in iterator: - yield converter(i) + return convert_struct _BRACKETS = {'(': ')', '[': ']', '{': '}'} @@ -713,7 +756,7 @@ def _infer_schema_type(obj, dataType): return _infer_type(obj) if not obj: - raise ValueError("Can not infer type from empty value") + return NullType() if isinstance(dataType, ArrayType): eType = _infer_schema_type(obj[0], dataType.elementType) @@ -1049,18 +1092,20 @@ def registerFunction(self, name, f, returnType=StringType()): self._sc._javaAccumulator, returnType.json()) - def inferSchema(self, rdd): + def inferSchema(self, rdd, samplingRatio=None): """Infer and apply a schema to an RDD of L{Row}. - We peek at the first row of the RDD to determine the fields' names - and types. Nested collections are supported, which include array, - dict, list, Row, tuple, namedtuple, or object. + When samplingRatio is specified, the schema is inferred by looking + at the types of each row in the sampled dataset. Otherwise, the + first 100 rows of the RDD are inspected. Nested collections are + supported, which can include array, dict, list, Row, tuple, + namedtuple, or object. - All the rows in `rdd` should have the same type with the first one, - or it will cause runtime exceptions. + Each row could be L{pyspark.sql.Row} object or namedtuple or objects. + Using top level dicts is deprecated, as dict is used to represent Maps. - Each row could be L{pyspark.sql.Row} object or namedtuple or objects, - using dict is deprecated. + If a single column has multiple distinct inferred types, it may cause + runtime exceptions. >>> rdd = sc.parallelize( ... [Row(field1=1, field2="row1"), @@ -1097,8 +1142,23 @@ def inferSchema(self, rdd): warnings.warn("Using RDD of dict to inferSchema is deprecated," "please use pyspark.sql.Row instead") - schema = _infer_schema(first) - rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema)) + if samplingRatio is None: + schema = _infer_schema(first) + if _has_nulltype(schema): + for row in rdd.take(100)[1:]: + schema = _merge_type(schema, _infer_schema(row)) + if not _has_nulltype(schema): + break + else: + warnings.warn("Some of types cannot be determined by the " + "first 100 rows, please try again with sampling") + else: + if samplingRatio > 0.99: + rdd = rdd.sample(False, float(samplingRatio)) + schema = rdd.map(_infer_schema).reduce(_merge_type) + + converter = _create_converter(schema) + rdd = rdd.map(converter) return self.applySchema(rdd, schema) def applySchema(self, rdd, schema): @@ -1219,7 +1279,7 @@ def parquetFile(self, path): jschema_rdd = self._ssql_ctx.parquetFile(path).toJavaSchemaRDD() return SchemaRDD(jschema_rdd, self) - def jsonFile(self, path, schema=None): + def jsonFile(self, path, schema=None, samplingRatio=1.0): """ Loads a text file storing one JSON object per line as a L{SchemaRDD}. @@ -1227,8 +1287,8 @@ def jsonFile(self, path, schema=None): If the schema is provided, applies the given schema to this JSON dataset. - Otherwise, it goes through the entire dataset once to determine - the schema. + Otherwise, it samples the dataset with ratio `samplingRatio` to + determine the schema. >>> import tempfile, shutil >>> jsonFile = tempfile.mkdtemp() @@ -1274,20 +1334,20 @@ def jsonFile(self, path, schema=None): [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] """ if schema is None: - srdd = self._ssql_ctx.jsonFile(path) + srdd = self._ssql_ctx.jsonFile(path, samplingRatio) else: scala_datatype = self._ssql_ctx.parseDataType(schema.json()) srdd = self._ssql_ctx.jsonFile(path, scala_datatype) return SchemaRDD(srdd.toJavaSchemaRDD(), self) - def jsonRDD(self, rdd, schema=None): + def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): """Loads an RDD storing one JSON object per string as a L{SchemaRDD}. If the schema is provided, applies the given schema to this JSON dataset. - Otherwise, it goes through the entire dataset once to determine - the schema. + Otherwise, it samples the dataset with ratio `samplingRatio` to + determine the schema. >>> srdd1 = sqlCtx.jsonRDD(json) >>> sqlCtx.registerRDDAsTable(srdd1, "table1") @@ -1344,7 +1404,7 @@ def func(iterator): keyed._bypass_serializer = True jrdd = keyed._jrdd.map(self._jvm.BytesToString()) if schema is None: - srdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) + srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) else: scala_datatype = self._ssql_ctx.parseDataType(schema.json()) srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 253a471849c3..68fd75687621 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -796,6 +796,25 @@ def test_serialize_nested_array_and_map(self): self.assertEqual(1.0, row.c) self.assertEqual("2", row.d) + def test_infer_schema(self): + d = [Row(l=[], d={}), + Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] + rdd = self.sc.parallelize(d) + srdd = self.sqlCtx.inferSchema(rdd) + self.assertEqual([], srdd.map(lambda r: r.l).first()) + self.assertEqual([None, ""], srdd.map(lambda r: r.s).collect()) + srdd.registerTempTable("test") + result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") + self.assertEqual(1, result.first()[0]) + + srdd2 = self.sqlCtx.inferSchema(rdd, 1.0) + self.assertEqual(srdd.schema(), srdd2.schema()) + self.assertEqual({}, srdd2.map(lambda r: r.d).first()) + self.assertEqual([None, ""], srdd2.map(lambda r: r.s).collect()) + srdd2.registerTempTable("test2") + result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") + self.assertEqual(1, result.first()[0]) + def test_convert_row_to_dict(self): row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) self.assertEqual(1, row.asDict()['l'][0].a) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index cc5015ad3c01..e1b5992a36e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -213,7 +213,7 @@ trait PrimitiveType extends DataType { } object PrimitiveType { - private val nonDecimals = Seq(DateType, TimestampType, BinaryType) ++ NativeType.all + private val nonDecimals = Seq(NullType, DateType, TimestampType, BinaryType) ++ NativeType.all private val nonDecimalNameToType = nonDecimals.map(t => t.typeName -> t).toMap /** Given the string representation of a type, return its DataType */ From 572300ba8a5f24b52f19d7033a456248da20bfed Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 3 Nov 2014 13:20:33 -0800 Subject: [PATCH 005/652] [SPARK-4202][SQL] Simple DSL support for Scala UDF This feature is based on an offline discussion with mengxr, hopefully can be useful for the new MLlib pipeline API. For the following test snippet ```scala case class KeyValue(key: Int, value: String) val testData = sc.parallelize(1 to 10).map(i => KeyValue(i, i.toString)).toSchemaRDD def foo(a: Int, b: String) => a.toString + b ``` the newly introduced DSL enables the following syntax ```scala import org.apache.spark.sql.catalyst.dsl._ testData.select(Star(None), foo.call('key, 'value) as 'result) ``` which is equivalent to ```scala testData.registerTempTable("testData") sqlContext.registerFunction("foo", foo) sql("SELECT *, foo(key, value) AS result FROM testData") ``` Author: Cheng Lian Closes #3067 from liancheng/udf-dsl and squashes the following commits: f132818 [Cheng Lian] Adds DSL support for Scala UDF (cherry picked from commit c238fb423d1011bd1b1e6201d769b72e52664fc6) Signed-off-by: Michael Armbrust --- .../spark/sql/catalyst/dsl/package.scala | 59 +++++++++++++++++++ .../org/apache/spark/sql/DslQuerySuite.scala | 17 ++++-- 2 files changed, 72 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 7e6d770314f5..3314e1547701 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.types.decimal.Decimal import scala.language.implicitConversions +import scala.reflect.runtime.universe.{TypeTag, typeTag} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ @@ -285,4 +286,62 @@ package object dsl { def writeToFile(path: String) = WriteToFile(path, logicalPlan) } } + + case class ScalaUdfBuilder[T: TypeTag](f: AnyRef) { + def call(args: Expression*) = ScalaUdf(f, ScalaReflection.schemaFor(typeTag[T]).dataType, args) + } + + // scalastyle:off + /** functionToUdfBuilder 1-22 were generated by this script + + (1 to 22).map { x => + val argTypes = Seq.fill(x)("_").mkString(", ") + s"implicit def functionToUdfBuilder[T: TypeTag](func: Function$x[$argTypes, T]) = ScalaUdfBuilder(func)" + } + */ + + implicit def functionToUdfBuilder[T: TypeTag](func: Function1[_, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function2[_, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function3[_, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function4[_, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function5[_, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function6[_, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function7[_, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function8[_, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function9[_, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function10[_, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + // scalastyle:on } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index 45e58afe9d9a..e70ad891eea3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.test._ /* Implicits */ -import TestSQLContext._ +import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.sql.test.TestSQLContext._ class DslQuerySuite extends QueryTest { - import TestData._ + import org.apache.spark.sql.TestData._ test("table scan") { checkAnswer( @@ -216,4 +215,14 @@ class DslQuerySuite extends QueryTest { (4, "d") :: Nil) checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) } + + test("udf") { + val foo = (a: Int, b: String) => a.toString + b + + checkAnswer( + // SELECT *, foo(key, value) FROM testData + testData.select(Star(None), foo.call('key, 'value)).limit(3), + (1, "1", "11") :: (2, "2", "22") :: (3, "3", "33") :: Nil + ) + } } From 6104754f711da9eb0c09daf377bcd750d2d23f8a Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Mon, 3 Nov 2014 13:59:43 -0800 Subject: [PATCH 006/652] [SPARK-4152] [SQL] Avoid data change in CTAS while table already existed MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CREATE TABLE t1 (a String); CREATE TABLE t1 AS SELECT key FROM src; – throw exception CREATE TABLE if not exists t1 AS SELECT key FROM src; – expect do nothing, currently it will overwrite the t1, which is incorrect. Author: Cheng Hao Closes #3013 from chenghao-intel/ctas_unittest and squashes the following commits: 194113e [Cheng Hao] fix bug in CTAS when table already existed (cherry picked from commit e83f13e8d37ca33f4e183e977d077221b90c6025) Signed-off-by: Michael Armbrust --- .../spark/sql/catalyst/analysis/Catalog.scala | 22 +++++++++++++++++++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 6 +++++ .../hive/execution/CreateTableAsSelect.scala | 12 +++++++++- .../sql/hive/execution/SQLQuerySuite.scala | 9 ++++++-- 4 files changed, 46 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 2059a91ba061..0415d74bd814 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -28,6 +28,8 @@ trait Catalog { def caseSensitive: Boolean + def tableExists(db: Option[String], tableName: String): Boolean + def lookupRelation( databaseName: Option[String], tableName: String, @@ -82,6 +84,14 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog { tables.clear() } + override def tableExists(db: Option[String], tableName: String): Boolean = { + val (dbName, tblName) = processDatabaseAndTableName(db, tableName) + tables.get(tblName) match { + case Some(_) => true + case None => false + } + } + override def lookupRelation( databaseName: Option[String], tableName: String, @@ -107,6 +117,14 @@ trait OverrideCatalog extends Catalog { // TODO: This doesn't work when the database changes... val overrides = new mutable.HashMap[(Option[String],String), LogicalPlan]() + abstract override def tableExists(db: Option[String], tableName: String): Boolean = { + val (dbName, tblName) = processDatabaseAndTableName(db, tableName) + overrides.get((dbName, tblName)) match { + case Some(_) => true + case None => super.tableExists(db, tableName) + } + } + abstract override def lookupRelation( databaseName: Option[String], tableName: String, @@ -149,6 +167,10 @@ object EmptyCatalog extends Catalog { val caseSensitive: Boolean = true + def tableExists(db: Option[String], tableName: String): Boolean = { + throw new UnsupportedOperationException + } + def lookupRelation( databaseName: Option[String], tableName: String, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 096b4a07aa2e..0baf4c9f8c7a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -57,6 +57,12 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with val caseSensitive: Boolean = false + def tableExists(db: Option[String], tableName: String): Boolean = { + val (databaseName, tblName) = processDatabaseAndTableName( + db.getOrElse(hive.sessionState.getCurrentDatabase), tableName) + client.getTable(databaseName, tblName, false) != null + } + def lookupRelation( db: Option[String], tableName: String, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index 2fce41473457..3d24d87bc3d3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -71,7 +71,17 @@ case class CreateTableAsSelect( // TODO ideally, we should get the output data ready first and then // add the relation into catalog, just in case of failure occurs while data // processing. - sc.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true)).toRdd + if (sc.catalog.tableExists(Some(database), tableName)) { + if (allowExisting) { + // table already exists, will do nothing, to keep consistent with Hive + } else { + throw + new org.apache.hadoop.hive.metastore.api.AlreadyExistsException(s"$database.$tableName") + } + } else { + sc.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true)).toRdd + } + Seq.empty[Row] } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 76a0ec01a607..e9b1943ff8db 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -56,7 +56,7 @@ class SQLQuerySuite extends QueryTest { sql( """CREATE TABLE IF NOT EXISTS ctas4 AS | SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin).collect - // expect the string => integer for field key cause the table ctas4 already existed. + // do nothing cause the table ctas4 already existed. sql( """CREATE TABLE IF NOT EXISTS ctas4 AS | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect @@ -78,9 +78,14 @@ class SQLQuerySuite extends QueryTest { SELECT key, value FROM src ORDER BY key, value""").collect().toSeq) + intercept[org.apache.hadoop.hive.metastore.api.AlreadyExistsException] { + sql( + """CREATE TABLE ctas4 AS + | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect + } checkAnswer( sql("SELECT key, value FROM ctas4 ORDER BY key, value"), - sql("SELECT CAST(key AS int) k, value FROM src ORDER BY k, value").collect().toSeq) + sql("SELECT key, value FROM ctas4 LIMIT 1").collect().toSeq) checkExistence(sql("DESC EXTENDED ctas2"), true, "name:key", "type:string", "name:value", "ctas2", From 51985f78ca5f728f8b9233b703110f541d27b274 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 3 Nov 2014 14:08:27 -0800 Subject: [PATCH 007/652] [SQL] More aggressive defaults - Turns on compression for in-memory cached data by default - Changes the default parquet compression format back to gzip (we have seen more OOMs with production workloads due to the way Snappy allocates memory) - Ups the batch size to 10,000 rows - Increases the broadcast threshold to 10mb. - Uses our parquet implementation instead of the hive one by default. - Cache parquet metadata by default. Author: Michael Armbrust Closes #3064 from marmbrus/fasterDefaults and squashes the following commits: 97ee9f8 [Michael Armbrust] parquet codec docs e641694 [Michael Armbrust] Remote also a12866a [Michael Armbrust] Cache metadata. 2d73acc [Michael Armbrust] Update docs defaults. d63d2d5 [Michael Armbrust] document parquet option da373f9 [Michael Armbrust] More aggressive defaults (cherry picked from commit 25bef7e6951301e93004567fc0cef96bf8d1a224) Signed-off-by: Michael Armbrust --- docs/sql-programming-guide.md | 18 +++++++++++++----- .../scala/org/apache/spark/sql/SQLConf.scala | 10 +++++----- .../sql/parquet/ParquetTableOperations.scala | 6 +++--- .../apache/spark/sql/hive/HiveContext.scala | 2 +- 4 files changed, 22 insertions(+), 14 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index d4ade939c3a6..e399fecbbc78 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -582,19 +582,27 @@ Configuration of Parquet can be done using the `setConf` method on SQLContext or spark.sql.parquet.cacheMetadata - false + true Turns on caching of Parquet schema metadata. Can speed up querying of static data. spark.sql.parquet.compression.codec - snappy + gzip Sets the compression codec use when writing Parquet files. Acceptable values include: uncompressed, snappy, gzip, lzo. + + spark.sql.hive.convertMetastoreParquet + true + + When set to false, Spark SQL will use the Hive SerDe for parquet tables instead of the built in + support. + + ## JSON Datasets @@ -815,7 +823,7 @@ Configuration of in-memory caching can be done using the `setConf` method on SQL Property NameDefaultMeaning spark.sql.inMemoryColumnarStorage.compressed - false + true When set to true Spark SQL will automatically select a compression codec for each column based on statistics of the data. @@ -823,7 +831,7 @@ Configuration of in-memory caching can be done using the `setConf` method on SQL spark.sql.inMemoryColumnarStorage.batchSize - 1000 + 10000 Controls the size of batches for columnar caching. Larger batch sizes can improve memory utilization and compression, but risk OOMs when caching data. @@ -841,7 +849,7 @@ that these options will be deprecated in future release as more optimizations ar Property NameDefaultMeaning spark.sql.autoBroadcastJoinThreshold - 10000 + 10485760 (10 MB) Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 07e6e2eccddf..279495aa6475 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -79,13 +79,13 @@ private[sql] trait SQLConf { private[spark] def dialect: String = getConf(DIALECT, "sql") /** When true tables cached using the in-memory columnar caching will be compressed. */ - private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED, "false").toBoolean + private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED, "true").toBoolean /** The compression codec for writing to a Parquetfile */ - private[spark] def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION, "snappy") + private[spark] def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION, "gzip") /** The number of rows that will be */ - private[spark] def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE, "1000").toInt + private[spark] def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE, "10000").toInt /** Number of partitions to use for shuffle operators. */ private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS, "200").toInt @@ -106,10 +106,10 @@ private[sql] trait SQLConf { * a broadcast value during the physical executions of join operations. Setting this to -1 * effectively disables auto conversion. * - * Hive setting: hive.auto.convert.join.noconditionaltask.size, whose default value is also 10000. + * Hive setting: hive.auto.convert.join.noconditionaltask.size, whose default value is 10000. */ private[spark] def autoBroadcastJoinThreshold: Int = - getConf(AUTO_BROADCASTJOIN_THRESHOLD, "10000").toInt + getConf(AUTO_BROADCASTJOIN_THRESHOLD, (10 * 1024 * 1024).toString).toInt /** * The default size in bytes to assign to a logical operator's estimation statistics. By default, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 9664c565a0b8..d00860a8bb8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -123,7 +123,7 @@ case class ParquetTableScan( // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata conf.set( SQLConf.PARQUET_CACHE_METADATA, - sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "false")) + sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "true")) val baseRDD = new org.apache.spark.rdd.NewHadoopRDD( @@ -394,7 +394,7 @@ private[parquet] class FilteringParquetRowInputFormat if (footers eq null) { val conf = ContextUtil.getConfiguration(jobContext) - val cacheMetadata = conf.getBoolean(SQLConf.PARQUET_CACHE_METADATA, false) + val cacheMetadata = conf.getBoolean(SQLConf.PARQUET_CACHE_METADATA, true) val statuses = listStatus(jobContext) fileStatuses = statuses.map(file => file.getPath -> file).toMap if (statuses.isEmpty) { @@ -493,7 +493,7 @@ private[parquet] class FilteringParquetRowInputFormat import parquet.filter2.compat.FilterCompat.Filter; import parquet.filter2.compat.RowGroupFilter; - val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, false) + val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, true) val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] val filter: Filter = ParquetInputFormat.getFilter(configuration) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index f025169ad506..e88afaaf001c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -90,7 +90,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * SerDe. */ private[spark] def convertMetastoreParquet: Boolean = - getConf("spark.sql.hive.convertMetastoreParquet", "false") == "true" + getConf("spark.sql.hive.convertMetastoreParquet", "true") == "true" override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } From fa86d862f98cfea3d9afff6e61b3141c9b08f949 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Mon, 3 Nov 2014 15:19:01 -0800 Subject: [PATCH 008/652] SPARK-4178. Hadoop input metrics ignore bytes read in RecordReader insta... ...ntiation Author: Sandy Ryza Closes #3045 from sryza/sandy-spark-4178 and squashes the following commits: 8d2e70e [Sandy Ryza] Kostas's review feedback e5b27c0 [Sandy Ryza] SPARK-4178. Hadoop input metrics ignore bytes read in RecordReader instantiation (cherry picked from commit 28128150e7e0c2b7d1c483e67214bdaef59f7d75) Signed-off-by: Patrick Wendell --- .../org/apache/spark/rdd/HadoopRDD.scala | 25 +++++++++-------- .../org/apache/spark/rdd/NewHadoopRDD.scala | 26 +++++++++--------- .../spark/metrics/InputMetricsSuite.scala | 27 +++++++++++++++++-- 3 files changed, 53 insertions(+), 25 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 946fb5616d3e..a157e36e2286 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -211,20 +211,11 @@ class HadoopRDD[K, V]( val split = theSplit.asInstanceOf[HadoopPartition] logInfo("Input split: " + split.inputSplit) - var reader: RecordReader[K, V] = null val jobConf = getJobConf() - val inputFormat = getInputFormat(jobConf) - HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime), - context.stageId, theSplit.index, context.attemptId.toInt, jobConf) - reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) - - // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener{ context => closeIfNeeded() } - val key: K = reader.createKey() - val value: V = reader.createValue() val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) - // Find a function that will return the FileSystem bytes read by this thread. + // Find a function that will return the FileSystem bytes read by this thread. Do this before + // creating RecordReader, because RecordReader's constructor might read some bytes val bytesReadCallback = if (split.inputSplit.value.isInstanceOf[FileSplit]) { SparkHadoopUtil.get.getFSBytesReadOnThreadCallback( split.inputSplit.value.asInstanceOf[FileSplit].getPath, jobConf) @@ -234,6 +225,18 @@ class HadoopRDD[K, V]( if (bytesReadCallback.isDefined) { context.taskMetrics.inputMetrics = Some(inputMetrics) } + + var reader: RecordReader[K, V] = null + val inputFormat = getInputFormat(jobConf) + HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime), + context.stageId, theSplit.index, context.attemptId.toInt, jobConf) + reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) + + // Register an on-task-completion callback to close the input stream. + context.addTaskCompletionListener{ context => closeIfNeeded() } + val key: K = reader.createKey() + val value: V = reader.createValue() + var recordsSinceMetricsUpdate = 0 override def getNext() = { diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 6d6b86721ca7..351e145f96f9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -107,20 +107,10 @@ class NewHadoopRDD[K, V]( val split = theSplit.asInstanceOf[NewHadoopPartition] logInfo("Input split: " + split.serializableHadoopSplit) val conf = confBroadcast.value.value - val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) - val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) - val format = inputFormatClass.newInstance - format match { - case configurable: Configurable => - configurable.setConf(conf) - case _ => - } - val reader = format.createRecordReader( - split.serializableHadoopSplit.value, hadoopAttemptContext) - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) - // Find a function that will return the FileSystem bytes read by this thread. + // Find a function that will return the FileSystem bytes read by this thread. Do this before + // creating RecordReader, because RecordReader's constructor might read some bytes val bytesReadCallback = if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) { SparkHadoopUtil.get.getFSBytesReadOnThreadCallback( split.serializableHadoopSplit.value.asInstanceOf[FileSplit].getPath, conf) @@ -131,6 +121,18 @@ class NewHadoopRDD[K, V]( context.taskMetrics.inputMetrics = Some(inputMetrics) } + val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) + val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) + val format = inputFormatClass.newInstance + format match { + case configurable: Configurable => + configurable.setConf(conf) + case _ => + } + val reader = format.createRecordReader( + split.serializableHadoopSplit.value, hadoopAttemptContext) + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener(context => close()) var havePair = false diff --git a/core/src/test/scala/org/apache/spark/metrics/InputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputMetricsSuite.scala index 33bd1afea247..48c386ba0431 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputMetricsSuite.scala @@ -27,7 +27,7 @@ import scala.collection.mutable.ArrayBuffer import java.io.{FileWriter, PrintWriter, File} class InputMetricsSuite extends FunSuite with SharedSparkContext { - test("input metrics when reading text file") { + test("input metrics when reading text file with single split") { val file = new File(getClass.getSimpleName + ".txt") val pw = new PrintWriter(new FileWriter(file)) pw.println("some stuff") @@ -48,6 +48,29 @@ class InputMetricsSuite extends FunSuite with SharedSparkContext { // Wait for task end events to come in sc.listenerBus.waitUntilEmpty(500) assert(taskBytesRead.length == 2) - assert(taskBytesRead.sum == file.length()) + assert(taskBytesRead.sum >= file.length()) + } + + test("input metrics when reading text file with multiple splits") { + val file = new File(getClass.getSimpleName + ".txt") + val pw = new PrintWriter(new FileWriter(file)) + for (i <- 0 until 10000) { + pw.println("some stuff") + } + pw.close() + file.deleteOnExit() + + val taskBytesRead = new ArrayBuffer[Long]() + sc.addSparkListener(new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + taskBytesRead += taskEnd.taskMetrics.inputMetrics.get.bytesRead + } + }) + sc.textFile("file://" + file.getAbsolutePath, 2).count() + + // Wait for task end events to come in + sc.listenerBus.waitUntilEmpty(500) + assert(taskBytesRead.length == 2) + assert(taskBytesRead.sum >= file.length()) } } From 52db2b9429e00d8ed398a2432ad6a26cd1e5920c Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 3 Nov 2014 18:04:51 -0800 Subject: [PATCH 009/652] [SQL] Convert arguments to Scala UDFs Author: Michael Armbrust Closes #3077 from marmbrus/udfsWithUdts and squashes the following commits: 34b5f27 [Michael Armbrust] style 504adef [Michael Armbrust] Convert arguments to Scala UDFs (cherry picked from commit 15b58a2234ab7ba30c9c0cbb536177a3c725e350) Signed-off-by: Michael Armbrust --- .../sql/catalyst/expressions/ScalaUdf.scala | 560 ++++++++++-------- .../spark/sql/UserDefinedTypeSuite.scala | 18 +- 2 files changed, 316 insertions(+), 262 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index fa1786e74bb3..18c96da2f87f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -34,320 +34,366 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi override def toString = s"scalaUDF(${children.mkString(",")})" + // scalastyle:off + /** This method has been generated by this script (1 to 22).map { x => val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _) - val evals = (0 to x - 1).map(x => s"children($x).eval(input)").reduce(_ + ",\n " + _) + val evals = (0 to x - 1).map(x => s" ScalaReflection.convertToScala(children($x).eval(input), children($x).dataType)").reduce(_ + ",\n " + _) s""" case $x => function.asInstanceOf[($anys) => Any]( - $evals) + $evals) """ - } + }.foreach(println) */ - // scalastyle:off override def eval(input: Row): Any = { val result = children.size match { case 0 => function.asInstanceOf[() => Any]() - case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input)) + case 1 => + function.asInstanceOf[(Any) => Any]( + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType)) + + case 2 => function.asInstanceOf[(Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType)) + + case 3 => function.asInstanceOf[(Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType)) + + case 4 => function.asInstanceOf[(Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType)) + + case 5 => function.asInstanceOf[(Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType)) + + case 6 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType)) + + case 7 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType)) + + case 8 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType)) + + case 9 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType)) + + case 10 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType)) + + case 11 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType)) + + case 12 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType)) + + case 13 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType)) + + case 14 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType)) + + case 15 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType)) + + case 16 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType)) + + case 17 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType)) + + case 18 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType)) + + case 19 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input), - children(18).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), + ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType)) + + case 20 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input), - children(18).eval(input), - children(19).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), + ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), + ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType)) + + case 21 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input), - children(18).eval(input), - children(19).eval(input), - children(20).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), + ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), + ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType), + ScalaReflection.convertToScala(children(20).eval(input), children(20).dataType)) + + case 22 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input), - children(18).eval(input), - children(19).eval(input), - children(20).eval(input), - children(21).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), + ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), + ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType), + ScalaReflection.convertToScala(children(20).eval(input), children(20).dataType), + ScalaReflection.convertToScala(children(21).eval(input), children(21).dataType)) + } // scalastyle:on diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 666235e57f81..1806a1dd8202 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -60,13 +60,13 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { } class UserDefinedTypeSuite extends QueryTest { + val points = Seq( + MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), + MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) + val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points) - test("register user type: MyDenseVector for MyLabeledPoint") { - val points = Seq( - MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), - MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) - val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points) + test("register user type: MyDenseVector for MyLabeledPoint") { val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v } val labelsArrays: Array[Double] = labels.collect() assert(labelsArrays.size === 2) @@ -80,4 +80,12 @@ class UserDefinedTypeSuite extends QueryTest { assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0)))) assert(featuresArrays.contains(new MyDenseVector(Array(0.2, 2.0)))) } + + test("UDTs and UDFs") { + registerFunction("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) + pointsRDD.registerTempTable("points") + checkAnswer( + sql("SELECT testType(features) from points"), + Seq(Row(true), Row(true))) + } } From 0826eed9c84a73544e3d8289834c8b5ebac47e03 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 3 Nov 2014 18:50:37 -0800 Subject: [PATCH 010/652] [FIX][MLLIB] fix seed in BaggedPointSuite Saw Jenkins test failures due to random seeds. jkbradley manishamde Author: Xiangrui Meng Closes #3084 from mengxr/fix-baggedpoint-suite and squashes the following commits: f735a43 [Xiangrui Meng] fix seed in BaggedPointSuite (cherry picked from commit c5912ecc7b392a13089ae735c07c2d7256de36c6) Signed-off-by: Xiangrui Meng --- .../spark/mllib/tree/impl/BaggedPointSuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala index c0a62e00432a..5cb433232e71 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala @@ -30,7 +30,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext { test("BaggedPoint RDD: without subsampling") { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false) + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, 42) baggedRDD.collect().foreach { baggedPoint => assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1) } @@ -44,7 +44,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true) + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true, seed) val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, expectedStddev, epsilon = 0.01) @@ -60,7 +60,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true) + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed) val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, expectedStddev, epsilon = 0.01) @@ -75,7 +75,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false) + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false, seed) val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, expectedStddev, epsilon = 0.01) @@ -91,7 +91,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false) + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false, seed) val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, expectedStddev, epsilon = 0.01) From 42d02db86cd973cf31ceeede0c5a723238bbe746 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 3 Nov 2014 19:29:11 -0800 Subject: [PATCH 011/652] [SPARK-4192][SQL] Internal API for Python UDT Following #2919, this PR adds Python UDT (for internal use only) with tests under "pyspark.tests". Before `SQLContext.applySchema`, we check whether we need to convert user-type instances into SQL recognizable data. In the current implementation, a Python UDT must be paired with a Scala UDT for serialization on the JVM side. A following PR will add VectorUDT in MLlib for both Scala and Python. marmbrus jkbradley davies Author: Xiangrui Meng Closes #3068 from mengxr/SPARK-4192-sql and squashes the following commits: acff637 [Xiangrui Meng] merge master dba5ea7 [Xiangrui Meng] only use pyClass for Python UDT output sqlType as well 2c9d7e4 [Xiangrui Meng] move import to global setup; update needsConversion 7c4a6a9 [Xiangrui Meng] address comments 75223db [Xiangrui Meng] minor update f740379 [Xiangrui Meng] remove UDT from default imports e98d9d0 [Xiangrui Meng] fix py style 4e84fce [Xiangrui Meng] remove local hive tests and add more tests 39f19e0 [Xiangrui Meng] add tests b7f666d [Xiangrui Meng] add Python UDT (cherry picked from commit 04450d11548cfb25d4fb77d4a33e3a7cd4254183) Signed-off-by: Xiangrui Meng --- python/pyspark/sql.py | 206 +++++++++++++++++- python/pyspark/tests.py | 93 +++++++- .../spark/sql/catalyst/types/dataTypes.scala | 9 +- .../org/apache/spark/sql/SQLContext.scala | 2 + .../spark/sql/execution/pythonUdfs.scala | 5 + .../spark/sql/test/ExamplePointUDT.scala | 64 ++++++ .../sql/types/util/DataTypeConversions.scala | 1 - 7 files changed, 375 insertions(+), 5 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 675df084bf30..d16c18bc79fe 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -417,6 +417,75 @@ def fromJson(cls, json): return StructType([StructField.fromJson(f) for f in json["fields"]]) +class UserDefinedType(DataType): + """ + :: WARN: Spark Internal Use Only :: + SQL User-Defined Type (UDT). + """ + + @classmethod + def typeName(cls): + return cls.__name__.lower() + + @classmethod + def sqlType(cls): + """ + Underlying SQL storage type for this UDT. + """ + raise NotImplementedError("UDT must implement sqlType().") + + @classmethod + def module(cls): + """ + The Python module of the UDT. + """ + raise NotImplementedError("UDT must implement module().") + + @classmethod + def scalaUDT(cls): + """ + The class name of the paired Scala UDT. + """ + raise NotImplementedError("UDT must have a paired Scala UDT.") + + def serialize(self, obj): + """ + Converts the a user-type object into a SQL datum. + """ + raise NotImplementedError("UDT must implement serialize().") + + def deserialize(self, datum): + """ + Converts a SQL datum into a user-type object. + """ + raise NotImplementedError("UDT must implement deserialize().") + + def json(self): + return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) + + def jsonValue(self): + schema = { + "type": "udt", + "class": self.scalaUDT(), + "pyClass": "%s.%s" % (self.module(), type(self).__name__), + "sqlType": self.sqlType().jsonValue() + } + return schema + + @classmethod + def fromJson(cls, json): + pyUDT = json["pyClass"] + split = pyUDT.rfind(".") + pyModule = pyUDT[:split] + pyClass = pyUDT[split+1:] + m = __import__(pyModule, globals(), locals(), [pyClass], -1) + UDT = getattr(m, pyClass) + return UDT() + + def __eq__(self, other): + return type(self) == type(other) + + _all_primitive_types = dict((v.typeName(), v) for v in globals().itervalues() if type(v) is PrimitiveTypeSingleton and @@ -469,6 +538,12 @@ def _parse_datatype_json_string(json_string): ... complex_arraytype, False) >>> check_datatype(complex_maptype) True + >>> check_datatype(ExamplePointUDT()) + True + >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> check_datatype(structtype_with_udt) + True """ return _parse_datatype_json_value(json.loads(json_string)) @@ -488,7 +563,13 @@ def _parse_datatype_json_value(json_value): else: raise ValueError("Could not parse datatype: %s" % json_value) else: - return _all_complex_types[json_value["type"]].fromJson(json_value) + tpe = json_value["type"] + if tpe in _all_complex_types: + return _all_complex_types[tpe].fromJson(json_value) + elif tpe == 'udt': + return UserDefinedType.fromJson(json_value) + else: + raise ValueError("not supported type: %s" % tpe) # Mapping Python types to Spark SQL DataType @@ -509,7 +590,18 @@ def _parse_datatype_json_value(json_value): def _infer_type(obj): - """Infer the DataType from obj""" + """Infer the DataType from obj + + >>> p = ExamplePoint(1.0, 2.0) + >>> _infer_type(p) + ExamplePointUDT + """ + if obj is None: + raise ValueError("Can not infer type for None") + + if hasattr(obj, '__UDT__'): + return obj.__UDT__ + dataType = _type_mappings.get(type(obj)) if dataType is not None: return dataType() @@ -558,6 +650,93 @@ def _infer_schema(row): return StructType(fields) +def _need_python_to_sql_conversion(dataType): + """ + Checks whether we need python to sql conversion for the given type. + For now, only UDTs need this conversion. + + >>> _need_python_to_sql_conversion(DoubleType()) + False + >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), + ... StructField("values", ArrayType(DoubleType(), False), False)]) + >>> _need_python_to_sql_conversion(schema0) + False + >>> _need_python_to_sql_conversion(ExamplePointUDT()) + True + >>> schema1 = ArrayType(ExamplePointUDT(), False) + >>> _need_python_to_sql_conversion(schema1) + True + >>> schema2 = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> _need_python_to_sql_conversion(schema2) + True + """ + if isinstance(dataType, StructType): + return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields]) + elif isinstance(dataType, ArrayType): + return _need_python_to_sql_conversion(dataType.elementType) + elif isinstance(dataType, MapType): + return _need_python_to_sql_conversion(dataType.keyType) or \ + _need_python_to_sql_conversion(dataType.valueType) + elif isinstance(dataType, UserDefinedType): + return True + else: + return False + + +def _python_to_sql_converter(dataType): + """ + Returns a converter that converts a Python object into a SQL datum for the given type. + + >>> conv = _python_to_sql_converter(DoubleType()) + >>> conv(1.0) + 1.0 + >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False)) + >>> conv([1.0, 2.0]) + [1.0, 2.0] + >>> conv = _python_to_sql_converter(ExamplePointUDT()) + >>> conv(ExamplePoint(1.0, 2.0)) + [1.0, 2.0] + >>> schema = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> conv = _python_to_sql_converter(schema) + >>> conv((1.0, ExamplePoint(1.0, 2.0))) + (1.0, [1.0, 2.0]) + """ + if not _need_python_to_sql_conversion(dataType): + return lambda x: x + + if isinstance(dataType, StructType): + names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) + converters = map(_python_to_sql_converter, types) + + def converter(obj): + if isinstance(obj, dict): + return tuple(c(obj.get(n)) for n, c in zip(names, converters)) + elif isinstance(obj, tuple): + if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"): + return tuple(c(v) for c, v in zip(converters, obj)) + elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs + d = dict(obj) + return tuple(c(d.get(n)) for n, c in zip(names, converters)) + else: + return tuple(c(v) for c, v in zip(converters, obj)) + else: + raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) + return converter + elif isinstance(dataType, ArrayType): + element_converter = _python_to_sql_converter(dataType.elementType) + return lambda a: [element_converter(v) for v in a] + elif isinstance(dataType, MapType): + key_converter = _python_to_sql_converter(dataType.keyType) + value_converter = _python_to_sql_converter(dataType.valueType) + return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) + elif isinstance(dataType, UserDefinedType): + return lambda obj: dataType.serialize(obj) + else: + raise ValueError("Unexpected type %r" % dataType) + + def _has_nulltype(dt): """ Return whether there is NullType in `dt` or not """ if isinstance(dt, StructType): @@ -818,11 +997,22 @@ def _verify_type(obj, dataType): Traceback (most recent call last): ... ValueError:... + >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) + >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... """ # all objects are nullable if obj is None: return + if isinstance(dataType, UserDefinedType): + if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): + raise ValueError("%r is not an instance of type %r" % (obj, dataType)) + _verify_type(dataType.serialize(obj), dataType.sqlType()) + return + _type = type(dataType) assert _type in _acceptable_types, "unkown datatype: %s" % dataType @@ -897,6 +1087,8 @@ def _has_struct_or_date(dt): return _has_struct_or_date(dt.valueType) elif isinstance(dt, DateType): return True + elif isinstance(dt, UserDefinedType): + return True return False @@ -967,6 +1159,9 @@ def Dict(d): elif isinstance(dataType, DateType): return datetime.date + elif isinstance(dataType, UserDefinedType): + return lambda datum: dataType.deserialize(datum) + elif not isinstance(dataType, StructType): raise Exception("unexpected data type: %s" % dataType) @@ -1244,6 +1439,10 @@ def applySchema(self, rdd, schema): for row in rows: _verify_type(row, schema) + # convert python objects to sql data + converter = _python_to_sql_converter(schema) + rdd = rdd.map(converter) + batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer) jrdd = self._pythonToJava(rdd._jrdd, batched) srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) @@ -1877,6 +2076,7 @@ def _test(): # let doctest run in pyspark.sql, so DataTypes can be picklable import pyspark.sql from pyspark.sql import Row, SQLContext + from pyspark.tests import ExamplePoint, ExamplePointUDT globs = pyspark.sql.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: @@ -1888,6 +2088,8 @@ def _test(): Row(field1=2, field2="row2"), Row(field1=3, field2="row3")] ) + globs['ExamplePoint'] = ExamplePoint + globs['ExamplePointUDT'] = ExamplePointUDT jsonStrings = [ '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 68fd75687621..e947b0946810 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -49,7 +49,8 @@ from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ CloudPickleSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter -from pyspark.sql import SQLContext, IntegerType, Row, ArrayType +from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ + UserDefinedType, DoubleType from pyspark import shuffle _have_scipy = False @@ -694,8 +695,65 @@ def heavy_foo(x): self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) +class ExamplePointUDT(UserDefinedType): + """ + User-defined type (UDT) for ExamplePoint. + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + @classmethod + def module(cls): + return 'pyspark.tests' + + @classmethod + def scalaUDT(cls): + return 'org.apache.spark.sql.test.ExamplePointUDT' + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return ExamplePoint(datum[0], datum[1]) + + +class ExamplePoint: + """ + An example class to demonstrate UDT in Scala, Java, and Python. + """ + + __UDT__ = ExamplePointUDT() + + def __init__(self, x, y): + self.x = x + self.y = y + + def __repr__(self): + return "ExamplePoint(%s,%s)" % (self.x, self.y) + + def __str__(self): + return "(%s,%s)" % (self.x, self.y) + + def __eq__(self, other): + return isinstance(other, ExamplePoint) and \ + other.x == self.x and other.y == self.y + + class SQLTests(ReusedPySparkTestCase): + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name) + def setUp(self): self.sqlCtx = SQLContext(self.sc) @@ -824,6 +882,39 @@ def test_convert_row_to_dict(self): row = self.sqlCtx.sql("select l[0].a AS la from test").first() self.assertEqual(1, row.asDict()["la"]) + def test_infer_schema_with_udt(self): + from pyspark.tests import ExamplePoint, ExamplePointUDT + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + srdd = self.sqlCtx.inferSchema(rdd) + schema = srdd.schema() + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), ExamplePointUDT) + srdd.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").first().point + self.assertEqual(point, ExamplePoint(1.0, 2.0)) + + def test_apply_schema_with_udt(self): + from pyspark.tests import ExamplePoint, ExamplePointUDT + row = (1.0, ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + srdd = self.sqlCtx.applySchema(rdd, schema) + point = srdd.first().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + + def test_parquet_with_udt(self): + from pyspark.tests import ExamplePoint + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + srdd0 = self.sqlCtx.inferSchema(rdd) + output_dir = os.path.join(self.tempdir.name, "labeled_point") + srdd0.saveAsParquetFile(output_dir) + srdd1 = self.sqlCtx.parquetFile(output_dir) + point = srdd1.first().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + class InputFormatTests(ReusedPySparkTestCase): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index e1b5992a36e5..5dd19dd12d8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -71,6 +71,8 @@ object DataType { case JSortedObject( ("class", JString(udtClass)), + ("pyClass", _), + ("sqlType", _), ("type", JString("udt"))) => Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] } @@ -593,6 +595,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { /** Underlying storage type for this UDT */ def sqlType: DataType + /** Paired Python UDT class, if exists. */ + def pyUDT: String = null + /** * Convert the user type to a SQL datum * @@ -606,7 +611,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { override private[sql] def jsonValue: JValue = { ("type" -> "udt") ~ - ("class" -> this.getClass.getName) + ("class" -> this.getClass.getName) ~ + ("pyClass" -> pyUDT) ~ + ("sqlType" -> sqlType.jsonValue) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 9e61d18f7e92..84eaf401f240 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{Optimizer, DefaultOptimizer} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.types.UserDefinedType import org.apache.spark.sql.execution.{SparkStrategies, _} import org.apache.spark.sql.json._ import org.apache.spark.sql.parquet.ParquetRelation @@ -483,6 +484,7 @@ class SQLContext(@transient val sparkContext: SparkContext) case ArrayType(_, _) => true case MapType(_, _, _) => true case StructType(_) => true + case udt: UserDefinedType[_] => needsConversion(udt.sqlType) case other => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 997669051ed0..a83cf5d441d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -135,6 +135,8 @@ object EvaluatePython { case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type }.asJava + case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType) + case (dec: BigDecimal, dt: DecimalType) => dec.underlying() // Pyrolite can handle BigDecimal // Pyrolite can handle Timestamp @@ -177,6 +179,9 @@ object EvaluatePython { case (c: java.util.Calendar, TimestampType) => new java.sql.Timestamp(c.getTime().getTime()) + case (_, udt: UserDefinedType[_]) => + fromJava(obj, udt.sqlType) + case (c: Int, ByteType) => c.toByte case (c: Long, ByteType) => c.toByte case (c: Int, ShortType) => c.toShort diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala new file mode 100644 index 000000000000..b9569e96c031 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import java.util + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType +import org.apache.spark.sql.catalyst.types._ + +/** + * An example class to demonstrate UDT in Scala, Java, and Python. + * @param x x coordinate + * @param y y coordinate + */ +@SQLUserDefinedType(udt = classOf[ExamplePointUDT]) +private[sql] class ExamplePoint(val x: Double, val y: Double) + +/** + * User-defined type for [[ExamplePoint]]. + */ +private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] { + + override def sqlType: DataType = ArrayType(DoubleType, false) + + override def pyUDT: String = "pyspark.tests.ExamplePointUDT" + + override def serialize(obj: Any): Seq[Double] = { + obj match { + case p: ExamplePoint => + Seq(p.x, p.y) + } + } + + override def deserialize(datum: Any): ExamplePoint = { + datum match { + case values: Seq[_] => + val xy = values.asInstanceOf[Seq[Double]] + assert(xy.length == 2) + new ExamplePoint(xy(0), xy(1)) + case values: util.ArrayList[_] => + val xy = values.asInstanceOf[util.ArrayList[Double]].asScala + new ExamplePoint(xy(0), xy(1)) + } + } + + override def userClass: Class[ExamplePoint] = classOf[ExamplePoint] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala index 1bc15146f0fe..3fa4a7c6481d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.types.UserDefinedType - protected[sql] object DataTypeConversions { /** From 8395e8fbdf23bef286ec68a4bbadcc448b504c2c Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 3 Nov 2014 22:29:48 -0800 Subject: [PATCH 012/652] [SPARK-3573][MLLIB] Make MLlib's Vector compatible with SQL's SchemaRDD Register MLlib's Vector as a SQL user-defined type (UDT) in both Scala and Python. With this PR, we can easily map a RDD[LabeledPoint] to a SchemaRDD, and then select columns or save to a Parquet file. Examples in Scala/Python are attached. The Scala code was copied from jkbradley. ~~This PR contains the changes from #3068 . I will rebase after #3068 is merged.~~ marmbrus jkbradley Author: Xiangrui Meng Closes #3070 from mengxr/SPARK-3573 and squashes the following commits: 3a0b6e5 [Xiangrui Meng] organize imports 236f0a0 [Xiangrui Meng] register vector as UDT and provide dataset examples (cherry picked from commit 1a9c6cddadebdc53d083ac3e0da276ce979b5d1f) Signed-off-by: Xiangrui Meng --- dev/run-tests | 2 +- .../src/main/python/mllib/dataset_example.py | 62 +++++++++ .../spark/examples/mllib/DatasetExample.scala | 121 ++++++++++++++++++ mllib/pom.xml | 5 + .../apache/spark/mllib/linalg/Vectors.scala | 69 +++++++++- .../spark/mllib/linalg/VectorsSuite.scala | 11 ++ python/pyspark/mllib/linalg.py | 50 ++++++++ python/pyspark/mllib/tests.py | 39 +++++- 8 files changed, 353 insertions(+), 6 deletions(-) create mode 100644 examples/src/main/python/mllib/dataset_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala diff --git a/dev/run-tests b/dev/run-tests index 0e9eefa76a18..de607e434445 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -180,7 +180,7 @@ CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS if [ -n "$_SQL_TESTS_ONLY" ]; then # This must be an array of individual arguments. Otherwise, having one long string #+ will be interpreted as a single test, which doesn't work. - SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test") + SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "mllib/test") else SBT_MAVEN_TEST_ARGS=("test") fi diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py new file mode 100644 index 000000000000..540dae785f6e --- /dev/null +++ b/examples/src/main/python/mllib/dataset_example.py @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +An example of how to use SchemaRDD as a dataset for ML. Run with:: + bin/spark-submit examples/src/main/python/mllib/dataset_example.py +""" + +import os +import sys +import tempfile +import shutil + +from pyspark import SparkContext +from pyspark.sql import SQLContext +from pyspark.mllib.util import MLUtils +from pyspark.mllib.stat import Statistics + + +def summarize(dataset): + print "schema: %s" % dataset.schema().json() + labels = dataset.map(lambda r: r.label) + print "label average: %f" % labels.mean() + features = dataset.map(lambda r: r.features) + summary = Statistics.colStats(features) + print "features average: %r" % summary.mean() + +if __name__ == "__main__": + if len(sys.argv) > 2: + print >> sys.stderr, "Usage: dataset_example.py " + exit(-1) + sc = SparkContext(appName="DatasetExample") + sqlCtx = SQLContext(sc) + if len(sys.argv) == 2: + input = sys.argv[1] + else: + input = "data/mllib/sample_libsvm_data.txt" + points = MLUtils.loadLibSVMFile(sc, input) + dataset0 = sqlCtx.inferSchema(points).setName("dataset0").cache() + summarize(dataset0) + tempdir = tempfile.NamedTemporaryFile(delete=False).name + os.unlink(tempdir) + print "Save dataset as a Parquet file to %s." % tempdir + dataset0.saveAsParquetFile(tempdir) + print "Load it back and summarize it again." + dataset1 = sqlCtx.parquetFile(tempdir).setName("dataset1").cache() + summarize(dataset1) + shutil.rmtree(tempdir) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala new file mode 100644 index 000000000000..f8d83f4ec732 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import java.io.File + +import com.google.common.io.Files +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext, SchemaRDD} + +/** + * An example of how to use [[org.apache.spark.sql.SchemaRDD]] as a Dataset for ML. Run with + * {{{ + * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object DatasetExample { + + case class Params( + input: String = "data/mllib/sample_libsvm_data.txt", + dataFormat: String = "libsvm") extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("DatasetExample") { + head("Dataset: an example app using SchemaRDD as a Dataset for ML.") + opt[String]("input") + .text(s"input path to dataset") + .action((x, c) => c.copy(input = x)) + opt[String]("dataFormat") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(input = x)) + checkConfig { params => + success + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + + val conf = new SparkConf().setAppName(s"DatasetExample with $params") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext._ // for implicit conversions + + // Load input data + val origData: RDD[LabeledPoint] = params.dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, params.input) + case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input) + } + println(s"Loaded ${origData.count()} instances from file: ${params.input}") + + // Convert input data to SchemaRDD explicitly. + val schemaRDD: SchemaRDD = origData + println(s"Inferred schema:\n${schemaRDD.schema.prettyJson}") + println(s"Converted to SchemaRDD with ${schemaRDD.count()} records") + + // Select columns, using implicit conversion to SchemaRDD. + val labelsSchemaRDD: SchemaRDD = origData.select('label) + val labels: RDD[Double] = labelsSchemaRDD.map { case Row(v: Double) => v } + val numLabels = labels.count() + val meanLabel = labels.fold(0.0)(_ + _) / numLabels + println(s"Selected label column with average value $meanLabel") + + val featuresSchemaRDD: SchemaRDD = origData.select('features) + val features: RDD[Vector] = featuresSchemaRDD.map { case Row(v: Vector) => v } + val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( + (summary, feat) => summary.add(feat), + (sum1, sum2) => sum1.merge(sum2)) + println(s"Selected features column with average values:\n ${featureSummary.mean.toString}") + + val tmpDir = Files.createTempDir() + tmpDir.deleteOnExit() + val outputDir = new File(tmpDir, "dataset").toString + println(s"Saving to $outputDir as Parquet file.") + schemaRDD.saveAsParquetFile(outputDir) + + println(s"Loading Parquet file with UDT from $outputDir.") + val newDataset = sqlContext.parquetFile(outputDir) + + println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") + val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v } + val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())( + (summary, feat) => summary.add(feat), + (sum1, sum2) => sum1.merge(sum2)) + println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}") + + sc.stop() + } + +} diff --git a/mllib/pom.xml b/mllib/pom.xml index fb7239e779aa..87a7ddaba97f 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -45,6 +45,11 @@ spark-streaming_${scala.binary.version} ${project.version} + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + org.eclipse.jetty jetty-server diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 6af225b7f49f..ac217edc619a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -17,22 +17,26 @@ package org.apache.spark.mllib.linalg -import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable} import java.util +import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable} import scala.annotation.varargs import scala.collection.JavaConverters._ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} -import org.apache.spark.mllib.util.NumericParser import org.apache.spark.SparkException +import org.apache.spark.mllib.util.NumericParser +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row} +import org.apache.spark.sql.catalyst.types._ /** * Represents a numeric vector, whose index type is Int and value type is Double. * * Note: Users should not implement this interface. */ +@SQLUserDefinedType(udt = classOf[VectorUDT]) sealed trait Vector extends Serializable { /** @@ -74,6 +78,65 @@ sealed trait Vector extends Serializable { } } +/** + * User-defined type for [[Vector]] which allows easy interaction with SQL + * via [[org.apache.spark.sql.SchemaRDD]]. + */ +private[spark] class VectorUDT extends UserDefinedType[Vector] { + + override def sqlType: StructType = { + // type: 0 = sparse, 1 = dense + // We only use "values" for dense vectors, and "size", "indices", and "values" for sparse + // vectors. The "values" field is nullable because we might want to add binary vectors later, + // which uses "size" and "indices", but not "values". + StructType(Seq( + StructField("type", ByteType, nullable = false), + StructField("size", IntegerType, nullable = true), + StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = true), + StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true))) + } + + override def serialize(obj: Any): Row = { + val row = new GenericMutableRow(4) + obj match { + case sv: SparseVector => + row.setByte(0, 0) + row.setInt(1, sv.size) + row.update(2, sv.indices.toSeq) + row.update(3, sv.values.toSeq) + case dv: DenseVector => + row.setByte(0, 1) + row.setNullAt(1) + row.setNullAt(2) + row.update(3, dv.values.toSeq) + } + row + } + + override def deserialize(datum: Any): Vector = { + datum match { + case row: Row => + require(row.length == 4, + s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4") + val tpe = row.getByte(0) + tpe match { + case 0 => + val size = row.getInt(1) + val indices = row.getAs[Iterable[Int]](2).toArray + val values = row.getAs[Iterable[Double]](3).toArray + new SparseVector(size, indices, values) + case 1 => + val values = row.getAs[Iterable[Double]](3).toArray + new DenseVector(values) + } + } + } + + override def pyUDT: String = "pyspark.mllib.linalg.VectorUDT" + + override def userClass: Class[Vector] = classOf[Vector] +} + /** * Factory methods for [[org.apache.spark.mllib.linalg.Vector]]. * We don't use the name `Vector` because Scala imports @@ -191,6 +254,7 @@ object Vectors { /** * A dense vector represented by a value array. */ +@SQLUserDefinedType(udt = classOf[VectorUDT]) class DenseVector(val values: Array[Double]) extends Vector { override def size: Int = values.length @@ -215,6 +279,7 @@ class DenseVector(val values: Array[Double]) extends Vector { * @param indices index array, assume to be strictly increasing. * @param values value array, must have the same length as the index array. */ +@SQLUserDefinedType(udt = classOf[VectorUDT]) class SparseVector( override val size: Int, val indices: Array[Int], diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index cd651fe2d2dd..93a84fe07b32 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -155,4 +155,15 @@ class VectorsSuite extends FunSuite { throw new RuntimeException(s"copy returned ${dvCopy.getClass} on ${dv.getClass}.") } } + + test("VectorUDT") { + val dv0 = Vectors.dense(Array.empty[Double]) + val dv1 = Vectors.dense(1.0, 2.0) + val sv0 = Vectors.sparse(2, Array.empty, Array.empty) + val sv1 = Vectors.sparse(2, Array(1), Array(2.0)) + val udt = new VectorUDT() + for (v <- Seq(dv0, dv1, sv0, sv1)) { + assert(v === udt.deserialize(udt.serialize(v))) + } + } } diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index d0a0e102a1a0..c0c3dff31e7f 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -29,6 +29,9 @@ import numpy as np +from pyspark.sql import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \ + IntegerType, ByteType, Row + __all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors'] @@ -106,7 +109,54 @@ def _format_float(f, digits=4): return s +class VectorUDT(UserDefinedType): + """ + SQL user-defined type (UDT) for Vector. + """ + + @classmethod + def sqlType(cls): + return StructType([ + StructField("type", ByteType(), False), + StructField("size", IntegerType(), True), + StructField("indices", ArrayType(IntegerType(), False), True), + StructField("values", ArrayType(DoubleType(), False), True)]) + + @classmethod + def module(cls): + return "pyspark.mllib.linalg" + + @classmethod + def scalaUDT(cls): + return "org.apache.spark.mllib.linalg.VectorUDT" + + def serialize(self, obj): + if isinstance(obj, SparseVector): + indices = [int(i) for i in obj.indices] + values = [float(v) for v in obj.values] + return (0, obj.size, indices, values) + elif isinstance(obj, DenseVector): + values = [float(v) for v in obj] + return (1, None, None, values) + else: + raise ValueError("cannot serialize %r of type %r" % (obj, type(obj))) + + def deserialize(self, datum): + assert len(datum) == 4, \ + "VectorUDT.deserialize given row with length %d but requires 4" % len(datum) + tpe = datum[0] + if tpe == 0: + return SparseVector(datum[1], datum[2], datum[3]) + elif tpe == 1: + return DenseVector(datum[3]) + else: + raise ValueError("do not recognize type %r" % tpe) + + class Vector(object): + + __UDT__ = VectorUDT() + """ Abstract class for DenseVector and SparseVector """ diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index d6fb87b378b4..9fa4d6f6a2f5 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -33,14 +33,14 @@ else: import unittest -from pyspark.serializers import PickleSerializer -from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, _convert_to_vector +from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics +from pyspark.serializers import PickleSerializer +from pyspark.sql import SQLContext from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase - _have_scipy = False try: import scipy.sparse @@ -221,6 +221,39 @@ def test_col_with_different_rdds(self): self.assertEqual(10, summary.count()) +class VectorUDTTests(PySparkTestCase): + + dv0 = DenseVector([]) + dv1 = DenseVector([1.0, 2.0]) + sv0 = SparseVector(2, [], []) + sv1 = SparseVector(2, [1], [2.0]) + udt = VectorUDT() + + def test_json_schema(self): + self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt) + + def test_serialization(self): + for v in [self.dv0, self.dv1, self.sv0, self.sv1]: + self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v))) + + def test_infer_schema(self): + sqlCtx = SQLContext(self.sc) + rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)]) + srdd = sqlCtx.inferSchema(rdd) + schema = srdd.schema() + field = [f for f in schema.fields if f.name == "features"][0] + self.assertEqual(field.dataType, self.udt) + vectors = srdd.map(lambda p: p.features).collect() + self.assertEqual(len(vectors), 2) + for v in vectors: + if isinstance(v, SparseVector): + self.assertEqual(v, self.sv1) + elif isinstance(v, DenseVector): + self.assertEqual(v, self.dv1) + else: + raise ValueError("expecting a vector but got %r of type %r" % (v, type(v))) + + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): From 786e75b33f0bc1445bfc289fe4b62407cb79026e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 3 Nov 2014 23:56:14 -0800 Subject: [PATCH 013/652] [SPARK-3886] [PySpark] simplify serializer, use AutoBatchedSerializer by default. This PR simplify serializer, always use batched serializer (AutoBatchedSerializer as default), even batch size is 1. Author: Davies Liu This patch had conflicts when merged, resolved by Committer: Josh Rosen Closes #2920 from davies/fix_autobatch and squashes the following commits: e544ef9 [Davies Liu] revert unrelated change 6880b14 [Davies Liu] Merge branch 'master' of github.com:apache/spark into fix_autobatch 1d557fc [Davies Liu] fix tests 8180907 [Davies Liu] Merge branch 'master' of github.com:apache/spark into fix_autobatch 76abdce [Davies Liu] clean up 53fa60b [Davies Liu] Merge branch 'master' of github.com:apache/spark into fix_autobatch d7ac751 [Davies Liu] Merge branch 'master' of github.com:apache/spark into fix_autobatch 2cc2497 [Davies Liu] Merge branch 'master' of github.com:apache/spark into fix_autobatch b4292ce [Davies Liu] fix bug in master d79744c [Davies Liu] recover hive tests be37ece [Davies Liu] refactor eb3938d [Davies Liu] refactor serializer in scala 8d77ef2 [Davies Liu] simplify serializer, use AutoBatchedSerializer by default. (cherry picked from commit e4f42631a68b473ce706429915f3f08042af2119) Signed-off-by: Josh Rosen --- .../spark/api/python/PythonHadoopUtil.scala | 6 +- .../apache/spark/api/python/PythonRDD.scala | 110 +--------------- .../apache/spark/api/python/SerDeUtil.scala | 121 +++++++++++++----- .../WriteInputFormatTestDataGenerator.scala | 10 +- .../mllib/api/python/PythonMLLibAPI.scala | 2 +- python/pyspark/context.py | 58 +++------ python/pyspark/mllib/common.py | 2 +- python/pyspark/mllib/recommendation.py | 2 +- python/pyspark/rdd.py | 91 ++++++------- python/pyspark/serializers.py | 36 ++---- python/pyspark/shuffle.py | 7 +- python/pyspark/sql.py | 18 +-- python/pyspark/tests.py | 66 ++-------- .../org/apache/spark/sql/SchemaRDD.scala | 10 +- 14 files changed, 201 insertions(+), 338 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index 49dc95f349ea..5ba66178e2b7 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -61,8 +61,7 @@ private[python] object Converter extends Logging { * Other objects are passed through without conversion. */ private[python] class WritableToJavaConverter( - conf: Broadcast[SerializableWritable[Configuration]], - batchSize: Int) extends Converter[Any, Any] { + conf: Broadcast[SerializableWritable[Configuration]]) extends Converter[Any, Any] { /** * Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or @@ -94,8 +93,7 @@ private[python] class WritableToJavaConverter( map.put(convertWritable(k), convertWritable(v)) } map - case w: Writable => - if (batchSize > 1) WritableUtils.clone(w, conf.value.value) else w + case w: Writable => WritableUtils.clone(w, conf.value.value) case other => other } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 61b125ef7c6c..e94ccdcd47bb 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -22,12 +22,10 @@ import java.net._ import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} import scala.collection.JavaConversions._ -import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.existentials import com.google.common.base.Charsets.UTF_8 -import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.compress.CompressionCodec @@ -442,7 +440,7 @@ private[spark] object PythonRDD extends Logging { val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(sc.hadoopConfiguration())) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -468,7 +466,7 @@ private[spark] object PythonRDD extends Logging { Some(path), inputFormatClass, keyClass, valueClass, mergedConf) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -494,7 +492,7 @@ private[spark] object PythonRDD extends Logging { None, inputFormatClass, keyClass, valueClass, conf) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -537,7 +535,7 @@ private[spark] object PythonRDD extends Logging { Some(path), inputFormatClass, keyClass, valueClass, mergedConf) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -563,7 +561,7 @@ private[spark] object PythonRDD extends Logging { None, inputFormatClass, keyClass, valueClass, conf) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -746,104 +744,6 @@ private[spark] object PythonRDD extends Logging { converted.saveAsHadoopDataset(new JobConf(conf)) } } - - - /** - * Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions). - */ - @deprecated("PySpark does not use it anymore", "1.1") - def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = { - pyRDD.rdd.mapPartitions { iter => - val unpickle = new Unpickler - SerDeUtil.initialize() - iter.flatMap { row => - unpickle.loads(row) match { - // in case of objects are pickled in batch mode - case objs: JArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap) - // not in batch mode - case obj: JMap[String @unchecked, _] => Seq(obj.toMap) - } - } - } - } - - /** - * Convert an RDD of serialized Python tuple to Array (no recursive conversions). - * It is only used by pyspark.sql. - */ - def pythonToJavaArray(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Array[_]] = { - - def toArray(obj: Any): Array[_] = { - obj match { - case objs: JArrayList[_] => - objs.toArray - case obj if obj.getClass.isArray => - obj.asInstanceOf[Array[_]].toArray - } - } - - pyRDD.rdd.mapPartitions { iter => - val unpickle = new Unpickler - iter.flatMap { row => - val obj = unpickle.loads(row) - if (batched) { - obj.asInstanceOf[JArrayList[_]].map(toArray) - } else { - Seq(toArray(obj)) - } - } - }.toJavaRDD() - } - - private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] { - private val pickle = new Pickler() - private var batch = 1 - private val buffer = new mutable.ArrayBuffer[Any] - - override def hasNext(): Boolean = iter.hasNext - - override def next(): Array[Byte] = { - while (iter.hasNext && buffer.length < batch) { - buffer += iter.next() - } - val bytes = pickle.dumps(buffer.toArray) - val size = bytes.length - // let 1M < size < 10M - if (size < 1024 * 1024) { - batch *= 2 - } else if (size > 1024 * 1024 * 10 && batch > 1) { - batch /= 2 - } - buffer.clear() - bytes - } - } - - /** - * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by - * PySpark. - */ - def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = { - jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) } - } - - /** - * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark. - */ - def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = { - pyRDD.rdd.mapPartitions { iter => - SerDeUtil.initialize() - val unpickle = new Unpickler - iter.flatMap { row => - val obj = unpickle.loads(row) - if (batched) { - obj.asInstanceOf[JArrayList[_]].asScala - } else { - Seq(obj) - } - } - }.toJavaRDD() - } } private diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index ebdc3533e099..a4153aaa926f 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -18,8 +18,13 @@ package org.apache.spark.api.python import java.nio.ByteOrder +import java.util.{ArrayList => JArrayList} + +import org.apache.spark.api.java.JavaRDD import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.util.Failure import scala.util.Try @@ -89,6 +94,73 @@ private[spark] object SerDeUtil extends Logging { } initialize() + + /** + * Convert an RDD of Java objects to Array (no recursive conversions). + * It is only used by pyspark.sql. + */ + def toJavaArray(jrdd: JavaRDD[Any]): JavaRDD[Array[_]] = { + jrdd.rdd.map { + case objs: JArrayList[_] => + objs.toArray + case obj if obj.getClass.isArray => + obj.asInstanceOf[Array[_]].toArray + }.toJavaRDD() + } + + /** + * Choose batch size based on size of objects + */ + private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] { + private val pickle = new Pickler() + private var batch = 1 + private val buffer = new mutable.ArrayBuffer[Any] + + override def hasNext: Boolean = iter.hasNext + + override def next(): Array[Byte] = { + while (iter.hasNext && buffer.length < batch) { + buffer += iter.next() + } + val bytes = pickle.dumps(buffer.toArray) + val size = bytes.length + // let 1M < size < 10M + if (size < 1024 * 1024) { + batch *= 2 + } else if (size > 1024 * 1024 * 10 && batch > 1) { + batch /= 2 + } + buffer.clear() + bytes + } + } + + /** + * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by + * PySpark. + */ + private[spark] def javaToPython(jRDD: JavaRDD[_]): JavaRDD[Array[Byte]] = { + jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) } + } + + /** + * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark. + */ + def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = { + pyRDD.rdd.mapPartitions { iter => + initialize() + val unpickle = new Unpickler + iter.flatMap { row => + val obj = unpickle.loads(row) + if (batched) { + obj.asInstanceOf[JArrayList[_]].asScala + } else { + Seq(obj) + } + } + }.toJavaRDD() + } + private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = { val pickle = new Pickler val kt = Try { @@ -128,17 +200,18 @@ private[spark] object SerDeUtil extends Logging { */ def pairRDDToPython(rdd: RDD[(Any, Any)], batchSize: Int): RDD[Array[Byte]] = { val (keyFailed, valueFailed) = checkPickle(rdd.first()) + rdd.mapPartitions { iter => - val pickle = new Pickler val cleaned = iter.map { case (k, v) => val key = if (keyFailed) k.toString else k val value = if (valueFailed) v.toString else v Array[Any](key, value) } - if (batchSize > 1) { - cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched))) + if (batchSize == 0) { + new AutoBatchedPickler(cleaned) } else { - cleaned.map(pickle.dumps(_)) + val pickle = new Pickler + cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched))) } } } @@ -146,36 +219,22 @@ private[spark] object SerDeUtil extends Logging { /** * Convert an RDD of serialized Python tuple (K, V) to RDD[(K, V)]. */ - def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batchSerialized: Boolean): RDD[(K, V)] = { + def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batched: Boolean): RDD[(K, V)] = { def isPair(obj: Any): Boolean = { - Option(obj.getClass.getComponentType).map(!_.isPrimitive).getOrElse(false) && + Option(obj.getClass.getComponentType).exists(!_.isPrimitive) && obj.asInstanceOf[Array[_]].length == 2 } - pyRDD.mapPartitions { iter => - initialize() - val unpickle = new Unpickler - val unpickled = - if (batchSerialized) { - iter.flatMap { batch => - unpickle.loads(batch) match { - case objs: java.util.List[_] => collectionAsScalaIterable(objs) - case other => throw new SparkException( - s"Unexpected type ${other.getClass.getName} for batch serialized Python RDD") - } - } - } else { - iter.map(unpickle.loads(_)) - } - unpickled.map { - case obj if isPair(obj) => - // we only accept (K, V) - val arr = obj.asInstanceOf[Array[_]] - (arr.head.asInstanceOf[K], arr.last.asInstanceOf[V]) - case other => throw new SparkException( - s"RDD element of type ${other.getClass.getName} cannot be used") - } + + val rdd = pythonToJava(pyRDD, batched).rdd + rdd.first match { + case obj if isPair(obj) => + // we only accept (K, V) + case other => throw new SparkException( + s"RDD element of type ${other.getClass.getName} cannot be used") + } + rdd.map { obj => + val arr = obj.asInstanceOf[Array[_]] + (arr.head.asInstanceOf[K], arr.last.asInstanceOf[V]) } } - } - diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala index e9ca9166eb4d..c0cbd28a845b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala +++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala @@ -176,11 +176,11 @@ object WriteInputFormatTestDataGenerator { // Create test data for arbitrary custom writable TestWritable val testClass = Seq( - ("1", TestWritable("test1", 123, 54.0)), - ("2", TestWritable("test2", 456, 8762.3)), - ("1", TestWritable("test3", 123, 423.1)), - ("3", TestWritable("test56", 456, 423.5)), - ("2", TestWritable("test2", 123, 5435.2)) + ("1", TestWritable("test1", 1, 1.0)), + ("2", TestWritable("test2", 2, 2.3)), + ("3", TestWritable("test3", 3, 3.1)), + ("5", TestWritable("test56", 5, 5.5)), + ("4", TestWritable("test4", 4, 4.2)) ) val rdd = sc.parallelize(testClass, numSlices = 2).map{ case (k, v) => (new Text(k), v) } rdd.saveAsNewAPIHadoopFile(classPath, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index acdc67ddc660..65b98a8ceea5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -736,7 +736,7 @@ private[spark] object SerDe extends Serializable { def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = { jRDD.rdd.mapPartitions { iter => initialize() // let it called in executor - new PythonRDD.AutoBatchedPickler(iter) + new SerDeUtil.AutoBatchedPickler(iter) } } diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 5f8dcedb1eea..a0e4821728c8 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -63,7 +63,6 @@ class SparkContext(object): _active_spark_context = None _lock = Lock() _python_includes = None # zip and egg files that need to be added to PYTHONPATH - _default_batch_size_for_serialized_input = 10 def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, environment=None, batchSize=0, serializer=PickleSerializer(), conf=None, @@ -115,9 +114,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self._conf = conf or SparkConf(_jvm=self._jvm) self._batchSize = batchSize # -1 represents an unlimited batch size self._unbatched_serializer = serializer - if batchSize == 1: - self.serializer = self._unbatched_serializer - elif batchSize == 0: + if batchSize == 0: self.serializer = AutoBatchedSerializer(self._unbatched_serializer) else: self.serializer = BatchedSerializer(self._unbatched_serializer, @@ -305,12 +302,8 @@ def parallelize(self, c, numSlices=None): # Make sure we distribute data evenly if it's smaller than self.batchSize if "__len__" not in dir(c): c = list(c) # Make it a list so we can compute its length - batchSize = min(len(c) // numSlices, self._batchSize) - if batchSize > 1: - serializer = BatchedSerializer(self._unbatched_serializer, - batchSize) - else: - serializer = self._unbatched_serializer + batchSize = max(1, min(len(c) // numSlices, self._batchSize)) + serializer = BatchedSerializer(self._unbatched_serializer, batchSize) serializer.dump_stream(c, tempFile) tempFile.close() readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile @@ -328,8 +321,7 @@ def pickleFile(self, name, minPartitions=None): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] """ minPartitions = minPartitions or self.defaultMinPartitions - return RDD(self._jsc.objectFile(name, minPartitions), self, - BatchedSerializer(PickleSerializer())) + return RDD(self._jsc.objectFile(name, minPartitions), self) def textFile(self, name, minPartitions=None, use_unicode=True): """ @@ -405,7 +397,7 @@ def _dictToJavaMap(self, d): return jm def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None, - valueConverter=None, minSplits=None, batchSize=None): + valueConverter=None, minSplits=None, batchSize=0): """ Read a Hadoop SequenceFile with arbitrary key and value Writable class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -427,17 +419,15 @@ def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None, :param minSplits: minimum splits in dataset (default min(2, sc.defaultParallelism)) :param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + Java object. (default 0, choose batchSize automatically) """ minSplits = minSplits or min(self.defaultParallelism, 2) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.sequenceFile(self._jsc, path, keyClass, valueClass, keyConverter, valueConverter, minSplits, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None, batchSize=None): + valueConverter=None, conf=None, batchSize=0): """ Read a 'new API' Hadoop InputFormat with arbitrary key and value class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -458,18 +448,16 @@ def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConv :param conf: Hadoop configuration, passed in as a dict (None by default) :param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + Java object. (default 0, choose batchSize automatically) """ jconf = self._dictToJavaMap(conf) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.newAPIHadoopFile(self._jsc, path, inputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None, batchSize=None): + valueConverter=None, conf=None, batchSize=0): """ Read a 'new API' Hadoop InputFormat with arbitrary key and value class, from an arbitrary Hadoop configuration, which is passed in as a Python dict. @@ -487,18 +475,16 @@ def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=N :param conf: Hadoop configuration, passed in as a dict (None by default) :param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + Java object. (default 0, choose batchSize automatically) """ jconf = self._dictToJavaMap(conf) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.newAPIHadoopRDD(self._jsc, inputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None, batchSize=None): + valueConverter=None, conf=None, batchSize=0): """ Read an 'old' Hadoop InputFormat with arbitrary key and value class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -519,18 +505,16 @@ def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter= :param conf: Hadoop configuration, passed in as a dict (None by default) :param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + Java object. (default 0, choose batchSize automatically) """ jconf = self._dictToJavaMap(conf) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.hadoopFile(self._jsc, path, inputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None, batchSize=None): + valueConverter=None, conf=None, batchSize=0): """ Read an 'old' Hadoop InputFormat with arbitrary key and value class, from an arbitrary Hadoop configuration, which is passed in as a Python dict. @@ -548,15 +532,13 @@ def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, :param conf: Hadoop configuration, passed in as a dict (None by default) :param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + Java object. (default 0, choose batchSize automatically) """ jconf = self._dictToJavaMap(conf) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.hadoopRDD(self._jsc, inputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def _checkpointFile(self, name, input_deserializer): jrdd = self._jsc.checkpointFile(name) @@ -836,7 +818,7 @@ def _test(): import doctest import tempfile globs = globals().copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['sc'] = SparkContext('local[4]', 'PythonTest') globs['tempdir'] = tempfile.mkdtemp() atexit.register(lambda: shutil.rmtree(globs['tempdir'])) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index 76864d816358..dbe5f698b734 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -96,7 +96,7 @@ def _java2py(sc, r): if clsName == 'JavaRDD': jrdd = sc._jvm.SerDe.javaToPython(r) - return RDD(jrdd, sc, AutoBatchedSerializer(PickleSerializer())) + return RDD(jrdd, sc) elif isinstance(r, (JavaArray, JavaList)) or clsName in _picklable_classes: r = sc._jvm.SerDe.dumps(r) diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 6b32af07c9be..e8b998414d31 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -117,7 +117,7 @@ def _test(): import doctest import pyspark.mllib.recommendation globs = pyspark.mllib.recommendation.__dict__.copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['sc'] = SparkContext('local[4]', 'PythonTest') (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 4f025b9f1170..879655dc53f4 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -120,7 +120,7 @@ class RDD(object): operated on in parallel. """ - def __init__(self, jrdd, ctx, jrdd_deserializer): + def __init__(self, jrdd, ctx, jrdd_deserializer=AutoBatchedSerializer(PickleSerializer())): self._jrdd = jrdd self.is_cached = False self.is_checkpointed = False @@ -129,12 +129,8 @@ def __init__(self, jrdd, ctx, jrdd_deserializer): self._id = jrdd.id() self._partitionFunc = None - def _toPickleSerialization(self): - if (self._jrdd_deserializer == PickleSerializer() or - self._jrdd_deserializer == BatchedSerializer(PickleSerializer())): - return self - else: - return self._reserialize(BatchedSerializer(PickleSerializer(), 10)) + def _pickled(self): + return self._reserialize(AutoBatchedSerializer(PickleSerializer())) def id(self): """ @@ -446,12 +442,11 @@ def intersection(self, other): def _reserialize(self, serializer=None): serializer = serializer or self.ctx.serializer - if self._jrdd_deserializer == serializer: - return self - else: - converted = self.map(lambda x: x, preservesPartitioning=True) - converted._jrdd_deserializer = serializer - return converted + if self._jrdd_deserializer != serializer: + if not isinstance(self, PipelinedRDD): + self = self.map(lambda x: x, preservesPartitioning=True) + self._jrdd_deserializer = serializer + return self def __add__(self, other): """ @@ -1120,9 +1115,8 @@ def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None :param valueConverter: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, batched, jconf, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, True, jconf, keyConverter, valueConverter, True) def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=None, @@ -1147,9 +1141,8 @@ def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueCl :param conf: Hadoop job configuration, passed in as a dict (None by default) """ jconf = self.ctx._dictToJavaMap(conf) - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsNewAPIHadoopFile(pickledRDD._jrdd, batched, path, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsNewAPIHadoopFile(pickledRDD._jrdd, True, path, outputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf) @@ -1166,9 +1159,8 @@ def saveAsHadoopDataset(self, conf, keyConverter=None, valueConverter=None): :param valueConverter: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, batched, jconf, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, True, jconf, keyConverter, valueConverter, False) def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=None, @@ -1195,9 +1187,8 @@ def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=No :param compressionCodecClass: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsHadoopFile(pickledRDD._jrdd, batched, path, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsHadoopFile(pickledRDD._jrdd, True, path, outputFormatClass, keyClass, valueClass, keyConverter, valueConverter, @@ -1215,9 +1206,8 @@ def saveAsSequenceFile(self, path, compressionCodecClass=None): :param path: path to sequence file :param compressionCodecClass: (None by default) """ - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsSequenceFile(pickledRDD._jrdd, batched, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsSequenceFile(pickledRDD._jrdd, True, path, compressionCodecClass) def saveAsPickleFile(self, path, batchSize=10): @@ -1232,8 +1222,11 @@ def saveAsPickleFile(self, path, batchSize=10): >>> sorted(sc.pickleFile(tmpFile.name, 5).collect()) [1, 2, 'rdd', 'spark'] """ - self._reserialize(BatchedSerializer(PickleSerializer(), - batchSize))._jrdd.saveAsObjectFile(path) + if batchSize == 0: + ser = AutoBatchedSerializer(PickleSerializer()) + else: + ser = BatchedSerializer(PickleSerializer(), batchSize) + self._reserialize(ser)._jrdd.saveAsObjectFile(path) def saveAsTextFile(self, path): """ @@ -1774,13 +1767,10 @@ def zip(self, other): >>> x.zip(y).collect() [(0, 1000), (1, 1001), (2, 1002), (3, 1003), (4, 1004)] """ - if self.getNumPartitions() != other.getNumPartitions(): - raise ValueError("Can only zip with RDD which has the same number of partitions") - def get_batch_size(ser): if isinstance(ser, BatchedSerializer): return ser.batchSize - return 0 + return 1 def batch_as(rdd, batchSize): ser = rdd._jrdd_deserializer @@ -1790,12 +1780,16 @@ def batch_as(rdd, batchSize): my_batch = get_batch_size(self._jrdd_deserializer) other_batch = get_batch_size(other._jrdd_deserializer) - if my_batch != other_batch: - # use the greatest batchSize to batch the other one. - if my_batch > other_batch: - other = batch_as(other, my_batch) - else: - self = batch_as(self, other_batch) + # use the smallest batchSize for both of them + batchSize = min(my_batch, other_batch) + if batchSize <= 0: + # auto batched or unlimited + batchSize = 100 + other = batch_as(other, batchSize) + self = batch_as(self, batchSize) + + if self.getNumPartitions() != other.getNumPartitions(): + raise ValueError("Can only zip with RDD which has the same number of partitions") # There will be an Exception in JVM if there are different number # of items in each partitions. @@ -1934,25 +1928,14 @@ def lookup(self, key): return values.collect() - def _is_pickled(self): - """ Return this RDD is serialized by Pickle or not. """ - der = self._jrdd_deserializer - if isinstance(der, PickleSerializer): - return True - if isinstance(der, BatchedSerializer) and isinstance(der.serializer, PickleSerializer): - return True - return False - def _to_java_object_rdd(self): """ Return an JavaRDD of Object by unpickling It will convert each Python object into Java object by Pyrolite, whenever the RDD is serialized in batch or not. """ - rdd = self._reserialize(AutoBatchedSerializer(PickleSerializer())) \ - if not self._is_pickled() else self - is_batch = isinstance(rdd._jrdd_deserializer, BatchedSerializer) - return self.ctx._jvm.PythonRDD.pythonToJava(rdd._jrdd, is_batch) + rdd = self._pickled() + return self.ctx._jvm.SerDeUtil.pythonToJava(rdd._jrdd, True) def countApprox(self, timeout, confidence=0.95): """ @@ -2132,7 +2115,7 @@ def _test(): globs = globals().copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['sc'] = SparkContext('local[4]', 'PythonTest') (failure_count, test_count) = doctest.testmod( globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 904bd9f2652d..d597cbf94e1b 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -33,9 +33,8 @@ [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] >>> sc.stop() -By default, PySpark serialize objects in batches; the batch size can be -controlled through SparkContext's C{batchSize} parameter -(the default size is 1024 objects): +PySpark serialize objects in batches; By default, the batch size is chosen based +on the size of objects, also configurable by SparkContext's C{batchSize} parameter: >>> sc = SparkContext('local', 'test', batchSize=2) >>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) @@ -48,16 +47,6 @@ >>> rdd._jrdd.count() 8L >>> sc.stop() - -A batch size of -1 uses an unlimited batch size, and a size of 1 disables -batching: - ->>> sc = SparkContext('local', 'test', batchSize=1) ->>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) ->>> rdd.glom().collect() -[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] ->>> rdd._jrdd.count() -16L """ import cPickle @@ -73,7 +62,7 @@ from pyspark import cloudpickle -__all__ = ["PickleSerializer", "MarshalSerializer"] +__all__ = ["PickleSerializer", "MarshalSerializer", "UTF8Deserializer"] class SpecialLengths(object): @@ -113,7 +102,7 @@ def __ne__(self, other): return not self.__eq__(other) def __repr__(self): - return "<%s object>" % self.__class__.__name__ + return "%s()" % self.__class__.__name__ def __hash__(self): return hash(str(self)) @@ -181,6 +170,7 @@ class BatchedSerializer(Serializer): """ UNLIMITED_BATCH_SIZE = -1 + UNKNOWN_BATCH_SIZE = 0 def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE): self.serializer = serializer @@ -213,10 +203,10 @@ def _load_stream_without_unbatching(self, stream): def __eq__(self, other): return (isinstance(other, BatchedSerializer) and - other.serializer == self.serializer) + other.serializer == self.serializer and other.batchSize == self.batchSize) def __repr__(self): - return "BatchedSerializer<%s>" % str(self.serializer) + return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize) class AutoBatchedSerializer(BatchedSerializer): @@ -225,7 +215,7 @@ class AutoBatchedSerializer(BatchedSerializer): """ def __init__(self, serializer, bestSize=1 << 16): - BatchedSerializer.__init__(self, serializer, -1) + BatchedSerializer.__init__(self, serializer, self.UNKNOWN_BATCH_SIZE) self.bestSize = bestSize def dump_stream(self, iterator, stream): @@ -248,10 +238,10 @@ def dump_stream(self, iterator, stream): def __eq__(self, other): return (isinstance(other, AutoBatchedSerializer) and - other.serializer == self.serializer) + other.serializer == self.serializer and other.bestSize == self.bestSize) def __str__(self): - return "AutoBatchedSerializer<%s>" % str(self.serializer) + return "AutoBatchedSerializer(%s)" % str(self.serializer) class CartesianDeserializer(FramedSerializer): @@ -284,7 +274,7 @@ def __eq__(self, other): self.key_ser == other.key_ser and self.val_ser == other.val_ser) def __repr__(self): - return "CartesianDeserializer<%s, %s>" % \ + return "CartesianDeserializer(%s, %s)" % \ (str(self.key_ser), str(self.val_ser)) @@ -311,7 +301,7 @@ def __eq__(self, other): self.key_ser == other.key_ser and self.val_ser == other.val_ser) def __repr__(self): - return "PairDeserializer<%s, %s>" % (str(self.key_ser), str(self.val_ser)) + return "PairDeserializer(%s, %s)" % (str(self.key_ser), str(self.val_ser)) class NoOpSerializer(FramedSerializer): @@ -430,7 +420,7 @@ def loads(self, obj): class AutoSerializer(FramedSerializer): """ - Choose marshal or cPickle as serialization protocol autumatically + Choose marshal or cPickle as serialization protocol automatically """ def __init__(self): diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index d57a802e4734..5931e923c2e3 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -25,7 +25,7 @@ import random import pyspark.heapq3 as heapq -from pyspark.serializers import BatchedSerializer, PickleSerializer +from pyspark.serializers import AutoBatchedSerializer, PickleSerializer try: import psutil @@ -213,8 +213,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None, Merger.__init__(self, aggregator) self.memory_limit = memory_limit # default serializer is only used for tests - self.serializer = serializer or \ - BatchedSerializer(PickleSerializer(), 1024) + self.serializer = serializer or AutoBatchedSerializer(PickleSerializer()) self.localdirs = localdirs or _get_local_dirs(str(id(self))) # number of partitions when spill data into disks self.partitions = partitions @@ -470,7 +469,7 @@ class ExternalSorter(object): def __init__(self, memory_limit, serializer=None): self.memory_limit = memory_limit self.local_dirs = _get_local_dirs("sort") - self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024) + self.serializer = serializer or AutoBatchedSerializer(PickleSerializer()) def _get_path(self, n): """ Choose one directory for spill by number n """ diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index d16c18bc79fe..e5d62a466cab 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -44,7 +44,8 @@ from py4j.java_collections import ListConverter, MapConverter from pyspark.rdd import RDD -from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer +from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \ + CloudPickleSerializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync @@ -1233,7 +1234,6 @@ def __init__(self, sparkContext, sqlContext=None): self._sc = sparkContext self._jsc = self._sc._jsc self._jvm = self._sc._jvm - self._pythonToJava = self._jvm.PythonRDD.pythonToJavaArray self._scala_SQLContext = sqlContext @property @@ -1263,8 +1263,8 @@ def registerFunction(self, name, f, returnType=StringType()): """ func = lambda _, it: imap(lambda x: f(*x), it) command = (func, None, - BatchedSerializer(PickleSerializer(), 1024), - BatchedSerializer(PickleSerializer(), 1024)) + AutoBatchedSerializer(PickleSerializer()), + AutoBatchedSerializer(PickleSerializer())) ser = CloudPickleSerializer() pickled_command = ser.dumps(command) if len(pickled_command) > (1 << 20): # 1M @@ -1443,8 +1443,7 @@ def applySchema(self, rdd, schema): converter = _python_to_sql_converter(schema) rdd = rdd.map(converter) - batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer) - jrdd = self._pythonToJava(rdd._jrdd, batched) + jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) return SchemaRDD(srdd.toJavaSchemaRDD(), self) @@ -1841,7 +1840,7 @@ def __init__(self, jschema_rdd, sql_ctx): self.is_checkpointed = False self.ctx = self.sql_ctx._sc # the _jrdd is created by javaToPython(), serialized by pickle - self._jrdd_deserializer = BatchedSerializer(PickleSerializer()) + self._jrdd_deserializer = AutoBatchedSerializer(PickleSerializer()) @property def _jrdd(self): @@ -2071,16 +2070,13 @@ def subtract(self, other, numPartitions=None): def _test(): import doctest - from array import array from pyspark.context import SparkContext # let doctest run in pyspark.sql, so DataTypes can be picklable import pyspark.sql from pyspark.sql import Row, SQLContext from pyspark.tests import ExamplePoint, ExamplePointUDT globs = pyspark.sql.__dict__.copy() - # The small batch size here ensures that we see multiple batches, - # even in these small test examples: - sc = SparkContext('local[4]', 'PythonTest', batchSize=2) + sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc globs['sqlCtx'] = SQLContext(sc) globs['rdd'] = sc.parallelize( diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index e947b0946810..7e61b017efa7 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -242,7 +242,7 @@ class PySparkTestCase(unittest.TestCase): def setUp(self): self._old_sys_path = list(sys.path) class_name = self.__class__.__name__ - self.sc = SparkContext('local[4]', class_name, batchSize=2) + self.sc = SparkContext('local[4]', class_name) def tearDown(self): self.sc.stop() @@ -253,7 +253,7 @@ class ReusedPySparkTestCase(unittest.TestCase): @classmethod def setUpClass(cls): - cls.sc = SparkContext('local[4]', cls.__name__, batchSize=2) + cls.sc = SparkContext('local[4]', cls.__name__) @classmethod def tearDownClass(cls): @@ -671,7 +671,7 @@ def setUp(self): self._old_sys_path = list(sys.path) class_name = self.__class__.__name__ conf = SparkConf().set("spark.python.profile", "true") - self.sc = SparkContext('local[4]', class_name, batchSize=2, conf=conf) + self.sc = SparkContext('local[4]', class_name, conf=conf) def test_profiler(self): @@ -1012,16 +1012,19 @@ def test_sequencefiles(self): clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", "org.apache.hadoop.io.Text", "org.apache.spark.api.python.TestWritable").collect()) - ec = (u'1', - {u'__class__': u'org.apache.spark.api.python.TestWritable', - u'double': 54.0, u'int': 123, u'str': u'test1'}) - self.assertEqual(clazz[0], ec) + cname = u'org.apache.spark.api.python.TestWritable' + ec = [(u'1', {u'__class__': cname, u'double': 1.0, u'int': 1, u'str': u'test1'}), + (u'2', {u'__class__': cname, u'double': 2.3, u'int': 2, u'str': u'test2'}), + (u'3', {u'__class__': cname, u'double': 3.1, u'int': 3, u'str': u'test3'}), + (u'4', {u'__class__': cname, u'double': 4.2, u'int': 4, u'str': u'test4'}), + (u'5', {u'__class__': cname, u'double': 5.5, u'int': 5, u'str': u'test56'})] + self.assertEqual(clazz, ec) unbatched_clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", "org.apache.hadoop.io.Text", "org.apache.spark.api.python.TestWritable", - batchSize=1).collect()) - self.assertEqual(unbatched_clazz[0], ec) + ).collect()) + self.assertEqual(unbatched_clazz, ec) def test_oldhadoop(self): basepath = self.tempdir.name @@ -1341,51 +1344,6 @@ def test_reserialization(self): result5 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newdataset").collect()) self.assertEqual(result5, data) - def test_unbatched_save_and_read(self): - basepath = self.tempdir.name - ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] - self.sc.parallelize(ei, len(ei)).saveAsSequenceFile( - basepath + "/unbatched/") - - unbatched_sequence = sorted(self.sc.sequenceFile( - basepath + "/unbatched/", - batchSize=1).collect()) - self.assertEqual(unbatched_sequence, ei) - - unbatched_hadoopFile = sorted(self.sc.hadoopFile( - basepath + "/unbatched/", - "org.apache.hadoop.mapred.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text", - batchSize=1).collect()) - self.assertEqual(unbatched_hadoopFile, ei) - - unbatched_newAPIHadoopFile = sorted(self.sc.newAPIHadoopFile( - basepath + "/unbatched/", - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text", - batchSize=1).collect()) - self.assertEqual(unbatched_newAPIHadoopFile, ei) - - oldconf = {"mapred.input.dir": basepath + "/unbatched/"} - unbatched_hadoopRDD = sorted(self.sc.hadoopRDD( - "org.apache.hadoop.mapred.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text", - conf=oldconf, - batchSize=1).collect()) - self.assertEqual(unbatched_hadoopRDD, ei) - - newconf = {"mapred.input.dir": basepath + "/unbatched/"} - unbatched_newAPIHadoopRDD = sorted(self.sc.newAPIHadoopRDD( - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text", - conf=newconf, - batchSize=1).collect()) - self.assertEqual(unbatched_newAPIHadoopRDD, ei) - def test_malformed_RDD(self): basepath = self.tempdir.name # non-batch-serialized RDD[[(K, V)]] should be rejected diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 3ee2ea05cfa2..fbec2f9f4b2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import java.util.{List => JList} +import org.apache.spark.api.python.SerDeUtil + import scala.collection.JavaConversions._ import net.razorvine.pickle.Pickler @@ -385,12 +387,8 @@ class SchemaRDD( */ private[sql] def javaToPython: JavaRDD[Array[Byte]] = { val fieldTypes = schema.fields.map(_.dataType) - this.mapPartitions { iter => - val pickle = new Pickler - iter.map { row => - EvaluatePython.rowToArray(row, fieldTypes) - }.grouped(100).map(batched => pickle.dumps(batched.toArray)) - } + val jrdd = this.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() + SerDeUtil.javaToPython(jrdd) } /** From 4b13bff939291caa1fb9b9a180db66b1d006153c Mon Sep 17 00:00:00 2001 From: Dariusz Kobylarz Date: Tue, 4 Nov 2014 09:53:43 -0800 Subject: [PATCH 014/652] fixed MLlib Naive-Bayes java example bug the filter tests Double objects by references whereas it should test their values Author: Dariusz Kobylarz Closes #3081 from dkobylarz/master and squashes the following commits: 5d43a39 [Dariusz Kobylarz] naive bayes example update a304b93 [Dariusz Kobylarz] fixed MLlib Naive-Bayes java example bug (cherry picked from commit bcecd73fdd4d2ec209259cfd57d3ad1d63f028f2) Signed-off-by: Xiangrui Meng --- docs/mllib-naive-bayes.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index 7f9d4c656394..d5b044d94fdd 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -88,11 +88,11 @@ JavaPairRDD predictionAndLabel = return new Tuple2(model.predict(p.features()), p.label()); } }); -double accuracy = 1.0 * predictionAndLabel.filter(new Function, Boolean>() { +double accuracy = predictionAndLabel.filter(new Function, Boolean>() { @Override public Boolean call(Tuple2 pl) { - return pl._1() == pl._2(); + return pl._1().equals(pl._2()); } - }).count() / test.count(); + }).count() / (double) test.count(); {% endhighlight %} From b90451814b7ff7338881e60124d779e2fd89ac60 Mon Sep 17 00:00:00 2001 From: Niklas Wilcke <1wilcke@informatik.uni-hamburg.de> Date: Tue, 4 Nov 2014 09:57:03 -0800 Subject: [PATCH 015/652] [Spark-4060] [MLlib] exposing special rdd functions to the public Author: Niklas Wilcke <1wilcke@informatik.uni-hamburg.de> Closes #2907 from numbnut/master and squashes the following commits: 7f7c767 [Niklas Wilcke] [Spark-4060] [MLlib] exposing special rdd functions to the public, #2907 (cherry picked from commit f90ad5d426cb726079c490a9bb4b1100e2b4e602) Signed-off-by: Xiangrui Meng --- .../spark/mllib/evaluation/AreaUnderCurve.scala | 2 +- .../org/apache/spark/mllib/rdd/RDDFunctions.scala | 11 ++++++----- .../scala/org/apache/spark/mllib/rdd/SlidingRDD.scala | 5 +++-- .../apache/spark/mllib/rdd/RDDFunctionsSuite.scala | 6 +++--- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala index 7858ec602483..078fbfbe4f0e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala @@ -43,7 +43,7 @@ private[evaluation] object AreaUnderCurve { */ def of(curve: RDD[(Double, Double)]): Double = { curve.sliding(2).aggregate(0.0)( - seqOp = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points), + seqOp = (auc: Double, points: Array[(Double, Double)]) => auc + trapezoid(points), combOp = _ + _ ) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index b5e403bc8c14..57c0768084e4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.rdd import scala.language.implicitConversions import scala.reflect.ClassTag +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.HashPartitioner import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD @@ -28,8 +29,8 @@ import org.apache.spark.util.Utils /** * Machine learning specific RDD functions. */ -private[mllib] -class RDDFunctions[T: ClassTag](self: RDD[T]) { +@DeveloperApi +class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable { /** * Returns a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding @@ -39,10 +40,10 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) { * trigger a Spark job if the parent RDD has more than one partitions and the window size is * greater than 1. */ - def sliding(windowSize: Int): RDD[Seq[T]] = { + def sliding(windowSize: Int): RDD[Array[T]] = { require(windowSize > 0, s"Sliding window size must be positive, but got $windowSize.") if (windowSize == 1) { - self.map(Seq(_)) + self.map(Array(_)) } else { new SlidingRDD[T](self, windowSize) } @@ -112,7 +113,7 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) { } } -private[mllib] +@DeveloperApi object RDDFunctions { /** Implicit conversion from an RDD to RDDFunctions. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala index dd80782c0f00..35e81fcb3de0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala @@ -45,15 +45,16 @@ class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T] */ private[mllib] class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int) - extends RDD[Seq[T]](parent) { + extends RDD[Array[T]](parent) { require(windowSize > 1, s"Window size must be greater than 1, but got $windowSize.") - override def compute(split: Partition, context: TaskContext): Iterator[Seq[T]] = { + override def compute(split: Partition, context: TaskContext): Iterator[Array[T]] = { val part = split.asInstanceOf[SlidingRDDPartition[T]] (firstParent[T].iterator(part.prev, context) ++ part.tail) .sliding(windowSize) .withPartial(false) + .map(_.toArray) } override def getPreferredLocations(split: Partition): Seq[String] = diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index 27a19f793242..4ef67a40b9f4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -42,9 +42,9 @@ class RDDFunctionsSuite extends FunSuite with LocalSparkContext { val data = Seq(Seq(1, 2, 3), Seq.empty[Int], Seq(4), Seq.empty[Int], Seq(5, 6, 7)) val rdd = sc.parallelize(data, data.length).flatMap(s => s) assert(rdd.partitions.size === data.length) - val sliding = rdd.sliding(3) - val expected = data.flatMap(x => x).sliding(3).toList - assert(sliding.collect().toList === expected) + val sliding = rdd.sliding(3).collect().toSeq.map(_.toSeq) + val expected = data.flatMap(x => x).sliding(3).toSeq.map(_.toSeq) + assert(sliding === expected) } test("treeAggregate") { From e5c7869f20139832ad9e636eaeb5e77da7297456 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 4 Nov 2014 18:14:28 -0800 Subject: [PATCH 016/652] [SQL] Add String option for DSL AS Author: Michael Armbrust Closes #3097 from marmbrus/asString and squashes the following commits: 6430520 [Michael Armbrust] Add String option for DSL AS (cherry picked from commit 515abb9afa2d6b58947af6bb079a493b49d315ca) Signed-off-by: Xiangrui Meng --- .../main/scala/org/apache/spark/sql/catalyst/dsl/package.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 3314e1547701..31dc5a58e68e 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -110,7 +110,8 @@ package object dsl { def asc = SortOrder(expr, Ascending) def desc = SortOrder(expr, Descending) - def as(s: Symbol) = Alias(expr, s.name)() + def as(alias: String) = Alias(expr, alias)() + def as(alias: Symbol) = Alias(expr, alias.name)() } trait ExpressionConversions { From f225b3cc18698b2ee8a94c8ffa0b6aca2fce7cf9 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 4 Nov 2014 21:35:52 -0800 Subject: [PATCH 017/652] [SPARK-3964] [MLlib] [PySpark] add Hypothesis test Python API ``` pyspark.mllib.stat.StatisticschiSqTest(observed, expected=None) :: Experimental :: If `observed` is Vector, conduct Pearson's chi-squared goodness of fit test of the observed data against the expected distribution, or againt the uniform distribution (by default), with each category having an expected frequency of `1 / len(observed)`. (Note: `observed` cannot contain negative values) If `observed` is matrix, conduct Pearson's independence test on the input contingency matrix, which cannot contain negative entries or columns or rows that sum up to 0. If `observed` is an RDD of LabeledPoint, conduct Pearson's independence test for every feature against the label across the input RDD. For each feature, the (feature, label) pairs are converted into a contingency matrix for which the chi-squared statistic is computed. All label and feature values must be categorical. :param observed: it could be a vector containing the observed categorical counts/relative frequencies, or the contingency matrix (containing either counts or relative frequencies), or an RDD of LabeledPoint containing the labeled dataset with categorical features. Real-valued features will be treated as categorical for each distinct value. :param expected: Vector containing the expected categorical counts/relative frequencies. `expected` is rescaled if the `expected` sum differs from the `observed` sum. :return: ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, the method used, and the null hypothesis. ``` Author: Davies Liu Closes #3091 from davies/his and squashes the following commits: 145d16c [Davies Liu] address comments 0ab0764 [Davies Liu] fix float 5097d54 [Davies Liu] add Hypothesis test Python API (cherry picked from commit c8abddc5164d8cf11cdede6ab3d5d1ea08028708) Signed-off-by: Xiangrui Meng --- docs/mllib-statistics.md | 40 +++++ .../mllib/api/python/PythonMLLibAPI.scala | 26 ++++ python/pyspark/mllib/common.py | 7 +- python/pyspark/mllib/linalg.py | 13 +- python/pyspark/mllib/stat.py | 137 +++++++++++++++++- 5 files changed, 219 insertions(+), 4 deletions(-) diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index 10a5131c0741..ca8c29218f52 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -380,6 +380,46 @@ for (ChiSqTestResult result : featureTestResults) { {% endhighlight %} +
+[`Statistics`](api/python/index.html#pyspark.mllib.stat.Statistics$) provides methods to +run Pearson's chi-squared tests. The following example demonstrates how to run and interpret +hypothesis tests. + +{% highlight python %} +from pyspark import SparkContext +from pyspark.mllib.linalg import Vectors, Matrices +from pyspark.mllib.regresssion import LabeledPoint +from pyspark.mllib.stat import Statistics + +sc = SparkContext() + +vec = Vectors.dense(...) # a vector composed of the frequencies of events + +# compute the goodness of fit. If a second vector to test against is not supplied as a parameter, +# the test runs against a uniform distribution. +goodnessOfFitTestResult = Statistics.chiSqTest(vec) +print goodnessOfFitTestResult # summary of the test including the p-value, degrees of freedom, + # test statistic, the method used, and the null hypothesis. + +mat = Matrices.dense(...) # a contingency matrix + +# conduct Pearson's independence test on the input contingency matrix +independenceTestResult = Statistics.chiSqTest(mat) +print independenceTestResult # summary of the test including the p-value, degrees of freedom... + +obs = sc.parallelize(...) # LabeledPoint(feature, label) . + +# The contingency table is constructed from an RDD of LabeledPoint and used to conduct +# the independence test. Returns an array containing the ChiSquaredTestResult for every feature +# against the label. +featureTestResults = Statistics.chiSqTest(obs) + +for i, result in enumerate(featureTestResults): + print "Column $d:" % (i + 1) + print result +{% endhighlight %} +
+ ## Random data generation diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 65b98a8ceea5..d832ae34b55e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -43,6 +43,7 @@ import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames +import org.apache.spark.mllib.stat.test.ChiSqTestResult import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -454,6 +455,31 @@ class PythonMLLibAPI extends Serializable { Statistics.corr(x.rdd, y.rdd, getCorrNameOrDefault(method)) } + /** + * Java stub for mllib Statistics.chiSqTest() + */ + def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = { + if (expected == null) { + Statistics.chiSqTest(observed) + } else { + Statistics.chiSqTest(observed, expected) + } + } + + /** + * Java stub for mllib Statistics.chiSqTest(observed: Matrix) + */ + def chiSqTest(observed: Matrix): ChiSqTestResult = { + Statistics.chiSqTest(observed) + } + + /** + * Java stub for mllib Statistics.chiSqTest(RDD[LabelPoint]) + */ + def chiSqTest(data: JavaRDD[LabeledPoint]): Array[ChiSqTestResult] = { + Statistics.chiSqTest(data.rdd) + } + // used by the corr methods to retrieve the name of the correlation method passed in via pyspark private def getCorrNameOrDefault(method: String) = { if (method == null) CorrelationNames.defaultCorrName else method diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index dbe5f698b734..c6149fe391ec 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -98,8 +98,13 @@ def _java2py(sc, r): jrdd = sc._jvm.SerDe.javaToPython(r) return RDD(jrdd, sc) - elif isinstance(r, (JavaArray, JavaList)) or clsName in _picklable_classes: + if clsName in _picklable_classes: r = sc._jvm.SerDe.dumps(r) + elif isinstance(r, (JavaArray, JavaList)): + try: + r = sc._jvm.SerDe.dumps(r) + except Py4JJavaError: + pass # not pickable if isinstance(r, bytearray): r = PickleSerializer().loads(str(r)) diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index c0c3dff31e7f..e35202dca0ac 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -33,7 +33,7 @@ IntegerType, ByteType, Row -__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors'] +__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors', 'DenseMatrix', 'Matrices'] if sys.version_info[:2] == (2, 7): @@ -578,6 +578,8 @@ class DenseMatrix(Matrix): def __init__(self, numRows, numCols, values): Matrix.__init__(self, numRows, numCols) assert len(values) == numRows * numCols + if not isinstance(values, array.array): + values = array.array('d', values) self.values = values def __reduce__(self): @@ -596,6 +598,15 @@ def toArray(self): return np.reshape(self.values, (self.numRows, self.numCols), order='F') +class Matrices(object): + @staticmethod + def dense(numRows, numCols, values): + """ + Create a DenseMatrix + """ + return DenseMatrix(numRows, numCols, values) + + def _test(): import doctest (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index 15f0652f833d..0700f8a8e5a8 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -19,11 +19,12 @@ Python package for statistical functions in MLlib. """ +from pyspark import RDD from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper -from pyspark.mllib.linalg import _convert_to_vector +from pyspark.mllib.linalg import Matrix, _convert_to_vector -__all__ = ['MultivariateStatisticalSummary', 'Statistics'] +__all__ = ['MultivariateStatisticalSummary', 'ChiSqTestResult', 'Statistics'] class MultivariateStatisticalSummary(JavaModelWrapper): @@ -51,6 +52,54 @@ def min(self): return self.call("min").toArray() +class ChiSqTestResult(JavaModelWrapper): + """ + :: Experimental :: + + Object containing the test results for the chi-squared hypothesis test. + """ + @property + def method(self): + """ + Name of the test method + """ + return self._java_model.method() + + @property + def pValue(self): + """ + The probability of obtaining a test statistic result at least as + extreme as the one that was actually observed, assuming that the + null hypothesis is true. + """ + return self._java_model.pValue() + + @property + def degreesOfFreedom(self): + """ + Returns the degree(s) of freedom of the hypothesis test. + Return type should be Number(e.g. Int, Double) or tuples of Numbers. + """ + return self._java_model.degreesOfFreedom() + + @property + def statistic(self): + """ + Test statistic. + """ + return self._java_model.statistic() + + @property + def nullHypothesis(self): + """ + Null hypothesis of the test. + """ + return self._java_model.nullHypothesis() + + def __str__(self): + return self._java_model.toString() + + class Statistics(object): @staticmethod @@ -135,6 +184,90 @@ def corr(x, y=None, method=None): else: return callMLlibFunc("corr", x.map(float), y.map(float), method) + @staticmethod + def chiSqTest(observed, expected=None): + """ + :: Experimental :: + + If `observed` is Vector, conduct Pearson's chi-squared goodness + of fit test of the observed data against the expected distribution, + or againt the uniform distribution (by default), with each category + having an expected frequency of `1 / len(observed)`. + (Note: `observed` cannot contain negative values) + + If `observed` is matrix, conduct Pearson's independence test on the + input contingency matrix, which cannot contain negative entries or + columns or rows that sum up to 0. + + If `observed` is an RDD of LabeledPoint, conduct Pearson's independence + test for every feature against the label across the input RDD. + For each feature, the (feature, label) pairs are converted into a + contingency matrix for which the chi-squared statistic is computed. + All label and feature values must be categorical. + + :param observed: it could be a vector containing the observed categorical + counts/relative frequencies, or the contingency matrix + (containing either counts or relative frequencies), + or an RDD of LabeledPoint containing the labeled dataset + with categorical features. Real-valued features will be + treated as categorical for each distinct value. + :param expected: Vector containing the expected categorical counts/relative + frequencies. `expected` is rescaled if the `expected` sum + differs from the `observed` sum. + :return: ChiSquaredTest object containing the test statistic, degrees + of freedom, p-value, the method used, and the null hypothesis. + + >>> from pyspark.mllib.linalg import Vectors, Matrices + >>> observed = Vectors.dense([4, 6, 5]) + >>> pearson = Statistics.chiSqTest(observed) + >>> print pearson.statistic + 0.4 + >>> pearson.degreesOfFreedom + 2 + >>> print round(pearson.pValue, 4) + 0.8187 + >>> pearson.method + u'pearson' + >>> pearson.nullHypothesis + u'observed follows the same distribution as expected.' + + >>> observed = Vectors.dense([21, 38, 43, 80]) + >>> expected = Vectors.dense([3, 5, 7, 20]) + >>> pearson = Statistics.chiSqTest(observed, expected) + >>> print round(pearson.pValue, 4) + 0.0027 + + >>> data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0] + >>> chi = Statistics.chiSqTest(Matrices.dense(3, 4, data)) + >>> print round(chi.statistic, 4) + 21.9958 + + >>> from pyspark.mllib.regression import LabeledPoint + >>> data = [LabeledPoint(0.0, Vectors.dense([0.5, 10.0])), + ... LabeledPoint(0.0, Vectors.dense([1.5, 20.0])), + ... LabeledPoint(1.0, Vectors.dense([1.5, 30.0])), + ... LabeledPoint(0.0, Vectors.dense([3.5, 30.0])), + ... LabeledPoint(0.0, Vectors.dense([3.5, 40.0])), + ... LabeledPoint(1.0, Vectors.dense([3.5, 40.0])),] + >>> rdd = sc.parallelize(data, 4) + >>> chi = Statistics.chiSqTest(rdd) + >>> print chi[0].statistic + 0.75 + >>> print chi[1].statistic + 1.5 + """ + if isinstance(observed, RDD): + jmodels = callMLlibFunc("chiSqTest", observed) + return [ChiSqTestResult(m) for m in jmodels] + + if isinstance(observed, Matrix): + jmodel = callMLlibFunc("chiSqTest", observed) + else: + if expected and len(expected) != len(observed): + raise ValueError("`expected` should have same length with `observed`") + jmodel = callMLlibFunc("chiSqTest", _convert_to_vector(observed), expected) + return ChiSqTestResult(jmodel) + def _test(): import doctest From 46654b0661257f432932c6efc09c4c0983521834 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 5 Nov 2014 01:21:53 -0800 Subject: [PATCH 018/652] [SPARK-4029][Streaming] Update streaming driver to reliably save and recover received block metadata on driver failures As part of the initiative of preventing data loss on driver failure, this JIRA tracks the sub task of modifying the streaming driver to reliably save received block metadata, and recover them on driver restart. This was solved by introducing a `ReceivedBlockTracker` that takes all the responsibility of managing the metadata of received blocks (i.e. `ReceivedBlockInfo`, and any actions on them (e.g, allocating blocks to batches, etc.). All actions to block info get written out to a write ahead log (using `WriteAheadLogManager`). On recovery, all the actions are replaying to recreate the pre-failure state of the `ReceivedBlockTracker`, which include the batch-to-block allocations and the unallocated blocks. Furthermore, the `ReceiverInputDStream` was modified to create `WriteAheadLogBackedBlockRDD`s when file segment info is present in the `ReceivedBlockInfo`. After recovery of all the block info (through recovery `ReceivedBlockTracker`), the `WriteAheadLogBackedBlockRDD`s gets recreated with the recovered info, and jobs submitted. The data of the blocks gets pulled from the write ahead logs, thanks to the segment info present in the `ReceivedBlockInfo`. This is still a WIP. Things that are missing here are. - *End-to-end integration tests:* Unit tests that tests the driver recovery, by killing and restarting the streaming context, and verifying all the input data gets processed. This has been implemented but not included in this PR yet. A sneak peek of that DriverFailureSuite can be found in this PR (on my personal repo): https://github.com/tdas/spark/pull/25 I can either include it in this PR, or submit that as a separate PR after this gets in. - *WAL cleanup:* Cleaning up the received data write ahead log, by calling `ReceivedBlockHandler.cleanupOldBlocks`. This is being worked on. Author: Tathagata Das Closes #3026 from tdas/driver-ha-rbt and squashes the following commits: a8009ed [Tathagata Das] Added comment 1d704bb [Tathagata Das] Enabled storing recovered WAL-backed blocks to BM 2ee2484 [Tathagata Das] More minor changes based on PR 47fc1e3 [Tathagata Das] Addressed PR comments. 9a7e3e4 [Tathagata Das] Refactored ReceivedBlockTracker API a bit to make things a little cleaner for users of the tracker. af63655 [Tathagata Das] Minor changes. fce2b21 [Tathagata Das] Removed commented lines 59496d3 [Tathagata Das] Changed class names, made allocation more explicit and added cleanup 19aec7d [Tathagata Das] Fixed casting bug. f66d277 [Tathagata Das] Fix line lengths. cda62ee [Tathagata Das] Added license 25611d6 [Tathagata Das] Minor changes before submitting PR 7ae0a7fb [Tathagata Das] Transferred changes from driver-ha-working branch (cherry picked from commit 5f13759d3642ea5b58c12a756e7125ac19aff10e) Signed-off-by: Tathagata Das --- .../dstream/ReceiverInputDStream.scala | 69 +++-- .../rdd/WriteAheadLogBackedBlockRDD.scala | 3 +- .../streaming/scheduler/JobGenerator.scala | 21 +- .../scheduler/ReceivedBlockTracker.scala | 230 +++++++++++++++++ .../streaming/scheduler/ReceiverTracker.scala | 98 ++++--- .../streaming/BasicOperationsSuite.scala | 19 +- .../streaming/ReceivedBlockTrackerSuite.scala | 242 ++++++++++++++++++ .../WriteAheadLogBackedBlockRDDSuite.scala | 4 +- 8 files changed, 597 insertions(+), 89 deletions(-) create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index bb47d373de63..3e67161363e5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -17,15 +17,14 @@ package org.apache.spark.streaming.dstream -import scala.collection.mutable.HashMap import scala.reflect.ClassTag import org.apache.spark.rdd.{BlockRDD, RDD} -import org.apache.spark.storage.BlockId +import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.streaming._ -import org.apache.spark.streaming.receiver.{WriteAheadLogBasedStoreResult, BlockManagerBasedStoreResult, Receiver} +import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD +import org.apache.spark.streaming.receiver.{Receiver, WriteAheadLogBasedStoreResult} import org.apache.spark.streaming.scheduler.ReceivedBlockInfo -import org.apache.spark.SparkException /** * Abstract class for defining any [[org.apache.spark.streaming.dstream.InputDStream]] @@ -40,9 +39,6 @@ import org.apache.spark.SparkException abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingContext) extends InputDStream[T](ssc_) { - /** Keeps all received blocks information */ - private lazy val receivedBlockInfo = new HashMap[Time, Array[ReceivedBlockInfo]] - /** This is an unique identifier for the network input stream. */ val id = ssc.getNewReceiverStreamId() @@ -58,24 +54,45 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont def stop() {} - /** Ask ReceiverInputTracker for received data blocks and generates RDDs with them. */ + /** + * Generates RDDs with blocks received by the receiver of this stream. */ override def compute(validTime: Time): Option[RDD[T]] = { - // If this is called for any time before the start time of the context, - // then this returns an empty RDD. This may happen when recovering from a - // master failure - if (validTime >= graph.startTime) { - val blockInfo = ssc.scheduler.receiverTracker.getReceivedBlockInfo(id) - receivedBlockInfo(validTime) = blockInfo - val blockIds = blockInfo.map { _.blockStoreResult.blockId.asInstanceOf[BlockId] } - Some(new BlockRDD[T](ssc.sc, blockIds)) - } else { - Some(new BlockRDD[T](ssc.sc, Array.empty)) - } - } + val blockRDD = { - /** Get information on received blocks. */ - private[streaming] def getReceivedBlockInfo(time: Time) = { - receivedBlockInfo.get(time).getOrElse(Array.empty[ReceivedBlockInfo]) + if (validTime < graph.startTime) { + // If this is called for any time before the start time of the context, + // then this returns an empty RDD. This may happen when recovering from a + // driver failure without any write ahead log to recover pre-failure data. + new BlockRDD[T](ssc.sc, Array.empty) + } else { + // Otherwise, ask the tracker for all the blocks that have been allocated to this stream + // for this batch + val blockInfos = + ssc.scheduler.receiverTracker.getBlocksOfBatch(validTime).get(id).getOrElse(Seq.empty) + val blockStoreResults = blockInfos.map { _.blockStoreResult } + val blockIds = blockStoreResults.map { _.blockId.asInstanceOf[BlockId] }.toArray + + // Check whether all the results are of the same type + val resultTypes = blockStoreResults.map { _.getClass }.distinct + if (resultTypes.size > 1) { + logWarning("Multiple result types in block information, WAL information will be ignored.") + } + + // If all the results are of type WriteAheadLogBasedStoreResult, then create + // WriteAheadLogBackedBlockRDD else create simple BlockRDD. + if (resultTypes.size == 1 && resultTypes.head == classOf[WriteAheadLogBasedStoreResult]) { + val logSegments = blockStoreResults.map { + _.asInstanceOf[WriteAheadLogBasedStoreResult].segment + }.toArray + // Since storeInBlockManager = false, the storage level does not matter. + new WriteAheadLogBackedBlockRDD[T](ssc.sparkContext, + blockIds, logSegments, storeInBlockManager = true, StorageLevel.MEMORY_ONLY_SER) + } else { + new BlockRDD[T](ssc.sc, blockIds) + } + } + } + Some(blockRDD) } /** @@ -86,10 +103,6 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont */ private[streaming] override def clearMetadata(time: Time) { super.clearMetadata(time) - val oldReceivedBlocks = receivedBlockInfo.filter(_._1 <= (time - rememberDuration)) - receivedBlockInfo --= oldReceivedBlocks.keys - logDebug("Cleared " + oldReceivedBlocks.size + " RDDs that were older than " + - (time - rememberDuration) + ": " + oldReceivedBlocks.keys.mkString(", ")) + ssc.scheduler.receiverTracker.cleanupOldMetadata(time - rememberDuration) } } - diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index 23295bf65871..dd1e96334952 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -48,7 +48,6 @@ class WriteAheadLogBackedBlockRDDPartition( * If it does not find them, it looks up the corresponding file segment. * * @param sc SparkContext - * @param hadoopConfig Hadoop configuration * @param blockIds Ids of the blocks that contains this RDD's data * @param segments Segments in write ahead logs that contain this RDD's data * @param storeInBlockManager Whether to store in the block manager after reading from the segment @@ -58,7 +57,6 @@ class WriteAheadLogBackedBlockRDDPartition( private[streaming] class WriteAheadLogBackedBlockRDD[T: ClassTag]( @transient sc: SparkContext, - @transient hadoopConfig: Configuration, @transient blockIds: Array[BlockId], @transient segments: Array[WriteAheadLogFileSegment], storeInBlockManager: Boolean, @@ -71,6 +69,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( s"the same as number of segments (${segments.length}})!") // Hadoop configuration is not serializable, so broadcast it as a serializable. + @transient private val hadoopConfig = sc.hadoopConfiguration private val broadcastedHadoopConf = new SerializableWritable(hadoopConfig) override def getPartitions: Array[Partition] = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 7d73ada12d10..39b66e113076 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -112,7 +112,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // Wait until all the received blocks in the network input tracker has // been consumed by network input DStreams, and jobs have been generated with them logInfo("Waiting for all received blocks to be consumed for job generation") - while(!hasTimedOut && jobScheduler.receiverTracker.hasMoreReceivedBlockIds) { + while(!hasTimedOut && jobScheduler.receiverTracker.hasUnallocatedBlocks) { Thread.sleep(pollTime) } logInfo("Waited for all received blocks to be consumed for job generation") @@ -217,14 +217,18 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Generate jobs and perform checkpoint for the given `time`. */ private def generateJobs(time: Time) { - Try(graph.generateJobs(time)) match { + // Set the SparkEnv in this thread, so that job generation code can access the environment + // Example: BlockRDDs are created in this thread, and it needs to access BlockManager + // Update: This is probably redundant after threadlocal stuff in SparkEnv has been removed. + SparkEnv.set(ssc.env) + Try { + jobScheduler.receiverTracker.allocateBlocksToBatch(time) // allocate received blocks to batch + graph.generateJobs(time) // generate jobs using allocated block + } match { case Success(jobs) => - val receivedBlockInfo = graph.getReceiverInputStreams.map { stream => - val streamId = stream.id - val receivedBlockInfo = stream.getReceivedBlockInfo(time) - (streamId, receivedBlockInfo) - }.toMap - jobScheduler.submitJobSet(JobSet(time, jobs, receivedBlockInfo)) + val receivedBlockInfos = + jobScheduler.receiverTracker.getBlocksOfBatch(time).mapValues { _.toArray } + jobScheduler.submitJobSet(JobSet(time, jobs, receivedBlockInfos)) case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e) } @@ -234,6 +238,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Clear DStream metadata for the given `time`. */ private def clearMetadata(time: Time) { ssc.graph.clearMetadata(time) + jobScheduler.receiverTracker.cleanupOldMetadata(time - graph.batchDuration) // If checkpointing is enabled, then checkpoint, // else mark batch to be fully processed diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala new file mode 100644 index 000000000000..5f5e1909908d --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import java.nio.ByteBuffer + +import scala.collection.mutable +import scala.language.implicitConversions + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.{SparkException, Logging, SparkConf} +import org.apache.spark.streaming.Time +import org.apache.spark.streaming.util.{Clock, WriteAheadLogManager} +import org.apache.spark.util.Utils + +/** Trait representing any event in the ReceivedBlockTracker that updates its state. */ +private[streaming] sealed trait ReceivedBlockTrackerLogEvent + +private[streaming] case class BlockAdditionEvent(receivedBlockInfo: ReceivedBlockInfo) + extends ReceivedBlockTrackerLogEvent +private[streaming] case class BatchAllocationEvent(time: Time, allocatedBlocks: AllocatedBlocks) + extends ReceivedBlockTrackerLogEvent +private[streaming] case class BatchCleanupEvent(times: Seq[Time]) + extends ReceivedBlockTrackerLogEvent + + +/** Class representing the blocks of all the streams allocated to a batch */ +private[streaming] +case class AllocatedBlocks(streamIdToAllocatedBlocks: Map[Int, Seq[ReceivedBlockInfo]]) { + def getBlocksOfStream(streamId: Int): Seq[ReceivedBlockInfo] = { + streamIdToAllocatedBlocks.get(streamId).getOrElse(Seq.empty) + } +} + +/** + * Class that keep track of all the received blocks, and allocate them to batches + * when required. All actions taken by this class can be saved to a write ahead log + * (if a checkpoint directory has been provided), so that the state of the tracker + * (received blocks and block-to-batch allocations) can be recovered after driver failure. + * + * Note that when any instance of this class is created with a checkpoint directory, + * it will try reading events from logs in the directory. + */ +private[streaming] class ReceivedBlockTracker( + conf: SparkConf, + hadoopConf: Configuration, + streamIds: Seq[Int], + clock: Clock, + checkpointDirOption: Option[String]) + extends Logging { + + private type ReceivedBlockQueue = mutable.Queue[ReceivedBlockInfo] + + private val streamIdToUnallocatedBlockQueues = new mutable.HashMap[Int, ReceivedBlockQueue] + private val timeToAllocatedBlocks = new mutable.HashMap[Time, AllocatedBlocks] + + private val logManagerRollingIntervalSecs = conf.getInt( + "spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", 60) + private val logManagerOption = checkpointDirOption.map { checkpointDir => + new WriteAheadLogManager( + ReceivedBlockTracker.checkpointDirToLogDir(checkpointDir), + hadoopConf, + rollingIntervalSecs = logManagerRollingIntervalSecs, + callerName = "ReceivedBlockHandlerMaster", + clock = clock + ) + } + + private var lastAllocatedBatchTime: Time = null + + // Recover block information from write ahead logs + recoverFromWriteAheadLogs() + + /** Add received block. This event will get written to the write ahead log (if enabled). */ + def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = synchronized { + try { + writeToLog(BlockAdditionEvent(receivedBlockInfo)) + getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo + logDebug(s"Stream ${receivedBlockInfo.streamId} received " + + s"block ${receivedBlockInfo.blockStoreResult.blockId}") + true + } catch { + case e: Exception => + logError(s"Error adding block $receivedBlockInfo", e) + false + } + } + + /** + * Allocate all unallocated blocks to the given batch. + * This event will get written to the write ahead log (if enabled). + */ + def allocateBlocksToBatch(batchTime: Time): Unit = synchronized { + if (lastAllocatedBatchTime == null || batchTime > lastAllocatedBatchTime) { + val streamIdToBlocks = streamIds.map { streamId => + (streamId, getReceivedBlockQueue(streamId).dequeueAll(x => true)) + }.toMap + val allocatedBlocks = AllocatedBlocks(streamIdToBlocks) + writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks)) + timeToAllocatedBlocks(batchTime) = allocatedBlocks + lastAllocatedBatchTime = batchTime + allocatedBlocks + } else { + throw new SparkException(s"Unexpected allocation of blocks, " + + s"last batch = $lastAllocatedBatchTime, batch time to allocate = $batchTime ") + } + } + + /** Get the blocks allocated to the given batch. */ + def getBlocksOfBatch(batchTime: Time): Map[Int, Seq[ReceivedBlockInfo]] = synchronized { + timeToAllocatedBlocks.get(batchTime).map { _.streamIdToAllocatedBlocks }.getOrElse(Map.empty) + } + + /** Get the blocks allocated to the given batch and stream. */ + def getBlocksOfBatchAndStream(batchTime: Time, streamId: Int): Seq[ReceivedBlockInfo] = { + synchronized { + timeToAllocatedBlocks.get(batchTime).map { + _.getBlocksOfStream(streamId) + }.getOrElse(Seq.empty) + } + } + + /** Check if any blocks are left to be allocated to batches. */ + def hasUnallocatedReceivedBlocks: Boolean = synchronized { + !streamIdToUnallocatedBlockQueues.values.forall(_.isEmpty) + } + + /** + * Get blocks that have been added but not yet allocated to any batch. This method + * is primarily used for testing. + */ + def getUnallocatedBlocks(streamId: Int): Seq[ReceivedBlockInfo] = synchronized { + getReceivedBlockQueue(streamId).toSeq + } + + /** Clean up block information of old batches. */ + def cleanupOldBatches(cleanupThreshTime: Time): Unit = synchronized { + assert(cleanupThreshTime.milliseconds < clock.currentTime()) + val timesToCleanup = timeToAllocatedBlocks.keys.filter { _ < cleanupThreshTime }.toSeq + logInfo("Deleting batches " + timesToCleanup) + writeToLog(BatchCleanupEvent(timesToCleanup)) + timeToAllocatedBlocks --= timesToCleanup + logManagerOption.foreach(_.cleanupOldLogs(cleanupThreshTime.milliseconds)) + log + } + + /** Stop the block tracker. */ + def stop() { + logManagerOption.foreach { _.stop() } + } + + /** + * Recover all the tracker actions from the write ahead logs to recover the state (unallocated + * and allocated block info) prior to failure. + */ + private def recoverFromWriteAheadLogs(): Unit = synchronized { + // Insert the recovered block information + def insertAddedBlock(receivedBlockInfo: ReceivedBlockInfo) { + logTrace(s"Recovery: Inserting added block $receivedBlockInfo") + getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo + } + + // Insert the recovered block-to-batch allocations and clear the queue of received blocks + // (when the blocks were originally allocated to the batch, the queue must have been cleared). + def insertAllocatedBatch(batchTime: Time, allocatedBlocks: AllocatedBlocks) { + logTrace(s"Recovery: Inserting allocated batch for time $batchTime to " + + s"${allocatedBlocks.streamIdToAllocatedBlocks}") + streamIdToUnallocatedBlockQueues.values.foreach { _.clear() } + lastAllocatedBatchTime = batchTime + timeToAllocatedBlocks.put(batchTime, allocatedBlocks) + } + + // Cleanup the batch allocations + def cleanupBatches(batchTimes: Seq[Time]) { + logTrace(s"Recovery: Cleaning up batches $batchTimes") + timeToAllocatedBlocks --= batchTimes + } + + logManagerOption.foreach { logManager => + logInfo(s"Recovering from write ahead logs in ${checkpointDirOption.get}") + logManager.readFromLog().foreach { byteBuffer => + logTrace("Recovering record " + byteBuffer) + Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array) match { + case BlockAdditionEvent(receivedBlockInfo) => + insertAddedBlock(receivedBlockInfo) + case BatchAllocationEvent(time, allocatedBlocks) => + insertAllocatedBatch(time, allocatedBlocks) + case BatchCleanupEvent(batchTimes) => + cleanupBatches(batchTimes) + } + } + } + } + + /** Write an update to the tracker to the write ahead log */ + private def writeToLog(record: ReceivedBlockTrackerLogEvent) { + logDebug(s"Writing to log $record") + logManagerOption.foreach { logManager => + logManager.writeToLog(ByteBuffer.wrap(Utils.serialize(record))) + } + } + + /** Get the queue of received blocks belonging to a particular stream */ + private def getReceivedBlockQueue(streamId: Int): ReceivedBlockQueue = { + streamIdToUnallocatedBlockQueues.getOrElseUpdate(streamId, new ReceivedBlockQueue) + } +} + +private[streaming] object ReceivedBlockTracker { + def checkpointDirToLogDir(checkpointDir: String): String = { + new Path(checkpointDir, "receivedBlockMetadata").toString + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index d696563bcee8..1c3984d968d2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -17,15 +17,16 @@ package org.apache.spark.streaming.scheduler -import scala.collection.mutable.{HashMap, SynchronizedMap, SynchronizedQueue} + +import scala.collection.mutable.{HashMap, SynchronizedMap} import scala.language.existentials import akka.actor._ -import org.apache.spark.{SerializableWritable, Logging, SparkEnv, SparkException} + +import org.apache.spark.{Logging, SerializableWritable, SparkEnv, SparkException} import org.apache.spark.SparkContext._ import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.receiver.{Receiver, ReceiverSupervisorImpl, StopReceiver} -import org.apache.spark.util.AkkaUtils /** * Messages used by the NetworkReceiver and the ReceiverTracker to communicate @@ -48,23 +49,28 @@ private[streaming] case class DeregisterReceiver(streamId: Int, msg: String, err * This class manages the execution of the receivers of NetworkInputDStreams. Instance of * this class must be created after all input streams have been added and StreamingContext.start() * has been called because it needs the final set of input streams at the time of instantiation. + * + * @param skipReceiverLaunch Do not launch the receiver. This is useful for testing. */ private[streaming] -class ReceiverTracker(ssc: StreamingContext) extends Logging { +class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false) extends Logging { - val receiverInputStreams = ssc.graph.getReceiverInputStreams() - val receiverInputStreamMap = Map(receiverInputStreams.map(x => (x.id, x)): _*) - val receiverExecutor = new ReceiverLauncher() - val receiverInfo = new HashMap[Int, ReceiverInfo] with SynchronizedMap[Int, ReceiverInfo] - val receivedBlockInfo = new HashMap[Int, SynchronizedQueue[ReceivedBlockInfo]] - with SynchronizedMap[Int, SynchronizedQueue[ReceivedBlockInfo]] - val timeout = AkkaUtils.askTimeout(ssc.conf) - val listenerBus = ssc.scheduler.listenerBus + private val receiverInputStreams = ssc.graph.getReceiverInputStreams() + private val receiverInputStreamIds = receiverInputStreams.map { _.id } + private val receiverExecutor = new ReceiverLauncher() + private val receiverInfo = new HashMap[Int, ReceiverInfo] with SynchronizedMap[Int, ReceiverInfo] + private val receivedBlockTracker = new ReceivedBlockTracker( + ssc.sparkContext.conf, + ssc.sparkContext.hadoopConfiguration, + receiverInputStreamIds, + ssc.scheduler.clock, + Option(ssc.checkpointDir) + ) + private val listenerBus = ssc.scheduler.listenerBus // actor is created when generator starts. // This not being null means the tracker has been started and not stopped - var actor: ActorRef = null - var currentTime: Time = null + private var actor: ActorRef = null /** Start the actor and receiver execution thread. */ def start() = synchronized { @@ -75,7 +81,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { if (!receiverInputStreams.isEmpty) { actor = ssc.env.actorSystem.actorOf(Props(new ReceiverTrackerActor), "ReceiverTracker") - receiverExecutor.start() + if (!skipReceiverLaunch) receiverExecutor.start() logInfo("ReceiverTracker started") } } @@ -84,45 +90,59 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { def stop() = synchronized { if (!receiverInputStreams.isEmpty && actor != null) { // First, stop the receivers - receiverExecutor.stop() + if (!skipReceiverLaunch) receiverExecutor.stop() // Finally, stop the actor ssc.env.actorSystem.stop(actor) actor = null + receivedBlockTracker.stop() logInfo("ReceiverTracker stopped") } } - /** Return all the blocks received from a receiver. */ - def getReceivedBlockInfo(streamId: Int): Array[ReceivedBlockInfo] = { - val receivedBlockInfo = getReceivedBlockInfoQueue(streamId).dequeueAll(x => true) - logInfo("Stream " + streamId + " received " + receivedBlockInfo.size + " blocks") - receivedBlockInfo.toArray + /** Allocate all unallocated blocks to the given batch. */ + def allocateBlocksToBatch(batchTime: Time): Unit = { + if (receiverInputStreams.nonEmpty) { + receivedBlockTracker.allocateBlocksToBatch(batchTime) + } + } + + /** Get the blocks for the given batch and all input streams. */ + def getBlocksOfBatch(batchTime: Time): Map[Int, Seq[ReceivedBlockInfo]] = { + receivedBlockTracker.getBlocksOfBatch(batchTime) } - private def getReceivedBlockInfoQueue(streamId: Int) = { - receivedBlockInfo.getOrElseUpdate(streamId, new SynchronizedQueue[ReceivedBlockInfo]) + /** Get the blocks allocated to the given batch and stream. */ + def getBlocksOfBatchAndStream(batchTime: Time, streamId: Int): Seq[ReceivedBlockInfo] = { + synchronized { + receivedBlockTracker.getBlocksOfBatchAndStream(batchTime, streamId) + } + } + + /** Clean up metadata older than the given threshold time */ + def cleanupOldMetadata(cleanupThreshTime: Time) { + receivedBlockTracker.cleanupOldBatches(cleanupThreshTime) } /** Register a receiver */ - def registerReceiver( + private def registerReceiver( streamId: Int, typ: String, host: String, receiverActor: ActorRef, sender: ActorRef ) { - if (!receiverInputStreamMap.contains(streamId)) { - throw new Exception("Register received for unexpected id " + streamId) + if (!receiverInputStreamIds.contains(streamId)) { + throw new SparkException("Register received for unexpected id " + streamId) } receiverInfo(streamId) = ReceiverInfo( streamId, s"${typ}-${streamId}", receiverActor, true, host) - ssc.scheduler.listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId))) + listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId))) logInfo("Registered receiver for stream " + streamId + " from " + sender.path.address) } /** Deregister a receiver */ - def deregisterReceiver(streamId: Int, message: String, error: String) { + private def deregisterReceiver(streamId: Int, message: String, error: String) { val newReceiverInfo = receiverInfo.get(streamId) match { case Some(oldInfo) => oldInfo.copy(actor = null, active = false, lastErrorMessage = message, lastError = error) @@ -131,7 +151,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, lastError = error) } receiverInfo(streamId) = newReceiverInfo - ssc.scheduler.listenerBus.post(StreamingListenerReceiverStopped(receiverInfo(streamId))) + listenerBus.post(StreamingListenerReceiverStopped(receiverInfo(streamId))) val messageWithError = if (error != null && !error.isEmpty) { s"$message - $error" } else { @@ -141,14 +161,12 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { } /** Add new blocks for the given stream */ - def addBlocks(receivedBlockInfo: ReceivedBlockInfo) { - getReceivedBlockInfoQueue(receivedBlockInfo.streamId) += receivedBlockInfo - logDebug("Stream " + receivedBlockInfo.streamId + " received new blocks: " + - receivedBlockInfo.blockStoreResult.blockId) + private def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = { + receivedBlockTracker.addBlock(receivedBlockInfo) } /** Report error sent by a receiver */ - def reportError(streamId: Int, message: String, error: String) { + private def reportError(streamId: Int, message: String, error: String) { val newReceiverInfo = receiverInfo.get(streamId) match { case Some(oldInfo) => oldInfo.copy(lastErrorMessage = message, lastError = error) @@ -157,7 +175,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, lastError = error) } receiverInfo(streamId) = newReceiverInfo - ssc.scheduler.listenerBus.post(StreamingListenerReceiverError(receiverInfo(streamId))) + listenerBus.post(StreamingListenerReceiverError(receiverInfo(streamId))) val messageWithError = if (error != null && !error.isEmpty) { s"$message - $error" } else { @@ -167,8 +185,8 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { } /** Check if any blocks are left to be processed */ - def hasMoreReceivedBlockIds: Boolean = { - !receivedBlockInfo.values.forall(_.isEmpty) + def hasUnallocatedBlocks: Boolean = { + receivedBlockTracker.hasUnallocatedReceivedBlocks } /** Actor to receive messages from the receivers. */ @@ -178,8 +196,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { registerReceiver(streamId, typ, host, receiverActor, sender) sender ! true case AddBlock(receivedBlockInfo) => - addBlocks(receivedBlockInfo) - sender ! true + sender ! addBlock(receivedBlockInfo) case ReportError(streamId, message, error) => reportError(streamId, message, error) case DeregisterReceiver(streamId, message, error) => @@ -194,6 +211,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { @transient val thread = new Thread() { override def run() { try { + SparkEnv.set(env) startReceivers() } catch { case ie: InterruptedException => logInfo("ReceiverLauncher interrupted") @@ -267,7 +285,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { // Distribute the receivers and start them logInfo("Starting " + receivers.length + " receivers") - ssc.sparkContext.runJob(tempRDD, startReceiver) + ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver)) logInfo("All of the receivers have been terminated") } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 6c8bb5014536..dbab685dc351 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -17,18 +17,19 @@ package org.apache.spark.streaming -import org.apache.spark.streaming.StreamingContext._ - -import org.apache.spark.rdd.{BlockRDD, RDD} -import org.apache.spark.SparkContext._ +import scala.collection.mutable +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import scala.language.existentials +import scala.reflect.ClassTag import util.ManualClock -import org.apache.spark.{SparkException, SparkConf} -import org.apache.spark.streaming.dstream.{WindowedDStream, DStream} -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} -import scala.reflect.ClassTag + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.StorageLevel -import scala.collection.mutable +import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.streaming.dstream.{DStream, WindowedDStream} class BasicOperationsSuite extends TestSuiteBase { test("map") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala new file mode 100644 index 000000000000..fd9c97f551c6 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import java.io.File + +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ +import scala.language.{implicitConversions, postfixOps} +import scala.util.Random + +import com.google.common.io.Files +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.storage.StreamBlockId +import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult +import org.apache.spark.streaming.scheduler._ +import org.apache.spark.streaming.util.{Clock, ManualClock, SystemClock, WriteAheadLogReader} +import org.apache.spark.streaming.util.WriteAheadLogSuite._ +import org.apache.spark.util.Utils + +class ReceivedBlockTrackerSuite + extends FunSuite with BeforeAndAfter with Matchers with Logging { + + val conf = new SparkConf().setMaster("local[2]").setAppName("ReceivedBlockTrackerSuite") + conf.set("spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", "1") + + val hadoopConf = new Configuration() + val akkaTimeout = 10 seconds + val streamId = 1 + + var allReceivedBlockTrackers = new ArrayBuffer[ReceivedBlockTracker]() + var checkpointDirectory: File = null + + before { + checkpointDirectory = Files.createTempDir() + } + + after { + allReceivedBlockTrackers.foreach { _.stop() } + if (checkpointDirectory != null && checkpointDirectory.exists()) { + FileUtils.deleteDirectory(checkpointDirectory) + checkpointDirectory = null + } + } + + test("block addition, and block to batch allocation") { + val receivedBlockTracker = createTracker(enableCheckpoint = false) + receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual Seq.empty + + val blockInfos = generateBlockInfos() + blockInfos.map(receivedBlockTracker.addBlock) + + // Verify added blocks are unallocated blocks + receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual blockInfos + + // Allocate the blocks to a batch and verify that all of them have been allocated + receivedBlockTracker.allocateBlocksToBatch(1) + receivedBlockTracker.getBlocksOfBatchAndStream(1, streamId) shouldEqual blockInfos + receivedBlockTracker.getUnallocatedBlocks(streamId) shouldBe empty + + // Allocate no blocks to another batch + receivedBlockTracker.allocateBlocksToBatch(2) + receivedBlockTracker.getBlocksOfBatchAndStream(2, streamId) shouldBe empty + + // Verify that batch 2 cannot be allocated again + intercept[SparkException] { + receivedBlockTracker.allocateBlocksToBatch(2) + } + + // Verify that older batches cannot be allocated again + intercept[SparkException] { + receivedBlockTracker.allocateBlocksToBatch(1) + } + } + + test("block addition, block to batch allocation and cleanup with write ahead log") { + val manualClock = new ManualClock + conf.getInt( + "spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", -1) should be (1) + + // Set the time increment level to twice the rotation interval so that every increment creates + // a new log file + val timeIncrementMillis = 2000L + def incrementTime() { + manualClock.addToTime(timeIncrementMillis) + } + + // Generate and add blocks to the given tracker + def addBlockInfos(tracker: ReceivedBlockTracker): Seq[ReceivedBlockInfo] = { + val blockInfos = generateBlockInfos() + blockInfos.map(tracker.addBlock) + blockInfos + } + + // Print the data present in the log ahead files in the log directory + def printLogFiles(message: String) { + val fileContents = getWriteAheadLogFiles().map { file => + (s"\n>>>>> $file: <<<<<\n${getWrittenLogData(file).mkString("\n")}") + }.mkString("\n") + logInfo(s"\n\n=====================\n$message\n$fileContents\n=====================\n") + } + + // Start tracker and add blocks + val tracker1 = createTracker(enableCheckpoint = true, clock = manualClock) + val blockInfos1 = addBlockInfos(tracker1) + tracker1.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1 + + // Verify whether write ahead log has correct contents + val expectedWrittenData1 = blockInfos1.map(BlockAdditionEvent) + getWrittenLogData() shouldEqual expectedWrittenData1 + getWriteAheadLogFiles() should have size 1 + + // Restart tracker and verify recovered list of unallocated blocks + incrementTime() + val tracker2 = createTracker(enableCheckpoint = true, clock = manualClock) + tracker2.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1 + + // Allocate blocks to batch and verify whether the unallocated blocks got allocated + val batchTime1 = manualClock.currentTime + tracker2.allocateBlocksToBatch(batchTime1) + tracker2.getBlocksOfBatchAndStream(batchTime1, streamId) shouldEqual blockInfos1 + + // Add more blocks and allocate to another batch + incrementTime() + val batchTime2 = manualClock.currentTime + val blockInfos2 = addBlockInfos(tracker2) + tracker2.allocateBlocksToBatch(batchTime2) + tracker2.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2 + + // Verify whether log has correct contents + val expectedWrittenData2 = expectedWrittenData1 ++ + Seq(createBatchAllocation(batchTime1, blockInfos1)) ++ + blockInfos2.map(BlockAdditionEvent) ++ + Seq(createBatchAllocation(batchTime2, blockInfos2)) + getWrittenLogData() shouldEqual expectedWrittenData2 + + // Restart tracker and verify recovered state + incrementTime() + val tracker3 = createTracker(enableCheckpoint = true, clock = manualClock) + tracker3.getBlocksOfBatchAndStream(batchTime1, streamId) shouldEqual blockInfos1 + tracker3.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2 + tracker3.getUnallocatedBlocks(streamId) shouldBe empty + + // Cleanup first batch but not second batch + val oldestLogFile = getWriteAheadLogFiles().head + incrementTime() + tracker3.cleanupOldBatches(batchTime2) + + // Verify that the batch allocations have been cleaned, and the act has been written to log + tracker3.getBlocksOfBatchAndStream(batchTime1, streamId) shouldEqual Seq.empty + getWrittenLogData(getWriteAheadLogFiles().last) should contain(createBatchCleanup(batchTime1)) + + // Verify that at least one log file gets deleted + eventually(timeout(10 seconds), interval(10 millisecond)) { + getWriteAheadLogFiles() should not contain oldestLogFile + } + printLogFiles("After cleanup") + + // Restart tracker and verify recovered state, specifically whether info about the first + // batch has been removed, but not the second batch + incrementTime() + val tracker4 = createTracker(enableCheckpoint = true, clock = manualClock) + tracker4.getUnallocatedBlocks(streamId) shouldBe empty + tracker4.getBlocksOfBatchAndStream(batchTime1, streamId) shouldBe empty // should be cleaned + tracker4.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2 + } + + /** + * Create tracker object with the optional provided clock. Use fake clock if you + * want to control time by manually incrementing it to test log cleanup. + */ + def createTracker(enableCheckpoint: Boolean, clock: Clock = new SystemClock): ReceivedBlockTracker = { + val cpDirOption = if (enableCheckpoint) Some(checkpointDirectory.toString) else None + val tracker = new ReceivedBlockTracker(conf, hadoopConf, Seq(streamId), clock, cpDirOption) + allReceivedBlockTrackers += tracker + tracker + } + + /** Generate blocks infos using random ids */ + def generateBlockInfos(): Seq[ReceivedBlockInfo] = { + List.fill(5)(ReceivedBlockInfo(streamId, 0, + BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt))))) + } + + /** Get all the data written in the given write ahead log file. */ + def getWrittenLogData(logFile: String): Seq[ReceivedBlockTrackerLogEvent] = { + getWrittenLogData(Seq(logFile)) + } + + /** + * Get all the data written in the given write ahead log files. By default, it will read all + * files in the test log directory. + */ + def getWrittenLogData(logFiles: Seq[String] = getWriteAheadLogFiles): Seq[ReceivedBlockTrackerLogEvent] = { + logFiles.flatMap { + file => new WriteAheadLogReader(file, hadoopConf).toSeq + }.map { byteBuffer => + Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array) + }.toList + } + + /** Get all the write ahead log files in the test directory */ + def getWriteAheadLogFiles(): Seq[String] = { + import ReceivedBlockTracker._ + val logDir = checkpointDirToLogDir(checkpointDirectory.toString) + getLogFilesInDirectory(logDir).map { _.toString } + } + + /** Create batch allocation object from the given info */ + def createBatchAllocation(time: Long, blockInfos: Seq[ReceivedBlockInfo]): BatchAllocationEvent = { + BatchAllocationEvent(time, AllocatedBlocks(Map((streamId -> blockInfos)))) + } + + /** Create batch cleanup object from the given info */ + def createBatchCleanup(time: Long, moreTimes: Long*): BatchCleanupEvent = { + BatchCleanupEvent((Seq(time) ++ moreTimes).map(Time.apply)) + } + + implicit def millisToTime(milliseconds: Long): Time = Time(milliseconds) + + implicit def timeToMillis(time: Time): Long = time.milliseconds +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index 10160244bcc9..d2b983c4b4d1 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -117,12 +117,12 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll { ) // Create the RDD and verify whether the returned data is correct - val rdd = new WriteAheadLogBackedBlockRDD[String](sparkContext, hadoopConf, blockIds.toArray, + val rdd = new WriteAheadLogBackedBlockRDD[String](sparkContext, blockIds.toArray, segments.toArray, storeInBlockManager = false, StorageLevel.MEMORY_ONLY) assert(rdd.collect() === data.flatten) if (testStoreInBM) { - val rdd2 = new WriteAheadLogBackedBlockRDD[String](sparkContext, hadoopConf, blockIds.toArray, + val rdd2 = new WriteAheadLogBackedBlockRDD[String](sparkContext, blockIds.toArray, segments.toArray, storeInBlockManager = true, StorageLevel.MEMORY_ONLY) assert(rdd2.collect() === data.flatten) assert( From 9cba88c7f9fdf151217716e4cc5fa75995736922 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 5 Nov 2014 10:33:13 -0800 Subject: [PATCH 019/652] [SPARK-4197] [mllib] GradientBoosting API cleanup and examples in Scala, Java MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Summary * Made it easier to construct default Strategy and BoostingStrategy and to set parameters using simple types. * Added Scala and Java examples for GradientBoostedTrees * small cleanups and fixes ### Details GradientBoosting bug fixes (“bug” = bad default options) * Force boostingStrategy.weakLearnerParams.algo = Regression * Force boostingStrategy.weakLearnerParams.impurity = impurity.Variance * Only persist data if not yet persisted (since it causes an error if persisted twice) BoostingStrategy * numEstimators: renamed to numIterations * removed subsamplingRate (duplicated by Strategy) * removed categoricalFeaturesInfo since it belongs with the weak learner params (since boosting can be oblivious to feature type) * Changed algo to var (not val) and added BeanProperty, with overload taking String argument * Added assertValid() method * Updated defaultParams() method and eliminated defaultWeakLearnerParams() since that belongs in Strategy Strategy (for DecisionTree) * Changed algo to var (not val) and added BeanProperty, with overload taking String argument * Added setCategoricalFeaturesInfo method taking Java Map. * Cleaned up assertValid * Changed val’s to def’s since parameters can now be changed. CC: manishamde mengxr codedeft Author: Joseph K. Bradley Closes #3094 from jkbradley/gbt-api and squashes the following commits: 7a27e22 [Joseph K. Bradley] scalastyle fix 52013d5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into gbt-api e9b8410 [Joseph K. Bradley] Summary of changes (cherry picked from commit 5b3b6f6f5f029164d7749366506e142b104c1d43) Signed-off-by: Xiangrui Meng --- .../mllib/JavaGradientBoostedTrees.java | 126 +++++++++++++ .../examples/mllib/DecisionTreeRunner.scala | 64 +++++-- .../examples/mllib/GradientBoostedTrees.scala | 146 +++++++++++++++ .../spark/mllib/tree/GradientBoosting.scala | 169 ++++++------------ .../tree/configuration/BoostingStrategy.scala | 78 ++++---- .../mllib/tree/configuration/Strategy.scala | 51 ++++-- .../mllib/tree/GradientBoostingSuite.scala | 34 ++-- 7 files changed, 462 insertions(+), 206 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java new file mode 100644 index 000000000000..1af2067b2b92 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import scala.Tuple2; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.GradientBoosting; +import org.apache.spark.mllib.tree.configuration.BoostingStrategy; +import org.apache.spark.mllib.tree.model.WeightedEnsembleModel; +import org.apache.spark.mllib.util.MLUtils; + +/** + * Classification and regression using gradient-boosted decision trees. + */ +public final class JavaGradientBoostedTrees { + + private static void usage() { + System.err.println("Usage: JavaGradientBoostedTrees " + + " "); + System.exit(-1); + } + + public static void main(String[] args) { + String datapath = "data/mllib/sample_libsvm_data.txt"; + String algo = "Classification"; + if (args.length >= 1) { + datapath = args[0]; + } + if (args.length >= 2) { + algo = args[1]; + } + if (args.length > 2) { + usage(); + } + SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + + JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); + + // Set parameters. + // Note: All features are treated as continuous. + BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo); + boostingStrategy.setNumIterations(10); + boostingStrategy.weakLearnerParams().setMaxDepth(5); + + if (algo.equals("Classification")) { + // Compute the number of classes from the data. + Integer numClasses = data.map(new Function() { + @Override public Double call(LabeledPoint p) { + return p.label(); + } + }).countByValue().size(); + boostingStrategy.setNumClassesForClassification(numClasses); // ignored for Regression + + // Train a GradientBoosting model for classification. + final WeightedEnsembleModel model = GradientBoosting.trainClassifier(data, boostingStrategy); + + // Evaluate model on training instances and compute training error + JavaPairRDD predictionAndLabel = + data.mapToPair(new PairFunction() { + @Override public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double trainErr = + 1.0 * predictionAndLabel.filter(new Function, Boolean>() { + @Override public Boolean call(Tuple2 pl) { + return !pl._1().equals(pl._2()); + } + }).count() / data.count(); + System.out.println("Training error: " + trainErr); + System.out.println("Learned classification tree model:\n" + model); + } else if (algo.equals("Regression")) { + // Train a GradientBoosting model for classification. + final WeightedEnsembleModel model = GradientBoosting.trainRegressor(data, boostingStrategy); + + // Evaluate model on training instances and compute training error + JavaPairRDD predictionAndLabel = + data.mapToPair(new PairFunction() { + @Override public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double trainMSE = + predictionAndLabel.map(new Function, Double>() { + @Override public Double call(Tuple2 pl) { + Double diff = pl._1() - pl._2(); + return diff * diff; + } + }).reduce(new Function2() { + @Override public Double call(Double a, Double b) { + return a + b; + } + }) / data.count(); + System.out.println("Training Mean Squared Error: " + trainMSE); + System.out.println("Learned regression tree model:\n" + model); + } else { + usage(); + } + + sc.stop(); + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 49751a30491d..63f02cf7b98b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -154,20 +154,30 @@ object DecisionTreeRunner { } } - def run(params: Params) { - - val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params") - val sc = new SparkContext(conf) - - println(s"DecisionTreeRunner with parameters:\n$params") - + /** + * Load training and test data from files. + * @param input Path to input dataset. + * @param dataFormat "libsvm" or "dense" + * @param testInput Path to test dataset. + * @param algo Classification or Regression + * @param fracTest Fraction of input data to hold out for testing. Ignored if testInput given. + * @return (training dataset, test dataset, number of classes), + * where the number of classes is inferred from data (and set to 0 for Regression) + */ + private[mllib] def loadDatasets( + sc: SparkContext, + input: String, + dataFormat: String, + testInput: String, + algo: Algo, + fracTest: Double): (RDD[LabeledPoint], RDD[LabeledPoint], Int) = { // Load training data and cache it. - val origExamples = params.dataFormat match { - case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache() - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input).cache() + val origExamples = dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, input).cache() + case "libsvm" => MLUtils.loadLibSVMFile(sc, input).cache() } // For classification, re-index classes if needed. - val (examples, classIndexMap, numClasses) = params.algo match { + val (examples, classIndexMap, numClasses) = algo match { case Classification => { // classCounts: class --> # examples in class val classCounts = origExamples.map(_.label).countByValue() @@ -205,14 +215,14 @@ object DecisionTreeRunner { } // Create training, test sets. - val splits = if (params.testInput != "") { + val splits = if (testInput != "") { // Load testInput. val numFeatures = examples.take(1)(0).features.size - val origTestExamples = params.dataFormat match { - case "dense" => MLUtils.loadLabeledPoints(sc, params.testInput) - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput, numFeatures) + val origTestExamples = dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, testInput) + case "libsvm" => MLUtils.loadLibSVMFile(sc, testInput, numFeatures) } - params.algo match { + algo match { case Classification => { // classCounts: class --> # examples in class val testExamples = { @@ -229,17 +239,31 @@ object DecisionTreeRunner { } } else { // Split input into training, test. - examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest)) + examples.randomSplit(Array(1.0 - fracTest, fracTest)) } val training = splits(0).cache() val test = splits(1).cache() + val numTraining = training.count() val numTest = test.count() - println(s"numTraining = $numTraining, numTest = $numTest.") examples.unpersist(blocking = false) + (training, test, numClasses) + } + + def run(params: Params) { + + val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params") + val sc = new SparkContext(conf) + + println(s"DecisionTreeRunner with parameters:\n$params") + + // Load training and test data and cache it. + val (training, test, numClasses) = loadDatasets(sc, params.input, params.dataFormat, + params.testInput, params.algo, params.fracTest) + val impurityCalculator = params.impurity match { case Gini => impurity.Gini case Entropy => impurity.Entropy @@ -338,7 +362,9 @@ object DecisionTreeRunner { /** * Calculates the mean squared error for regression. */ - private def meanSquaredError(tree: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = { + private[mllib] def meanSquaredError( + tree: WeightedEnsembleModel, + data: RDD[LabeledPoint]): Double = { data.map { y => val err = tree.predict(y.features) - y.label err * err diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala new file mode 100644 index 000000000000..9b6db01448be --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.tree.GradientBoosting +import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo} +import org.apache.spark.util.Utils + +/** + * An example runner for Gradient Boosting using decision trees as weak learners. Run with + * {{{ + * ./bin/run-example org.apache.spark.examples.mllib.GradientBoostedTrees [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + * + * Note: This script treats all features as real-valued (not categorical). + * To include categorical features, modify categoricalFeaturesInfo. + */ +object GradientBoostedTrees { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + algo: String = "Classification", + maxDepth: Int = 5, + numIterations: Int = 10, + fracTest: Double = 0.2) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("GradientBoostedTrees") { + head("GradientBoostedTrees: an example decision tree app.") + opt[String]("algo") + .text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}") + .action((x, c) => c.copy(algo = x)) + opt[Int]("maxDepth") + .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") + .action((x, c) => c.copy(maxDepth = x)) + opt[Int]("numIterations") + .text(s"number of iterations of boosting," + s" default: ${defaultParams.numIterations}") + .action((x, c) => c.copy(numIterations = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest > 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + + val conf = new SparkConf().setAppName(s"GradientBoostedTrees with $params") + val sc = new SparkContext(conf) + + println(s"GradientBoostedTrees with parameters:\n$params") + + // Load training and test data and cache it. + val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest) + + val boostingStrategy = BoostingStrategy.defaultParams(params.algo) + boostingStrategy.numClassesForClassification = numClasses + boostingStrategy.numIterations = params.numIterations + boostingStrategy.weakLearnerParams.maxDepth = params.maxDepth + + val randomSeed = Utils.random.nextInt() + if (params.algo == "Classification") { + val startTime = System.nanoTime() + val model = GradientBoosting.trainClassifier(training, boostingStrategy) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + if (model.totalNumNodes < 30) { + println(model.toDebugString) // Print full model. + } else { + println(model) // Print model summary. + } + val trainAccuracy = + new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label))) + .precision + println(s"Train accuracy = $trainAccuracy") + val testAccuracy = + new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision + println(s"Test accuracy = $testAccuracy") + } else if (params.algo == "Regression") { + val startTime = System.nanoTime() + val model = GradientBoosting.trainRegressor(training, boostingStrategy) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + if (model.totalNumNodes < 30) { + println(model.toDebugString) // Print full model. + } else { + println(model) // Print model summary. + } + val trainMSE = DecisionTreeRunner.meanSquaredError(model, training) + println(s"Train mean squared error = $trainMSE") + val testMSE = DecisionTreeRunner.meanSquaredError(model, test) + println(s"Test mean squared error = $testMSE") + } + + sc.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala index 1a847201ce15..f729344a682e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala @@ -17,30 +17,49 @@ package org.apache.spark.mllib.tree -import scala.collection.JavaConverters._ - +import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD -import org.apache.spark.mllib.tree.configuration.{Strategy, BoostingStrategy} -import org.apache.spark.Logging -import org.apache.spark.mllib.tree.impl.TimeTracker -import org.apache.spark.mllib.tree.loss.Losses -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel} import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.storage.StorageLevel +import org.apache.spark.mllib.tree.configuration.BoostingStrategy import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum +import org.apache.spark.mllib.tree.impl.TimeTracker +import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel /** * :: Experimental :: - * A class that implements gradient boosting for regression and binary classification problems. + * A class that implements Stochastic Gradient Boosting + * for regression and binary classification problems. + * + * The implementation is based upon: + * J.H. Friedman. "Stochastic Gradient Boosting." 1999. + * + * Notes: + * - This currently can be run with several loss functions. However, only SquaredError is + * fully supported. Specifically, the loss function should be used to compute the gradient + * (to re-label training instances on each iteration) and to weight weak hypotheses. + * Currently, gradients are computed correctly for the available loss functions, + * but weak hypothesis weights are not computed correctly for LogLoss or AbsoluteError. + * Running with those losses will likely behave reasonably, but lacks the same guarantees. + * * @param boostingStrategy Parameters for the gradient boosting algorithm */ @Experimental class GradientBoosting ( private val boostingStrategy: BoostingStrategy) extends Serializable with Logging { + boostingStrategy.weakLearnerParams.algo = Regression + boostingStrategy.weakLearnerParams.impurity = impurity.Variance + + // Ensure values for weak learner are the same as what is provided to the boosting algorithm. + boostingStrategy.weakLearnerParams.numClassesForClassification = + boostingStrategy.numClassesForClassification + + boostingStrategy.assertValid() + /** * Method to train a gradient boosting model * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. @@ -51,6 +70,7 @@ class GradientBoosting ( algo match { case Regression => GradientBoosting.boost(input, boostingStrategy) case Classification => + // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) GradientBoosting.boost(remappedInput, boostingStrategy) case _ => @@ -118,120 +138,32 @@ object GradientBoosting extends Logging { } /** - * Method to train a gradient boosting binary classification model. - * - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * For classification, labels should take values {0, 1, ..., numClasses-1}. - * For regression, labels are real numbers. - * @param numEstimators Number of estimators used in boosting stages. In other words, - * number of boosting iterations performed. - * @param loss Loss function used for minimization during gradient boosting. - * @param learningRate Learning rate for shrinking the contribution of each estimator. The - * learning rate should be between in the interval (0, 1] - * @param subsamplingRate Fraction of the training data used for learning the decision tree. - * @param numClassesForClassification Number of classes for classification. - * (Ignored for regression.) - * @param categoricalFeaturesInfo A map storing information about the categorical variables and - * the number of discrete values they take. For example, - * an entry (n -> k) implies the feature n is categorical with k - * categories 0, 1, 2, ... , k-1. It's important to note that - * features are zero-indexed. - * @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is - * supported.) - * @return WeightedEnsembleModel that can be used for prediction + * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#train]] */ - def trainClassifier( - input: RDD[LabeledPoint], - numEstimators: Int, - loss: String, - learningRate: Double, - subsamplingRate: Double, - numClassesForClassification: Int, - categoricalFeaturesInfo: Map[Int, Int], - weakLearnerParams: Strategy): WeightedEnsembleModel = { - val lossType = Losses.fromString(loss) - val boostingStrategy = new BoostingStrategy(Classification, numEstimators, lossType, - learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo, - weakLearnerParams) - new GradientBoosting(boostingStrategy).train(input) - } - - /** - * Method to train a gradient boosting regression model. - * - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * For classification, labels should take values {0, 1, ..., numClasses-1}. - * For regression, labels are real numbers. - * @param numEstimators Number of estimators used in boosting stages. In other words, - * number of boosting iterations performed. - * @param loss Loss function used for minimization during gradient boosting. - * @param learningRate Learning rate for shrinking the contribution of each estimator. The - * learning rate should be between in the interval (0, 1] - * @param subsamplingRate Fraction of the training data used for learning the decision tree. - * @param numClassesForClassification Number of classes for classification. - * (Ignored for regression.) - * @param categoricalFeaturesInfo A map storing information about the categorical variables and - * the number of discrete values they take. For example, - * an entry (n -> k) implies the feature n is categorical with k - * categories 0, 1, 2, ... , k-1. It's important to note that - * features are zero-indexed. - * @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is - * supported.) - * @return WeightedEnsembleModel that can be used for prediction - */ - def trainRegressor( - input: RDD[LabeledPoint], - numEstimators: Int, - loss: String, - learningRate: Double, - subsamplingRate: Double, - numClassesForClassification: Int, - categoricalFeaturesInfo: Map[Int, Int], - weakLearnerParams: Strategy): WeightedEnsembleModel = { - val lossType = Losses.fromString(loss) - val boostingStrategy = new BoostingStrategy(Regression, numEstimators, lossType, - learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo, - weakLearnerParams) - new GradientBoosting(boostingStrategy).train(input) + def train( + input: JavaRDD[LabeledPoint], + boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { + train(input.rdd, boostingStrategy) } /** * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]] */ def trainClassifier( - input: RDD[LabeledPoint], - numEstimators: Int, - loss: String, - learningRate: Double, - subsamplingRate: Double, - numClassesForClassification: Int, - categoricalFeaturesInfo:java.util.Map[java.lang.Integer, java.lang.Integer], - weakLearnerParams: Strategy): WeightedEnsembleModel = { - trainClassifier(input, numEstimators, loss, learningRate, subsamplingRate, - numClassesForClassification, - categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, - weakLearnerParams) + input: JavaRDD[LabeledPoint], + boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { + trainClassifier(input.rdd, boostingStrategy) } /** * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]] */ def trainRegressor( - input: RDD[LabeledPoint], - numEstimators: Int, - loss: String, - learningRate: Double, - subsamplingRate: Double, - numClassesForClassification: Int, - categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], - weakLearnerParams: Strategy): WeightedEnsembleModel = { - trainRegressor(input, numEstimators, loss, learningRate, subsamplingRate, - numClassesForClassification, - categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, - weakLearnerParams) + input: JavaRDD[LabeledPoint], + boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { + trainRegressor(input.rdd, boostingStrategy) } - /** * Internal method for performing regression using trees as base learners. * @param input training dataset @@ -247,15 +179,17 @@ object GradientBoosting extends Logging { timer.start("init") // Initialize gradient boosting parameters - val numEstimators = boostingStrategy.numEstimators - val baseLearners = new Array[DecisionTreeModel](numEstimators) - val baseLearnerWeights = new Array[Double](numEstimators) + val numIterations = boostingStrategy.numIterations + val baseLearners = new Array[DecisionTreeModel](numIterations) + val baseLearnerWeights = new Array[Double](numIterations) val loss = boostingStrategy.loss val learningRate = boostingStrategy.learningRate val strategy = boostingStrategy.weakLearnerParams // Cache input - input.persist(StorageLevel.MEMORY_AND_DISK) + if (input.getStorageLevel == StorageLevel.NONE) { + input.persist(StorageLevel.MEMORY_AND_DISK) + } timer.stop("init") @@ -264,7 +198,7 @@ object GradientBoosting extends Logging { logDebug("##########") var data = input - // 1. Initialize tree + // Initialize tree timer.start("building tree 0") val firstTreeModel = new DecisionTree(strategy).train(data) baseLearners(0) = firstTreeModel @@ -280,7 +214,7 @@ object GradientBoosting extends Logging { point.features)) var m = 1 - while (m < numEstimators) { + while (m < numIterations) { timer.start(s"building tree $m") logDebug("###################################################") logDebug("Gradient boosting tree iteration " + m) @@ -289,6 +223,9 @@ object GradientBoosting extends Logging { timer.stop(s"building tree $m") // Create partial model baseLearners(m) = model + // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError. + // Technically, the weight should be optimized for the particular loss. + // However, the behavior should be reasonable, though not optimal. baseLearnerWeights(m) = learningRate // Note: A model of type regression is used since we require raw prediction val partialModel = new WeightedEnsembleModel(baseLearners.slice(0, m + 1), @@ -305,8 +242,6 @@ object GradientBoosting extends Logging { logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") - - // 3. Output classifier new WeightedEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.algo, Sum) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index 501d9ff9ea9b..abbda040bd52 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -21,7 +21,6 @@ import scala.beans.BeanProperty import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.impurity.{Gini, Variance} import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} /** @@ -30,46 +29,58 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} * @param algo Learning goal. Supported: * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] - * @param numEstimators Number of estimators used in boosting stages. In other words, - * number of boosting iterations performed. + * @param numIterations Number of iterations of boosting. In other words, the number of + * weak hypotheses used in the final model. * @param loss Loss function used for minimization during gradient boosting. * @param learningRate Learning rate for shrinking the contribution of each estimator. The * learning rate should be between in the interval (0, 1] - * @param subsamplingRate Fraction of the training data used for learning the decision tree. * @param numClassesForClassification Number of classes for classification. * (Ignored for regression.) + * This setting overrides any setting in [[weakLearnerParams]]. * Default value is 2 (binary classification). - * @param categoricalFeaturesInfo A map storing information about the categorical variables and the - * number of discrete values they take. For example, an entry (n -> - * k) implies the feature n is categorical with k categories 0, - * 1, 2, ... , k-1. It's important to note that features are - * zero-indexed. * @param weakLearnerParams Parameters for weak learners. Currently only decision trees are * supported. */ @Experimental case class BoostingStrategy( // Required boosting parameters - algo: Algo, - @BeanProperty var numEstimators: Int, + @BeanProperty var algo: Algo, + @BeanProperty var numIterations: Int, @BeanProperty var loss: Loss, // Optional boosting parameters @BeanProperty var learningRate: Double = 0.1, - @BeanProperty var subsamplingRate: Double = 1.0, @BeanProperty var numClassesForClassification: Int = 2, - @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), @BeanProperty var weakLearnerParams: Strategy) extends Serializable { - require(learningRate <= 1, "Learning rate should be <= 1. Provided learning rate is " + - s"$learningRate.") - require(learningRate > 0, "Learning rate should be > 0. Provided learning rate is " + - s"$learningRate.") - // Ensure values for weak learner are the same as what is provided to the boosting algorithm. - weakLearnerParams.categoricalFeaturesInfo = categoricalFeaturesInfo weakLearnerParams.numClassesForClassification = numClassesForClassification - weakLearnerParams.subsamplingRate = subsamplingRate + /** + * Sets Algorithm using a String. + */ + def setAlgo(algo: String): Unit = algo match { + case "Classification" => setAlgo(Classification) + case "Regression" => setAlgo(Regression) + } + + /** + * Check validity of parameters. + * Throws exception if invalid. + */ + private[tree] def assertValid(): Unit = { + algo match { + case Classification => + require(numClassesForClassification == 2) + case Regression => + // nothing + case _ => + throw new IllegalArgumentException( + s"BoostingStrategy given invalid algo parameter: $algo." + + s" Valid settings are: Classification, Regression.") + } + require(learningRate > 0 && learningRate <= 1, + "Learning rate should be in range (0, 1]. Provided learning rate is " + s"$learningRate.") + } } @Experimental @@ -82,28 +93,17 @@ object BoostingStrategy { * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] * @return Configuration for boosting algorithm */ - def defaultParams(algo: Algo): BoostingStrategy = { - val treeStrategy = defaultWeakLearnerParams(algo) + def defaultParams(algo: String): BoostingStrategy = { + val treeStrategy = Strategy.defaultStrategy("Regression") + treeStrategy.maxDepth = 3 algo match { - case Classification => - new BoostingStrategy(algo, 100, LogLoss, weakLearnerParams = treeStrategy) - case Regression => - new BoostingStrategy(algo, 100, SquaredError, weakLearnerParams = treeStrategy) + case "Classification" => + new BoostingStrategy(Algo.withName(algo), 100, LogLoss, weakLearnerParams = treeStrategy) + case "Regression" => + new BoostingStrategy(Algo.withName(algo), 100, SquaredError, + weakLearnerParams = treeStrategy) case _ => throw new IllegalArgumentException(s"$algo is not supported by the boosting.") } } - - /** - * Returns default configuration for the weak learner (decision tree) algorithm - * @param algo Learning goal. Supported: - * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], - * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] - * @return Configuration for weak learner - */ - def defaultWeakLearnerParams(algo: Algo): Strategy = { - // Note: Regression tree used even for classification for GBT. - new Strategy(Regression, Variance, 3) - } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index d09295c507d6..b5b1f82177ed 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -70,7 +70,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ */ @Experimental class Strategy ( - val algo: Algo, + @BeanProperty var algo: Algo, @BeanProperty var impurity: Impurity, @BeanProperty var maxDepth: Int, @BeanProperty var numClassesForClassification: Int = 2, @@ -85,17 +85,9 @@ class Strategy ( @BeanProperty var checkpointDir: Option[String] = None, @BeanProperty var checkpointInterval: Int = 10) extends Serializable { - if (algo == Classification) { - require(numClassesForClassification >= 2) - } - require(minInstancesPerNode >= 1, - s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") - require(maxMemoryInMB <= 10240, - s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB") - - val isMulticlassClassification = + def isMulticlassClassification = algo == Classification && numClassesForClassification > 2 - val isMulticlassWithCategoricalFeatures + def isMulticlassWithCategoricalFeatures = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) /** @@ -112,6 +104,23 @@ class Strategy ( categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) } + /** + * Sets Algorithm using a String. + */ + def setAlgo(algo: String): Unit = algo match { + case "Classification" => setAlgo(Classification) + case "Regression" => setAlgo(Regression) + } + + /** + * Sets categoricalFeaturesInfo using a Java Map. + */ + def setCategoricalFeaturesInfo( + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]): Unit = { + setCategoricalFeaturesInfo( + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) + } + /** * Check validity of parameters. * Throws exception if invalid. @@ -143,6 +152,26 @@ class Strategy ( s"DecisionTree Strategy given invalid categoricalFeaturesInfo setting:" + s" feature $feature has $arity categories. The number of categories should be >= 2.") } + require(minInstancesPerNode >= 1, + s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") + require(maxMemoryInMB <= 10240, + s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB") } +} + +@Experimental +object Strategy { + /** + * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] + * @param algo "Classification" or "Regression" + */ + def defaultStrategy(algo: String): Strategy = algo match { + case "Classification" => + new Strategy(algo = Classification, impurity = Gini, maxDepth = 10, + numClassesForClassification = 2) + case "Regression" => + new Strategy(algo = Regression, impurity = Variance, maxDepth = 10, + numClassesForClassification = 0) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala index 970fff82215e..99a02eda60ba 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala @@ -22,9 +22,8 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} -import org.apache.spark.mllib.tree.impurity.{Variance, Gini} +import org.apache.spark.mllib.tree.impurity.Variance import org.apache.spark.mllib.tree.loss.{SquaredError, LogLoss} -import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel} import org.apache.spark.mllib.util.LocalSparkContext @@ -34,9 +33,8 @@ import org.apache.spark.mllib.util.LocalSparkContext class GradientBoostingSuite extends FunSuite with LocalSparkContext { test("Regression with continuous features: SquaredError") { - GradientBoostingSuite.testCombinations.foreach { - case (numEstimators, learningRate, subsamplingRate) => + case (numIterations, learningRate, subsamplingRate) => val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) val categoricalFeaturesInfo = Map.empty[Int, Int] @@ -48,11 +46,11 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext { val dt = DecisionTree.train(remappedInput, treeStrategy) - val boostingStrategy = new BoostingStrategy(Regression, numEstimators, SquaredError, - subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy) + val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError, + learningRate, 1, treeStrategy) val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy) - assert(gbt.weakHypotheses.size === numEstimators) + assert(gbt.weakHypotheses.size === numIterations) val gbtTree = gbt.weakHypotheses(0) EnsembleTestHelper.validateRegressor(gbt, arr, 0.02) @@ -63,9 +61,8 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext { } test("Regression with continuous features: Absolute Error") { - GradientBoostingSuite.testCombinations.foreach { - case (numEstimators, learningRate, subsamplingRate) => + case (numIterations, learningRate, subsamplingRate) => val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) val categoricalFeaturesInfo = Map.empty[Int, Int] @@ -77,11 +74,11 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext { val dt = DecisionTree.train(remappedInput, treeStrategy) - val boostingStrategy = new BoostingStrategy(Regression, numEstimators, SquaredError, - subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy) + val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError, + learningRate, numClassesForClassification = 2, treeStrategy) val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy) - assert(gbt.weakHypotheses.size === numEstimators) + assert(gbt.weakHypotheses.size === numIterations) val gbtTree = gbt.weakHypotheses(0) EnsembleTestHelper.validateRegressor(gbt, arr, 0.02) @@ -91,11 +88,9 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext { } } - test("Binary classification with continuous features: Log Loss") { - GradientBoostingSuite.testCombinations.foreach { - case (numEstimators, learningRate, subsamplingRate) => + case (numIterations, learningRate, subsamplingRate) => val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) val categoricalFeaturesInfo = Map.empty[Int, Int] @@ -107,11 +102,11 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext { val dt = DecisionTree.train(remappedInput, treeStrategy) - val boostingStrategy = new BoostingStrategy(Classification, numEstimators, LogLoss, - subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy) + val boostingStrategy = new BoostingStrategy(Classification, numIterations, LogLoss, + learningRate, numClassesForClassification = 2, treeStrategy) val gbt = GradientBoosting.trainClassifier(rdd, boostingStrategy) - assert(gbt.weakHypotheses.size === numEstimators) + assert(gbt.weakHypotheses.size === numIterations) val gbtTree = gbt.weakHypotheses(0) EnsembleTestHelper.validateClassifier(gbt, arr, 0.9) @@ -126,7 +121,6 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext { object GradientBoostingSuite { // Combinations for estimators, learning rates and subsamplingRate - val testCombinations - = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75)) + val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75)) } From 236434033fe452e70dbd0236935a49693712e130 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Tue, 4 Nov 2014 16:15:38 -0800 Subject: [PATCH 020/652] [SPARK-2938] Support SASL authentication in NettyBlockTransferService Also lays the groundwork for supporting it inside the external shuffle service. Author: Aaron Davidson Closes #3087 from aarondav/sasl and squashes the following commits: 3481718 [Aaron Davidson] Delete rogue println 44f8410 [Aaron Davidson] Delete documentation - muahaha! eb9f065 [Aaron Davidson] Improve documentation and add end-to-end test at Spark-level a6b95f1 [Aaron Davidson] Address comments 785bbde [Aaron Davidson] Cleanup 79973cb [Aaron Davidson] Remove unused file 151b3c5 [Aaron Davidson] Add docs, timeout config, better failure handling f6177d7 [Aaron Davidson] Cleanup SASL state upon connection termination 7b42adb [Aaron Davidson] Add unit tests 8191bcb [Aaron Davidson] [SPARK-2938] Support SASL authentication in NettyBlockTransferService --- .../org/apache/spark/SecurityManager.scala | 23 ++- .../scala/org/apache/spark/SparkConf.scala | 6 + .../scala/org/apache/spark/SparkContext.scala | 2 + .../scala/org/apache/spark/SparkEnv.scala | 3 +- .../org/apache/spark/SparkSaslClient.scala | 147 --------------- .../org/apache/spark/SparkSaslServer.scala | 176 ------------------ .../org/apache/spark/executor/Executor.scala | 1 + .../netty/NettyBlockTransferService.scala | 28 ++- .../apache/spark/network/nio/Connection.scala | 5 +- .../spark/network/nio/ConnectionManager.scala | 7 +- .../apache/spark/storage/BlockManager.scala | 45 +++-- .../NettyBlockTransferSecuritySuite.scala | 161 ++++++++++++++++ .../network/nio/ConnectionManagerSuite.scala | 6 +- .../BlockManagerReplicationSuite.scala | 2 + .../spark/storage/BlockManagerSuite.scala | 4 +- docs/security.md | 1 - .../spark/network/TransportContext.java | 15 +- .../spark/network/client/TransportClient.java | 11 +- .../client/TransportClientBootstrap.java | 32 ++++ .../client/TransportClientFactory.java | 64 +++++-- .../spark/network/server/NoOpRpcHandler.java | 2 +- .../spark/network/server/RpcHandler.java | 19 +- .../server/TransportRequestHandler.java | 1 + .../spark/network/util/TransportConf.java | 3 + .../network/sasl/SaslClientBootstrap.java | 74 ++++++++ .../spark/network/sasl/SaslMessage.java | 74 ++++++++ .../spark/network/sasl/SaslRpcHandler.java | 97 ++++++++++ .../spark/network/sasl/SecretKeyHolder.java | 35 ++++ .../spark/network/sasl/SparkSaslClient.java | 138 ++++++++++++++ .../spark/network/sasl/SparkSaslServer.java | 170 +++++++++++++++++ .../shuffle/ExternalShuffleBlockHandler.java | 2 +- .../shuffle/ExternalShuffleClient.java | 15 +- .../spark/network/shuffle/ShuffleClient.java | 11 +- .../network/sasl/SaslIntegrationSuite.java | 172 +++++++++++++++++ .../spark/network/sasl/SparkSaslSuite.java | 89 +++++++++ .../ExternalShuffleIntegrationSuite.java | 7 +- .../streaming/ReceivedBlockHandlerSuite.scala | 1 + 37 files changed, 1257 insertions(+), 392 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/SparkSaslClient.scala delete mode 100644 core/src/main/scala/org/apache/spark/SparkSaslServer.scala create mode 100644 core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala create mode 100644 network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java create mode 100644 network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java create mode 100644 network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 0e0f1a7b2377..dee935ffad51 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -22,6 +22,7 @@ import java.net.{Authenticator, PasswordAuthentication} import org.apache.hadoop.io.Text import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.network.sasl.SecretKeyHolder /** * Spark class responsible for security. @@ -84,7 +85,7 @@ import org.apache.spark.deploy.SparkHadoopUtil * Authenticator installed in the SecurityManager to how it does the authentication * and in this case gets the user name and password from the request. * - * - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously + * - BlockTransferService -> The Spark BlockTransferServices uses java nio to asynchronously * exchange messages. For this we use the Java SASL * (Simple Authentication and Security Layer) API and again use DIGEST-MD5 * as the authentication mechanism. This means the shared secret is not passed @@ -98,7 +99,7 @@ import org.apache.spark.deploy.SparkHadoopUtil * of protection they want. If we support those, the messages will also have to * be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's. * - * Since the connectionManager does asynchronous messages passing, the SASL + * Since the NioBlockTransferService does asynchronous messages passing, the SASL * authentication is a bit more complex. A ConnectionManager can be both a client * and a Server, so for a particular connection is has to determine what to do. * A ConnectionId was added to be able to track connections and is used to @@ -107,6 +108,10 @@ import org.apache.spark.deploy.SparkHadoopUtil * and waits for the response from the server and does the handshake before sending * the real message. * + * The NettyBlockTransferService ensures that SASL authentication is performed + * synchronously prior to any other communication on a connection. This is done in + * SaslClientBootstrap on the client side and SaslRpcHandler on the server side. + * * - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters * can be used. Yarn requires a specific AmIpFilter be installed for security to work * properly. For non-Yarn deployments, users can write a filter to go through a @@ -139,7 +144,7 @@ import org.apache.spark.deploy.SparkHadoopUtil * can take place. */ -private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { +private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with SecretKeyHolder { // key used to store the spark secret in the Hadoop UGI private val sparkSecretLookupKey = "sparkCookie" @@ -337,4 +342,16 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { * @return the secret key as a String if authentication is enabled, otherwise returns null */ def getSecretKey(): String = secretKey + + override def getSaslUser(appId: String): String = { + val myAppId = sparkConf.getAppId + require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}") + getSaslUser() + } + + override def getSecretKey(appId: String): String = { + val myAppId = sparkConf.getAppId + require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}") + getSecretKey() + } } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index ad0a9017afea..4c6c86c7bad7 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -217,6 +217,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { */ getAll.filter { case (k, _) => isAkkaConf(k) } + /** + * Returns the Spark application id, valid in the Driver after TaskScheduler registration and + * from the start in the Executor. + */ + def getAppId: String = get("spark.app.id") + /** Does the configuration contain a given parameter? */ def contains(key: String): Boolean = settings.contains(key) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 8b4db783979e..d65027d18e2d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -313,6 +313,8 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { val applicationId: String = taskScheduler.applicationId() conf.set("spark.app.id", applicationId) + env.blockManager.initialize(applicationId) + val metricsSystem = env.metricsSystem // The metrics system for Driver need to be set spark.app.id to app ID. diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index e2f13accdfab..45e9d7f243e9 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -276,7 +276,7 @@ object SparkEnv extends Logging { val blockTransferService = conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match { case "netty" => - new NettyBlockTransferService(conf) + new NettyBlockTransferService(conf, securityManager) case "nio" => new NioBlockTransferService(conf, securityManager) } @@ -285,6 +285,7 @@ object SparkEnv extends Logging { "BlockManagerMaster", new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver) + // NB: blockManager is not valid until initialize() is called later. val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer, conf, mapOutputTracker, shuffleManager, blockTransferService) diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala deleted file mode 100644 index a954fcc0c31f..000000000000 --- a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala +++ /dev/null @@ -1,147 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import javax.security.auth.callback.Callback -import javax.security.auth.callback.CallbackHandler -import javax.security.auth.callback.NameCallback -import javax.security.auth.callback.PasswordCallback -import javax.security.auth.callback.UnsupportedCallbackException -import javax.security.sasl.RealmCallback -import javax.security.sasl.RealmChoiceCallback -import javax.security.sasl.Sasl -import javax.security.sasl.SaslClient -import javax.security.sasl.SaslException - -import scala.collection.JavaConversions.mapAsJavaMap - -import com.google.common.base.Charsets.UTF_8 - -/** - * Implements SASL Client logic for Spark - */ -private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logging { - - /** - * Used to respond to server's counterpart, SaslServer with SASL tokens - * represented as byte arrays. - * - * The authentication mechanism used here is DIGEST-MD5. This could be changed to be - * configurable in the future. - */ - private var saslClient: SaslClient = Sasl.createSaslClient(Array[String](SparkSaslServer.DIGEST), - null, null, SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS, - new SparkSaslClientCallbackHandler(securityMgr)) - - /** - * Used to initiate SASL handshake with server. - * @return response to challenge if needed - */ - def firstToken(): Array[Byte] = { - synchronized { - val saslToken: Array[Byte] = - if (saslClient != null && saslClient.hasInitialResponse()) { - logDebug("has initial response") - saslClient.evaluateChallenge(new Array[Byte](0)) - } else { - new Array[Byte](0) - } - saslToken - } - } - - /** - * Determines whether the authentication exchange has completed. - * @return true is complete, otherwise false - */ - def isComplete(): Boolean = { - synchronized { - if (saslClient != null) saslClient.isComplete() else false - } - } - - /** - * Respond to server's SASL token. - * @param saslTokenMessage contains server's SASL token - * @return client's response SASL token - */ - def saslResponse(saslTokenMessage: Array[Byte]): Array[Byte] = { - synchronized { - if (saslClient != null) saslClient.evaluateChallenge(saslTokenMessage) else new Array[Byte](0) - } - } - - /** - * Disposes of any system resources or security-sensitive information the - * SaslClient might be using. - */ - def dispose() { - synchronized { - if (saslClient != null) { - try { - saslClient.dispose() - } catch { - case e: SaslException => // ignored - } finally { - saslClient = null - } - } - } - } - - /** - * Implementation of javax.security.auth.callback.CallbackHandler - * that works with share secrets. - */ - private class SparkSaslClientCallbackHandler(securityMgr: SecurityManager) extends - CallbackHandler { - - private val userName: String = - SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8)) - private val secretKey = securityMgr.getSecretKey() - private val userPassword: Array[Char] = SparkSaslServer.encodePassword( - if (secretKey != null) secretKey.getBytes(UTF_8) else "".getBytes(UTF_8)) - - /** - * Implementation used to respond to SASL request from the server. - * - * @param callbacks objects that indicate what credential information the - * server's SaslServer requires from the client. - */ - override def handle(callbacks: Array[Callback]) { - logDebug("in the sasl client callback handler") - callbacks foreach { - case nc: NameCallback => { - logDebug("handle: SASL client callback: setting username: " + userName) - nc.setName(userName) - } - case pc: PasswordCallback => { - logDebug("handle: SASL client callback: setting userPassword") - pc.setPassword(userPassword) - } - case rc: RealmCallback => { - logDebug("handle: SASL client callback: setting realm: " + rc.getDefaultText()) - rc.setText(rc.getDefaultText()) - } - case cb: RealmChoiceCallback => {} - case cb: Callback => throw - new UnsupportedCallbackException(cb, "handle: Unrecognized SASL client callback") - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala deleted file mode 100644 index 7c2afb364661..000000000000 --- a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala +++ /dev/null @@ -1,176 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import javax.security.auth.callback.Callback -import javax.security.auth.callback.CallbackHandler -import javax.security.auth.callback.NameCallback -import javax.security.auth.callback.PasswordCallback -import javax.security.auth.callback.UnsupportedCallbackException -import javax.security.sasl.AuthorizeCallback -import javax.security.sasl.RealmCallback -import javax.security.sasl.Sasl -import javax.security.sasl.SaslException -import javax.security.sasl.SaslServer -import scala.collection.JavaConversions.mapAsJavaMap - -import com.google.common.base.Charsets.UTF_8 -import org.apache.commons.net.util.Base64 - -/** - * Encapsulates SASL server logic - */ -private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Logging { - - /** - * Actual SASL work done by this object from javax.security.sasl. - */ - private var saslServer: SaslServer = Sasl.createSaslServer(SparkSaslServer.DIGEST, null, - SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS, - new SparkSaslDigestCallbackHandler(securityMgr)) - - /** - * Determines whether the authentication exchange has completed. - * @return true is complete, otherwise false - */ - def isComplete(): Boolean = { - synchronized { - if (saslServer != null) saslServer.isComplete() else false - } - } - - /** - * Used to respond to server SASL tokens. - * @param token Server's SASL token - * @return response to send back to the server. - */ - def response(token: Array[Byte]): Array[Byte] = { - synchronized { - if (saslServer != null) saslServer.evaluateResponse(token) else new Array[Byte](0) - } - } - - /** - * Disposes of any system resources or security-sensitive information the - * SaslServer might be using. - */ - def dispose() { - synchronized { - if (saslServer != null) { - try { - saslServer.dispose() - } catch { - case e: SaslException => // ignore - } finally { - saslServer = null - } - } - } - } - - /** - * Implementation of javax.security.auth.callback.CallbackHandler - * for SASL DIGEST-MD5 mechanism - */ - private class SparkSaslDigestCallbackHandler(securityMgr: SecurityManager) - extends CallbackHandler { - - private val userName: String = - SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8)) - - override def handle(callbacks: Array[Callback]) { - logDebug("In the sasl server callback handler") - callbacks foreach { - case nc: NameCallback => { - logDebug("handle: SASL server callback: setting username") - nc.setName(userName) - } - case pc: PasswordCallback => { - logDebug("handle: SASL server callback: setting userPassword") - val password: Array[Char] = - SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes(UTF_8)) - pc.setPassword(password) - } - case rc: RealmCallback => { - logDebug("handle: SASL server callback: setting realm: " + rc.getDefaultText()) - rc.setText(rc.getDefaultText()) - } - case ac: AuthorizeCallback => { - val authid = ac.getAuthenticationID() - val authzid = ac.getAuthorizationID() - if (authid.equals(authzid)) { - logDebug("set auth to true") - ac.setAuthorized(true) - } else { - logDebug("set auth to false") - ac.setAuthorized(false) - } - if (ac.isAuthorized()) { - logDebug("sasl server is authorized") - ac.setAuthorizedID(authzid) - } - } - case cb: Callback => throw - new UnsupportedCallbackException(cb, "handle: Unrecognized SASL DIGEST-MD5 Callback") - } - } - } -} - -private[spark] object SparkSaslServer { - - /** - * This is passed as the server name when creating the sasl client/server. - * This could be changed to be configurable in the future. - */ - val SASL_DEFAULT_REALM = "default" - - /** - * The authentication mechanism used here is DIGEST-MD5. This could be changed to be - * configurable in the future. - */ - val DIGEST = "DIGEST-MD5" - - /** - * The quality of protection is just "auth". This means that we are doing - * authentication only, we are not supporting integrity or privacy protection of the - * communication channel after authentication. This could be changed to be configurable - * in the future. - */ - val SASL_PROPS = Map(Sasl.QOP -> "auth", Sasl.SERVER_AUTH ->"true") - - /** - * Encode a byte[] identifier as a Base64-encoded string. - * - * @param identifier identifier to encode - * @return Base64-encoded string - */ - def encodeIdentifier(identifier: Array[Byte]): String = { - new String(Base64.encodeBase64(identifier), UTF_8) - } - - /** - * Encode a password as a base64-encoded char[] array. - * @param password as a byte array. - * @return password as a char array. - */ - def encodePassword(password: Array[Byte]): Array[Char] = { - new String(Base64.encodeBase64(password), UTF_8).toCharArray() - } -} - diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index e24a15f015e1..7dd5265891c3 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -86,6 +86,7 @@ private[spark] class Executor( conf, executorId, slaveHostname, port, isLocal, actorSystem) SparkEnv.set(_env) _env.metricsSystem.registerSource(executorSource) + _env.blockManager.initialize(conf.getAppId) _env } else { SparkEnv.get diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 1c4327cf13b5..0d1fc81d2a16 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,13 +17,15 @@ package org.apache.spark.network.netty +import scala.collection.JavaConversions._ import scala.concurrent.{Future, Promise} -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.network.client.{RpcResponseCallback, TransportClientFactory} +import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory} import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock} +import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap} import org.apache.spark.network.server._ import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher} import org.apache.spark.serializer.JavaSerializer @@ -33,18 +35,30 @@ import org.apache.spark.util.Utils /** * A BlockTransferService that uses Netty to fetch a set of blocks at at time. */ -class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { +class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManager) + extends BlockTransferService { + // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. - val serializer = new JavaSerializer(conf) + private val serializer = new JavaSerializer(conf) + private val authEnabled = securityManager.isAuthenticationEnabled() + private val transportConf = SparkTransportConf.fromSparkConf(conf) private[this] var transportContext: TransportContext = _ private[this] var server: TransportServer = _ private[this] var clientFactory: TransportClientFactory = _ override def init(blockDataManager: BlockDataManager): Unit = { - val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager) - transportContext = new TransportContext(SparkTransportConf.fromSparkConf(conf), rpcHandler) - clientFactory = transportContext.createClientFactory() + val (rpcHandler: RpcHandler, bootstrap: Option[TransportClientBootstrap]) = { + val nettyRpcHandler = new NettyBlockRpcServer(serializer, blockDataManager) + if (!authEnabled) { + (nettyRpcHandler, None) + } else { + (new SaslRpcHandler(nettyRpcHandler, securityManager), + Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager))) + } + } + transportContext = new TransportContext(transportConf, rpcHandler) + clientFactory = transportContext.createClientFactory(bootstrap.toList) server = transportContext.createServer() logInfo("Server created on " + server.getPort) } diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index 4f6f5e235811..c2d9578be7eb 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -23,12 +23,13 @@ import java.nio.channels._ import java.util.concurrent.ConcurrentLinkedQueue import java.util.LinkedList -import org.apache.spark._ - import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal +import org.apache.spark._ +import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} + private[nio] abstract class Connection(val channel: SocketChannel, val selector: Selector, val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId, diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 8408b75bb4d6..f198aa8564a5 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -34,6 +34,7 @@ import scala.language.postfixOps import com.google.common.base.Charsets.UTF_8 import org.apache.spark._ +import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} import org.apache.spark.util.Utils import scala.util.Try @@ -600,7 +601,7 @@ private[nio] class ConnectionManager( } else { var replyToken : Array[Byte] = null try { - replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken) + replyToken = waitingConn.sparkSaslClient.response(securityMsg.getToken) if (waitingConn.isSaslComplete()) { logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId) connectionsAwaitingSasl -= waitingConn.connectionId @@ -634,7 +635,7 @@ private[nio] class ConnectionManager( connection.synchronized { if (connection.sparkSaslServer == null) { logDebug("Creating sasl Server") - connection.sparkSaslServer = new SparkSaslServer(securityManager) + connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager) } } replyToken = connection.sparkSaslServer.response(securityMsg.getToken) @@ -778,7 +779,7 @@ private[nio] class ConnectionManager( if (!conn.isSaslComplete()) { conn.synchronized { if (conn.sparkSaslClient == null) { - conn.sparkSaslClient = new SparkSaslClient(securityManager) + conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager) var firstResponse: Array[Byte] = null try { firstResponse = conn.sparkSaslClient.firstToken() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 5f5dd0dc1c63..655d16c65c8b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -57,6 +57,12 @@ private[spark] class BlockResult( inputMetrics.bytesRead = bytes } +/** + * Manager running on every node (driver and executors) which provides interfaces for putting and + * retrieving blocks both locally and remotely into various stores (memory, disk, and off-heap). + * + * Note that #initialize() must be called before the BlockManager is usable. + */ private[spark] class BlockManager( executorId: String, actorSystem: ActorSystem, @@ -69,8 +75,6 @@ private[spark] class BlockManager( blockTransferService: BlockTransferService) extends BlockDataManager with Logging { - blockTransferService.init(this) - val diskBlockManager = new DiskBlockManager(this, conf) private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] @@ -102,22 +106,16 @@ private[spark] class BlockManager( + " switch to sort-based shuffle.") } - val blockManagerId = BlockManagerId( - executorId, blockTransferService.hostName, blockTransferService.port) + var blockManagerId: BlockManagerId = _ // Address of the server that serves this executor's shuffle files. This is either an external // service, or just our own Executor's BlockManager. - private[spark] val shuffleServerId = if (externalShuffleServiceEnabled) { - BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort) - } else { - blockManagerId - } + private[spark] var shuffleServerId: BlockManagerId = _ // Client to read other executors' shuffle files. This is either an external service, or just the // standard BlockTranserService to directly connect to other Executors. private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { - val appId = conf.get("spark.app.id", "unknown-app-id") - new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf), appId) + new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf)) } else { blockTransferService } @@ -150,8 +148,6 @@ private[spark] class BlockManager( private val peerFetchLock = new Object private var lastPeerFetchTime = 0L - initialize() - /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay * the initialization of the compression codec until it is first used. The reason is that a Spark * program could be using a user-defined codec in a third party jar, which is loaded in @@ -176,10 +172,27 @@ private[spark] class BlockManager( } /** - * Initialize the BlockManager. Register to the BlockManagerMaster, and start the - * BlockManagerWorker actor. Additionally registers with a local shuffle service if configured. + * Initializes the BlockManager with the given appId. This is not performed in the constructor as + * the appId may not be known at BlockManager instantiation time (in particular for the driver, + * where it is only learned after registration with the TaskScheduler). + * + * This method initializes the BlockTransferService and ShuffleClient, registers with the + * BlockManagerMaster, starts the BlockManagerWorker actor, and registers with a local shuffle + * service if configured. */ - private def initialize(): Unit = { + def initialize(appId: String): Unit = { + blockTransferService.init(this) + shuffleClient.init(appId) + + blockManagerId = BlockManagerId( + executorId, blockTransferService.hostName, blockTransferService.port) + + shuffleServerId = if (externalShuffleServiceEnabled) { + BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort) + } else { + blockManagerId + } + master.registerBlockManager(blockManagerId, maxMemory, slaveActor) // Register Executors' configuration with the local shuffle service, if one should exist. diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala new file mode 100644 index 000000000000..bed0ed9d713d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.nio._ +import java.util.concurrent.TimeUnit + +import scala.concurrent.duration._ +import scala.concurrent.{Await, Promise} +import scala.util.{Failure, Success, Try} + +import org.apache.commons.io.IOUtils +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.network.{BlockDataManager, BlockTransferService} +import org.apache.spark.storage.{BlockId, ShuffleBlockId} +import org.apache.spark.{SecurityManager, SparkConf} +import org.mockito.Mockito._ +import org.scalatest.mock.MockitoSugar +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite, ShouldMatchers} + +class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with ShouldMatchers { + test("security default off") { + testConnection(new SparkConf, new SparkConf) match { + case Success(_) => // expected + case Failure(t) => fail(t) + } + } + + test("security on same password") { + val conf = new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.app.id", "app-id") + testConnection(conf, conf) match { + case Success(_) => // expected + case Failure(t) => fail(t) + } + } + + test("security on mismatch password") { + val conf0 = new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.app.id", "app-id") + val conf1 = conf0.clone.set("spark.authenticate.secret", "bad") + testConnection(conf0, conf1) match { + case Success(_) => fail("Should have failed") + case Failure(t) => t.getMessage should include ("Mismatched response") + } + } + + test("security mismatch auth off on server") { + val conf0 = new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.app.id", "app-id") + val conf1 = conf0.clone.set("spark.authenticate", "false") + testConnection(conf0, conf1) match { + case Success(_) => fail("Should have failed") + case Failure(t) => // any funny error may occur, sever will interpret SASL token as RPC + } + } + + test("security mismatch auth off on client") { + val conf0 = new SparkConf() + .set("spark.authenticate", "false") + .set("spark.authenticate.secret", "good") + .set("spark.app.id", "app-id") + val conf1 = conf0.clone.set("spark.authenticate", "true") + testConnection(conf0, conf1) match { + case Success(_) => fail("Should have failed") + case Failure(t) => t.getMessage should include ("Expected SaslMessage") + } + } + + test("security mismatch app ids") { + val conf0 = new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.app.id", "app-id") + val conf1 = conf0.clone.set("spark.app.id", "other-id") + testConnection(conf0, conf1) match { + case Success(_) => fail("Should have failed") + case Failure(t) => t.getMessage should include ("SASL appId app-id did not match") + } + } + + /** + * Creates two servers with different configurations and sees if they can talk. + * Returns Success() if they can transfer a block, and Failure() if the block transfer was failed + * properly. We will throw an out-of-band exception if something other than that goes wrong. + */ + private def testConnection(conf0: SparkConf, conf1: SparkConf): Try[Unit] = { + val blockManager = mock[BlockDataManager] + val blockId = ShuffleBlockId(0, 1, 2) + val blockString = "Hello, world!" + val blockBuffer = new NioManagedBuffer(ByteBuffer.wrap(blockString.getBytes)) + when(blockManager.getBlockData(blockId)).thenReturn(blockBuffer) + + val securityManager0 = new SecurityManager(conf0) + val exec0 = new NettyBlockTransferService(conf0, securityManager0) + exec0.init(blockManager) + + val securityManager1 = new SecurityManager(conf1) + val exec1 = new NettyBlockTransferService(conf1, securityManager1) + exec1.init(blockManager) + + val result = fetchBlock(exec0, exec1, "1", blockId) match { + case Success(buf) => + IOUtils.toString(buf.createInputStream()) should equal(blockString) + buf.release() + Success() + case Failure(t) => + Failure(t) + } + exec0.close() + exec1.close() + result + } + + /** Synchronously fetches a single block, acting as the given executor fetching from another. */ + private def fetchBlock( + self: BlockTransferService, + from: BlockTransferService, + execId: String, + blockId: BlockId): Try[ManagedBuffer] = { + + val promise = Promise[ManagedBuffer]() + + self.fetchBlocks(from.hostName, from.port, execId, Array(blockId.toString), + new BlockFetchingListener { + override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { + promise.failure(exception) + } + + override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { + promise.success(data.retain()) + } + }) + + Await.ready(promise.future, FiniteDuration(1000, TimeUnit.MILLISECONDS)) + promise.future.value.get + } +} + diff --git a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala index b70734dfe37c..716f875d30b8 100644 --- a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala @@ -60,6 +60,7 @@ class ConnectionManagerSuite extends FunSuite { val conf = new SparkConf conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") + conf.set("spark.app.id", "app-id") val securityManager = new SecurityManager(conf) val manager = new ConnectionManager(0, conf, securityManager) var numReceivedMessages = 0 @@ -95,6 +96,7 @@ class ConnectionManagerSuite extends FunSuite { test("security mismatch password") { val conf = new SparkConf conf.set("spark.authenticate", "true") + conf.set("spark.app.id", "app-id") conf.set("spark.authenticate.secret", "good") val securityManager = new SecurityManager(conf) val manager = new ConnectionManager(0, conf, securityManager) @@ -105,9 +107,7 @@ class ConnectionManagerSuite extends FunSuite { None }) - val badconf = new SparkConf - badconf.set("spark.authenticate", "true") - badconf.set("spark.authenticate.secret", "bad") + val badconf = conf.clone.set("spark.authenticate.secret", "bad") val badsecurityManager = new SecurityManager(badconf) val managerServer = new ConnectionManager(0, badconf, badsecurityManager) var numReceivedServerMessages = 0 diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index c6d710559209..1461fa69db90 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -63,6 +63,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd val transfer = new NioBlockTransferService(conf, securityMgr) val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer) + store.initialize("app-id") allStores += store store } @@ -263,6 +264,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd when(failableTransfer.port).thenReturn(1000) val failableStore = new BlockManager("failable-store", actorSystem, master, serializer, 10000, conf, mapOutputTracker, shuffleManager, failableTransfer) + failableStore.initialize("app-id") allStores += failableStore // so that this gets stopped after test assert(master.getPeers(store.blockManagerId).toSet === Set(failableStore.blockManagerId)) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 715b740b857b..0782876c8e3c 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -73,8 +73,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) - new BlockManager(name, actorSystem, master, serializer, maxMem, conf, + val manager = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer) + manager.initialize("app-id") + manager } before { diff --git a/docs/security.md b/docs/security.md index ec0523184d66..1e206a139fb7 100644 --- a/docs/security.md +++ b/docs/security.md @@ -7,7 +7,6 @@ Spark currently supports authentication via a shared secret. Authentication can * For Spark on [YARN](running-on-yarn.html) deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. * For other types of Spark deployments, the Spark parameter `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. -* **IMPORTANT NOTE:** *The experimental Netty shuffle path (`spark.shuffle.use.netty`) is not secured, so do not use Netty for shuffles if running with authentication.* ## Web UI diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index a271841e4e56..5bc6e5a2418a 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -17,12 +17,16 @@ package org.apache.spark.network; +import java.util.List; + +import com.google.common.collect.Lists; import io.netty.channel.Channel; import io.netty.channel.socket.SocketChannel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.client.TransportResponseHandler; import org.apache.spark.network.protocol.MessageDecoder; @@ -64,8 +68,17 @@ public TransportContext(TransportConf conf, RpcHandler rpcHandler) { this.decoder = new MessageDecoder(); } + /** + * Initializes a ClientFactory which runs the given TransportClientBootstraps prior to returning + * a new Client. Bootstraps will be executed synchronously, and must run successfully in order + * to create a Client. + */ + public TransportClientFactory createClientFactory(List bootstraps) { + return new TransportClientFactory(this, bootstraps); + } + public TransportClientFactory createClientFactory() { - return new TransportClientFactory(this); + return createClientFactory(Lists.newArrayList()); } /** Create a server which will attempt to bind to a specific port. */ diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index 01c143fff423..a08cee02dd57 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -19,10 +19,9 @@ import java.io.Closeable; import java.util.UUID; -import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; +import com.google.common.base.Objects; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import com.google.common.util.concurrent.SettableFuture; @@ -186,4 +185,12 @@ public void close() { // close is a local operation and should finish with milliseconds; timeout just to be safe channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS); } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("remoteAdress", channel.remoteAddress()) + .add("isActive", isActive()) + .toString(); + } } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java new file mode 100644 index 000000000000..65e8020e3412 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +/** + * A bootstrap which is executed on a TransportClient before it is returned to the user. + * This enables an initial exchange of information (e.g., SASL authentication tokens) on a once-per- + * connection basis. + * + * Since connections (and TransportClients) are reused as much as possible, it is generally + * reasonable to perform an expensive bootstrapping operation, as they often share a lifespan with + * the JVM itself. + */ +public interface TransportClientBootstrap { + /** Performs the bootstrapping operation, throwing an exception on failure. */ + public void doBootstrap(TransportClient client) throws RuntimeException; +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 0b4a1d828640..1723fed30725 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -21,10 +21,14 @@ import java.lang.reflect.Field; import java.net.InetSocketAddress; import java.net.SocketAddress; +import java.util.List; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import com.google.common.collect.Lists; import io.netty.bootstrap.Bootstrap; import io.netty.buffer.PooledByteBufAllocator; import io.netty.channel.Channel; @@ -40,6 +44,7 @@ import org.apache.spark.network.TransportContext; import org.apache.spark.network.server.TransportChannelHandler; import org.apache.spark.network.util.IOMode; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.TransportConf; @@ -47,22 +52,29 @@ * Factory for creating {@link TransportClient}s by using createClient. * * The factory maintains a connection pool to other hosts and should return the same - * {@link TransportClient} for the same remote host. It also shares a single worker thread pool for - * all {@link TransportClient}s. + * TransportClient for the same remote host. It also shares a single worker thread pool for + * all TransportClients. + * + * TransportClients will be reused whenever possible. Prior to completing the creation of a new + * TransportClient, all given {@link TransportClientBootstrap}s will be run. */ public class TransportClientFactory implements Closeable { private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class); private final TransportContext context; private final TransportConf conf; + private final List clientBootstraps; private final ConcurrentHashMap connectionPool; private final Class socketChannelClass; private EventLoopGroup workerGroup; - public TransportClientFactory(TransportContext context) { - this.context = context; + public TransportClientFactory( + TransportContext context, + List clientBootstraps) { + this.context = Preconditions.checkNotNull(context); this.conf = context.getConf(); + this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps)); this.connectionPool = new ConcurrentHashMap(); IOMode ioMode = IOMode.valueOf(conf.ioMode()); @@ -72,9 +84,12 @@ public TransportClientFactory(TransportContext context) { } /** - * Create a new BlockFetchingClient connecting to the given remote host / port. + * Create a new {@link TransportClient} connecting to the given remote host / port. This will + * reuse TransportClients if they are still active and are for the same remote address. Prior + * to the creation of a new TransportClient, we will execute all {@link TransportClientBootstrap}s + * that are registered with this factory. * - * This blocks until a connection is successfully established. + * This blocks until a connection is successfully established and fully bootstrapped. * * Concurrency: This method is safe to call from multiple threads. */ @@ -104,17 +119,18 @@ public TransportClient createClient(String remoteHost, int remotePort) { // Use pooled buffers to reduce temporary buffer allocation bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator()); - final AtomicReference client = new AtomicReference(); + final AtomicReference clientRef = new AtomicReference(); bootstrap.handler(new ChannelInitializer() { @Override public void initChannel(SocketChannel ch) { TransportChannelHandler clientHandler = context.initializePipeline(ch); - client.set(clientHandler.getClient()); + clientRef.set(clientHandler.getClient()); } }); // Connect to the remote server + long preConnect = System.currentTimeMillis(); ChannelFuture cf = bootstrap.connect(address); if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) { throw new RuntimeException( @@ -123,15 +139,35 @@ public void initChannel(SocketChannel ch) { throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause()); } - // Successful connection -- in the event that two threads raced to create a client, we will + TransportClient client = clientRef.get(); + assert client != null : "Channel future completed successfully with null client"; + + // Execute any client bootstraps synchronously before marking the Client as successful. + long preBootstrap = System.currentTimeMillis(); + logger.debug("Connection to {} successful, running bootstraps...", address); + try { + for (TransportClientBootstrap clientBootstrap : clientBootstraps) { + clientBootstrap.doBootstrap(client); + } + } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala + long bootstrapTime = System.currentTimeMillis() - preBootstrap; + logger.error("Exception while bootstrapping client after " + bootstrapTime + " ms", e); + client.close(); + throw Throwables.propagate(e); + } + long postBootstrap = System.currentTimeMillis(); + + // Successful connection & bootstrap -- in the event that two threads raced to create a client, // use the first one that was put into the connectionPool and close the one we made here. - assert client.get() != null : "Channel future completed successfully with null client"; - TransportClient oldClient = connectionPool.putIfAbsent(address, client.get()); + TransportClient oldClient = connectionPool.putIfAbsent(address, client); if (oldClient == null) { - return client.get(); + logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)", + address, postBootstrap - preConnect, postBootstrap - preBootstrap); + return client; } else { - logger.debug("Two clients were created concurrently, second one will be disposed."); - client.get().close(); + logger.debug("Two clients were created concurrently after {} ms, second will be disposed.", + postBootstrap - preConnect); + client.close(); return oldClient; } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java index 5a3f003726fc..1502b7489e86 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java @@ -21,7 +21,7 @@ import org.apache.spark.network.client.TransportClient; /** An RpcHandler suitable for a client-only TransportContext, which cannot receive RPCs. */ -public class NoOpRpcHandler implements RpcHandler { +public class NoOpRpcHandler extends RpcHandler { private final StreamManager streamManager; public NoOpRpcHandler() { diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 2369dc620394..2ba92a40f8b0 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -23,22 +23,33 @@ /** * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s. */ -public interface RpcHandler { +public abstract class RpcHandler { /** * Receive a single RPC message. Any exception thrown while in this method will be sent back to * the client in string form as a standard RPC failure. * + * This method will not be called in parallel for a single TransportClient (i.e., channel). + * * @param client A channel client which enables the handler to make requests back to the sender - * of this RPC. + * of this RPC. This will always be the exact same object for a particular channel. * @param message The serialized bytes of the RPC. * @param callback Callback which should be invoked exactly once upon success or failure of the * RPC. */ - void receive(TransportClient client, byte[] message, RpcResponseCallback callback); + public abstract void receive( + TransportClient client, + byte[] message, + RpcResponseCallback callback); /** * Returns the StreamManager which contains the state about which streams are currently being * fetched by a TransportClient. */ - StreamManager getStreamManager(); + public abstract StreamManager getStreamManager(); + + /** + * Invoked when the connection associated with the given client has been invalidated. + * No further requests will come from this client. + */ + public void connectionTerminated(TransportClient client) { } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 17fe9001b35c..1580180cc17e 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -86,6 +86,7 @@ public void channelUnregistered() { for (long streamId : streamIds) { streamManager.connectionTerminated(streamId); } + rpcHandler.connectionTerminated(reverseClient); } @Override diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index a68f38e0e94c..823790dd3c66 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -55,4 +55,7 @@ public int connectionTimeoutMs() { /** Send buffer size (SO_SNDBUF). */ public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); } + + /** Timeout for a single round trip of SASL token exchange, in milliseconds. */ + public int saslRTTimeout() { return conf.getInt("spark.shuffle.sasl.timeout", 30000); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java new file mode 100644 index 000000000000..7bc91e375371 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.util.TransportConf; + +/** + * Bootstraps a {@link TransportClient} by performing SASL authentication on the connection. The + * server should be setup with a {@link SaslRpcHandler} with matching keys for the given appId. + */ +public class SaslClientBootstrap implements TransportClientBootstrap { + private final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class); + + private final TransportConf conf; + private final String appId; + private final SecretKeyHolder secretKeyHolder; + + public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) { + this.conf = conf; + this.appId = appId; + this.secretKeyHolder = secretKeyHolder; + } + + /** + * Performs SASL authentication by sending a token, and then proceeding with the SASL + * challenge-response tokens until we either successfully authenticate or throw an exception + * due to mismatch. + */ + @Override + public void doBootstrap(TransportClient client) { + SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder); + try { + byte[] payload = saslClient.firstToken(); + + while (!saslClient.isComplete()) { + SaslMessage msg = new SaslMessage(appId, payload); + ByteBuf buf = Unpooled.buffer(msg.encodedLength()); + msg.encode(buf); + + byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeout()); + payload = saslClient.response(response); + } + } finally { + try { + // Once authentication is complete, the server will trust all remaining communication. + saslClient.dispose(); + } catch (RuntimeException e) { + logger.error("Error while disposing SASL client", e); + } + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java new file mode 100644 index 000000000000..5b77e18c26bf --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import com.google.common.base.Charsets; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encodable; + +/** + * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged + * with the given appId. This appId allows a single SaslRpcHandler to multiplex different + * applications which may be using different sets of credentials. + */ +class SaslMessage implements Encodable { + + /** Serialization tag used to catch incorrect payloads. */ + private static final byte TAG_BYTE = (byte) 0xEA; + + public final String appId; + public final byte[] payload; + + public SaslMessage(String appId, byte[] payload) { + this.appId = appId; + this.payload = payload; + } + + @Override + public int encodedLength() { + // tag + appIdLength + appId + payloadLength + payload + return 1 + 4 + appId.getBytes(Charsets.UTF_8).length + 4 + payload.length; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeByte(TAG_BYTE); + byte[] idBytes = appId.getBytes(Charsets.UTF_8); + buf.writeInt(idBytes.length); + buf.writeBytes(idBytes); + buf.writeInt(payload.length); + buf.writeBytes(payload); + } + + public static SaslMessage decode(ByteBuf buf) { + if (buf.readByte() != TAG_BYTE) { + throw new IllegalStateException("Expected SaslMessage, received something else"); + } + + int idLength = buf.readInt(); + byte[] idBytes = new byte[idLength]; + buf.readBytes(idBytes); + + int payloadLength = buf.readInt(); + byte[] payload = new byte[payloadLength]; + buf.readBytes(payload); + + return new SaslMessage(new String(idBytes, Charsets.UTF_8), payload); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java new file mode 100644 index 000000000000..3777a18e33f7 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.util.concurrent.ConcurrentMap; + +import com.google.common.base.Charsets; +import com.google.common.collect.Maps; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; + +/** + * RPC Handler which performs SASL authentication before delegating to a child RPC handler. + * The delegate will only receive messages if the given connection has been successfully + * authenticated. A connection may be authenticated at most once. + * + * Note that the authentication process consists of multiple challenge-response pairs, each of + * which are individual RPCs. + */ +public class SaslRpcHandler extends RpcHandler { + private final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class); + + /** RpcHandler we will delegate to for authenticated connections. */ + private final RpcHandler delegate; + + /** Class which provides secret keys which are shared by server and client on a per-app basis. */ + private final SecretKeyHolder secretKeyHolder; + + /** Maps each channel to its SASL authentication state. */ + private final ConcurrentMap channelAuthenticationMap; + + public SaslRpcHandler(RpcHandler delegate, SecretKeyHolder secretKeyHolder) { + this.delegate = delegate; + this.secretKeyHolder = secretKeyHolder; + this.channelAuthenticationMap = Maps.newConcurrentMap(); + } + + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + SparkSaslServer saslServer = channelAuthenticationMap.get(client); + if (saslServer != null && saslServer.isComplete()) { + // Authentication complete, delegate to base handler. + delegate.receive(client, message, callback); + return; + } + + SaslMessage saslMessage = SaslMessage.decode(Unpooled.wrappedBuffer(message)); + + if (saslServer == null) { + // First message in the handshake, setup the necessary state. + saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder); + channelAuthenticationMap.put(client, saslServer); + } + + byte[] response = saslServer.response(saslMessage.payload); + if (saslServer.isComplete()) { + logger.debug("SASL authentication successful for channel {}", client); + } + callback.onSuccess(response); + } + + @Override + public StreamManager getStreamManager() { + return delegate.getStreamManager(); + } + + @Override + public void connectionTerminated(TransportClient client) { + SparkSaslServer saslServer = channelAuthenticationMap.remove(client); + if (saslServer != null) { + saslServer.dispose(); + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java new file mode 100644 index 000000000000..81d576679468 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +/** + * Interface for getting a secret key associated with some application. + */ +public interface SecretKeyHolder { + /** + * Gets an appropriate SASL User for the given appId. + * @throws IllegalArgumentException if the given appId is not associated with a SASL user. + */ + String getSaslUser(String appId); + + /** + * Gets an appropriate SASL secret key for the given appId. + * @throws IllegalArgumentException if the given appId is not associated with a SASL secret key. + */ + String getSecretKey(String appId); +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java new file mode 100644 index 000000000000..72ba737b998b --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.sasl.RealmCallback; +import javax.security.sasl.RealmChoiceCallback; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslClient; +import javax.security.sasl.SaslException; +import java.io.IOException; + +import com.google.common.base.Throwables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.spark.network.sasl.SparkSaslServer.*; + +/** + * A SASL Client for Spark which simply keeps track of the state of a single SASL session, from the + * initial state to the "authenticated" state. This client initializes the protocol via a + * firstToken, which is then followed by a set of challenges and responses. + */ +public class SparkSaslClient { + private final Logger logger = LoggerFactory.getLogger(SparkSaslClient.class); + + private final String secretKeyId; + private final SecretKeyHolder secretKeyHolder; + private SaslClient saslClient; + + public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder) { + this.secretKeyId = secretKeyId; + this.secretKeyHolder = secretKeyHolder; + try { + this.saslClient = Sasl.createSaslClient(new String[] { DIGEST }, null, null, DEFAULT_REALM, + SASL_PROPS, new ClientCallbackHandler()); + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** Used to initiate SASL handshake with server. */ + public synchronized byte[] firstToken() { + if (saslClient != null && saslClient.hasInitialResponse()) { + try { + return saslClient.evaluateChallenge(new byte[0]); + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } else { + return new byte[0]; + } + } + + /** Determines whether the authentication exchange has completed. */ + public synchronized boolean isComplete() { + return saslClient != null && saslClient.isComplete(); + } + + /** + * Respond to server's SASL token. + * @param token contains server's SASL token + * @return client's response SASL token + */ + public synchronized byte[] response(byte[] token) { + try { + return saslClient != null ? saslClient.evaluateChallenge(token) : new byte[0]; + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** + * Disposes of any system resources or security-sensitive information the + * SaslClient might be using. + */ + public synchronized void dispose() { + if (saslClient != null) { + try { + saslClient.dispose(); + } catch (SaslException e) { + // ignore + } finally { + saslClient = null; + } + } + } + + /** + * Implementation of javax.security.auth.callback.CallbackHandler + * that works with share secrets. + */ + private class ClientCallbackHandler implements CallbackHandler { + @Override + public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { + + for (Callback callback : callbacks) { + if (callback instanceof NameCallback) { + logger.trace("SASL client callback: setting username"); + NameCallback nc = (NameCallback) callback; + nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId))); + } else if (callback instanceof PasswordCallback) { + logger.trace("SASL client callback: setting password"); + PasswordCallback pc = (PasswordCallback) callback; + pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); + } else if (callback instanceof RealmCallback) { + logger.trace("SASL client callback: setting realm"); + RealmCallback rc = (RealmCallback) callback; + rc.setText(rc.getDefaultText()); + logger.info("Realm callback"); + } else if (callback instanceof RealmChoiceCallback) { + // ignore (?) + } else { + throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback"); + } + } + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java new file mode 100644 index 000000000000..2c0ce40c75e8 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.sasl.AuthorizeCallback; +import javax.security.sasl.RealmCallback; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslException; +import javax.security.sasl.SaslServer; +import java.io.IOException; +import java.util.Map; + +import com.google.common.base.Charsets; +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableMap; +import com.google.common.io.BaseEncoding; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A SASL Server for Spark which simply keeps track of the state of a single SASL session, from the + * initial state to the "authenticated" state. (It is not a server in the sense of accepting + * connections on some socket.) + */ +public class SparkSaslServer { + private final Logger logger = LoggerFactory.getLogger(SparkSaslServer.class); + + /** + * This is passed as the server name when creating the sasl client/server. + * This could be changed to be configurable in the future. + */ + static final String DEFAULT_REALM = "default"; + + /** + * The authentication mechanism used here is DIGEST-MD5. This could be changed to be + * configurable in the future. + */ + static final String DIGEST = "DIGEST-MD5"; + + /** + * The quality of protection is just "auth". This means that we are doing + * authentication only, we are not supporting integrity or privacy protection of the + * communication channel after authentication. This could be changed to be configurable + * in the future. + */ + static final Map SASL_PROPS = ImmutableMap.builder() + .put(Sasl.QOP, "auth") + .put(Sasl.SERVER_AUTH, "true") + .build(); + + /** Identifier for a certain secret key within the secretKeyHolder. */ + private final String secretKeyId; + private final SecretKeyHolder secretKeyHolder; + private SaslServer saslServer; + + public SparkSaslServer(String secretKeyId, SecretKeyHolder secretKeyHolder) { + this.secretKeyId = secretKeyId; + this.secretKeyHolder = secretKeyHolder; + try { + this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, SASL_PROPS, + new DigestCallbackHandler()); + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** + * Determines whether the authentication exchange has completed successfully. + */ + public synchronized boolean isComplete() { + return saslServer != null && saslServer.isComplete(); + } + + /** + * Used to respond to server SASL tokens. + * @param token Server's SASL token + * @return response to send back to the server. + */ + public synchronized byte[] response(byte[] token) { + try { + return saslServer != null ? saslServer.evaluateResponse(token) : new byte[0]; + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** + * Disposes of any system resources or security-sensitive information the + * SaslServer might be using. + */ + public synchronized void dispose() { + if (saslServer != null) { + try { + saslServer.dispose(); + } catch (SaslException e) { + // ignore + } finally { + saslServer = null; + } + } + } + + /** + * Implementation of javax.security.auth.callback.CallbackHandler for SASL DIGEST-MD5 mechanism. + */ + private class DigestCallbackHandler implements CallbackHandler { + @Override + public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { + for (Callback callback : callbacks) { + if (callback instanceof NameCallback) { + logger.trace("SASL server callback: setting username"); + NameCallback nc = (NameCallback) callback; + nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId))); + } else if (callback instanceof PasswordCallback) { + logger.trace("SASL server callback: setting password"); + PasswordCallback pc = (PasswordCallback) callback; + pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); + } else if (callback instanceof RealmCallback) { + logger.trace("SASL server callback: setting realm"); + RealmCallback rc = (RealmCallback) callback; + rc.setText(rc.getDefaultText()); + } else if (callback instanceof AuthorizeCallback) { + AuthorizeCallback ac = (AuthorizeCallback) callback; + String authId = ac.getAuthenticationID(); + String authzId = ac.getAuthorizationID(); + ac.setAuthorized(authId.equals(authzId)); + if (ac.isAuthorized()) { + ac.setAuthorizedID(authzId); + } + logger.debug("SASL Authorization complete, authorized set to {}", ac.isAuthorized()); + } else { + throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback"); + } + } + } + } + + /* Encode a byte[] identifier as a Base64-encoded string. */ + public static String encodeIdentifier(String identifier) { + Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled"); + return BaseEncoding.base64().encode(identifier.getBytes(Charsets.UTF_8)); + } + + /** Encode a password as a base64-encoded char[] array. */ + public static char[] encodePassword(String password) { + Preconditions.checkNotNull(password, "Password cannot be null if SASL is enabled"); + return BaseEncoding.base64().encode(password.getBytes(Charsets.UTF_8)).toCharArray(); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index a9dff31decc8..cd3fea85b19a 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -41,7 +41,7 @@ * with the "one-for-one" strategy, meaning each Transport-layer Chunk is equivalent to one Spark- * level shuffle block. */ -public class ExternalShuffleBlockHandler implements RpcHandler { +public class ExternalShuffleBlockHandler extends RpcHandler { private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockHandler.class); private final ExternalShuffleBlockManager blockManager; diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 6bbabc44b958..b0b19ba67bdd 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -17,8 +17,6 @@ package org.apache.spark.network.shuffle; -import java.io.Closeable; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -36,15 +34,20 @@ * BlockTransferService), which has the downside of losing the shuffle data if we lose the * executors. */ -public class ExternalShuffleClient implements ShuffleClient { +public class ExternalShuffleClient extends ShuffleClient { private final Logger logger = LoggerFactory.getLogger(ExternalShuffleClient.class); private final TransportClientFactory clientFactory; - private final String appId; - public ExternalShuffleClient(TransportConf conf, String appId) { + private String appId; + + public ExternalShuffleClient(TransportConf conf) { TransportContext context = new TransportContext(conf, new NoOpRpcHandler()); this.clientFactory = context.createClientFactory(); + } + + @Override + public void init(String appId) { this.appId = appId; } @@ -55,6 +58,7 @@ public void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener) { + assert appId != null : "Called before init()"; logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { TransportClient client = clientFactory.createClient(host, port); @@ -82,6 +86,7 @@ public void registerWithShuffleServer( int port, String execId, ExecutorShuffleInfo executorInfo) { + assert appId != null : "Called before init()"; TransportClient client = clientFactory.createClient(host, port); byte[] registerExecutorMessage = JavaUtils.serialize(new RegisterExecutor(appId, execId, executorInfo)); diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index d46a56239455..f72ab40690d0 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -20,7 +20,14 @@ import java.io.Closeable; /** Provides an interface for reading shuffle files, either from an Executor or external service. */ -public interface ShuffleClient extends Closeable { +public abstract class ShuffleClient implements Closeable { + + /** + * Initializes the ShuffleClient, specifying this Executor's appId. + * Must be called before any other method on the ShuffleClient. + */ + public void init(String appId) { } + /** * Fetch a sequence of blocks from a remote node asynchronously, * @@ -28,7 +35,7 @@ public interface ShuffleClient extends Closeable { * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as * the data of a block is fetched, rather than waiting for all blocks to be fetched. */ - public void fetchBlocks( + public abstract void fetchBlocks( String host, int port, String execId, diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java new file mode 100644 index 000000000000..84781207861e --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.io.IOException; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.TestUtils; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class SaslIntegrationSuite { + static ExternalShuffleBlockHandler handler; + static TransportServer server; + static TransportConf conf; + static TransportContext context; + + TransportClientFactory clientFactory; + + /** Provides a secret key holder which always returns the given secret key. */ + static class TestSecretKeyHolder implements SecretKeyHolder { + + private final String secretKey; + + TestSecretKeyHolder(String secretKey) { + this.secretKey = secretKey; + } + + @Override + public String getSaslUser(String appId) { + return "user"; + } + @Override + public String getSecretKey(String appId) { + return secretKey; + } + } + + + @BeforeClass + public static void beforeAll() throws IOException { + SecretKeyHolder secretKeyHolder = new TestSecretKeyHolder("good-key"); + SaslRpcHandler handler = new SaslRpcHandler(new TestRpcHandler(), secretKeyHolder); + conf = new TransportConf(new SystemPropertyConfigProvider()); + context = new TransportContext(conf, handler); + server = context.createServer(); + } + + + @AfterClass + public static void afterAll() { + server.close(); + } + + @After + public void afterEach() { + if (clientFactory != null) { + clientFactory.close(); + clientFactory = null; + } + } + + @Test + public void testGoodClient() { + clientFactory = context.createClientFactory( + Lists.newArrayList( + new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("good-key")))); + + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + String msg = "Hello, World!"; + byte[] resp = client.sendRpcSync(msg.getBytes(), 1000); + assertEquals(msg, new String(resp)); // our rpc handler should just return the given msg + } + + @Test + public void testBadClient() { + clientFactory = context.createClientFactory( + Lists.newArrayList( + new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("bad-key")))); + + try { + // Bootstrap should fail on startup. + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response")); + } + } + + @Test + public void testNoSaslClient() { + clientFactory = context.createClientFactory( + Lists.newArrayList()); + + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + try { + client.sendRpcSync(new byte[13], 1000); + fail("Should have failed"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage")); + } + + try { + // Guessing the right tag byte doesn't magically get you in... + client.sendRpcSync(new byte[] { (byte) 0xEA }, 1000); + fail("Should have failed"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException")); + } + } + + @Test + public void testNoSaslServer() { + RpcHandler handler = new TestRpcHandler(); + TransportContext context = new TransportContext(conf, handler); + clientFactory = context.createClientFactory( + Lists.newArrayList( + new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("key")))); + TransportServer server = context.createServer(); + try { + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge format violation")); + } finally { + server.close(); + } + } + + /** RPC handler which simply responds with the message it received. */ + public static class TestRpcHandler extends RpcHandler { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + callback.onSuccess(message); + } + + @Override + public StreamManager getStreamManager() { + return new OneForOneStreamManager(); + } + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java new file mode 100644 index 000000000000..67a07f38eb5a --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.util.Map; + +import com.google.common.collect.ImmutableMap; +import org.junit.Test; + +import static org.junit.Assert.*; + +/** + * Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes. + */ +public class SparkSaslSuite { + + /** Provides a secret key holder which returns secret key == appId */ + private SecretKeyHolder secretKeyHolder = new SecretKeyHolder() { + @Override + public String getSaslUser(String appId) { + return "user"; + } + + @Override + public String getSecretKey(String appId) { + return appId; + } + }; + + @Test + public void testMatching() { + SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder); + SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder); + + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + + byte[] clientMessage = client.firstToken(); + + while (!client.isComplete()) { + clientMessage = client.response(server.response(clientMessage)); + } + assertTrue(server.isComplete()); + + // Disposal should invalidate + server.dispose(); + assertFalse(server.isComplete()); + client.dispose(); + assertFalse(client.isComplete()); + } + + + @Test + public void testNonMatching() { + SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder); + SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder); + + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + + byte[] clientMessage = client.firstToken(); + + try { + while (!client.isComplete()) { + clientMessage = client.response(server.response(clientMessage)); + } + fail("Should not have completed"); + } catch (Exception e) { + assertTrue(e.getMessage().contains("Mismatched response")); + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + } + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index b3bcf5fd68e7..bc101f53844d 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -135,7 +135,8 @@ private FetchResult fetchBlocks(String execId, String[] blockIds, int port) thro final Semaphore requestsRemaining = new Semaphore(0); - ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID); + ExternalShuffleClient client = new ExternalShuffleClient(conf); + client.init(APP_ID); client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, new BlockFetchingListener() { @Override @@ -164,6 +165,7 @@ public void onBlockFetchFailure(String blockId, Throwable exception) { if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server"); } + client.close(); return res; } @@ -265,7 +267,8 @@ public void testFetchNoServer() throws Exception { } private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) { - ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID); + ExternalShuffleClient client = new ExternalShuffleClient(conf); + client.init(APP_ID); client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), executorId, executorInfo); } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index ad1a6f01b3a5..0f27f55fec4f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -74,6 +74,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche blockManager = new BlockManager("bm", actorSystem, blockManagerMaster, serializer, blockManagerSize, conf, mapOutputTracker, shuffleManager, new NioBlockTransferService(conf, securityMgr)) + blockManager.initialize("app-id") tempDirectory = Files.createTempDir() manualClock.setTime(0) From e7f735637ad2f681b454d1297f6fdcc433feebbc Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Wed, 5 Nov 2014 14:38:43 -0800 Subject: [PATCH 021/652] [SPARK-4242] [Core] Add SASL to external shuffle service Does three things: (1) Adds SASL to ExternalShuffleClient, (2) puts SecurityManager in BlockManager's constructor, and (3) adds unit test. Author: Aaron Davidson Closes #3108 from aarondav/sasl-client and squashes the following commits: 48b622d [Aaron Davidson] Screw it, let's just get LimitedInputStream 3543b70 [Aaron Davidson] Back out of pom change due to unknown test issue? b58518a [Aaron Davidson] ByteStreams.limit() not available :( cbe451a [Aaron Davidson] Address comments 2bf2908 [Aaron Davidson] [SPARK-4242] [Core] Add SASL to external shuffle service --- LICENSE | 21 +++- .../scala/org/apache/spark/SparkEnv.scala | 2 +- .../apache/spark/storage/BlockManager.scala | 12 +- .../BlockManagerReplicationSuite.scala | 4 +- .../spark/storage/BlockManagerSuite.scala | 4 +- network/common/pom.xml | 1 + .../buffer/FileSegmentManagedBuffer.java | 3 +- .../network/util/LimitedInputStream.java | 87 ++++++++++++++ network/shuffle/pom.xml | 1 + .../spark/network/sasl/SparkSaslClient.java | 1 - .../spark/network/sasl/SparkSaslServer.java | 9 +- .../shuffle/ExternalShuffleClient.java | 31 ++++- .../ExternalShuffleIntegrationSuite.java | 4 +- .../shuffle/ExternalShuffleSecuritySuite.java | 113 ++++++++++++++++++ .../streaming/ReceivedBlockHandlerSuite.scala | 2 +- 15 files changed, 272 insertions(+), 23 deletions(-) create mode 100644 network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java create mode 100644 network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java diff --git a/LICENSE b/LICENSE index f1732fb47afc..3c667bf45059 100644 --- a/LICENSE +++ b/LICENSE @@ -754,7 +754,7 @@ SUCH DAMAGE. ======================================================================== -For Timsort (core/src/main/java/org/apache/spark/util/collection/Sorter.java): +For Timsort (core/src/main/java/org/apache/spark/util/collection/TimSort.java): ======================================================================== Copyright (C) 2008 The Android Open Source Project @@ -771,6 +771,25 @@ See the License for the specific language governing permissions and limitations under the License. +======================================================================== +For LimitedInputStream + (network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java): +======================================================================== +Copyright (C) 2007 The Guava Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + + ======================================================================== BSD-style licenses ======================================================================== diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 45e9d7f243e9..e7454beddbfd 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -287,7 +287,7 @@ object SparkEnv extends Logging { // NB: blockManager is not valid until initialize() is called later. val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, - serializer, conf, mapOutputTracker, shuffleManager, blockTransferService) + serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager) val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 655d16c65c8b..a5fb87b9b2c5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -72,7 +72,8 @@ private[spark] class BlockManager( val conf: SparkConf, mapOutputTracker: MapOutputTracker, shuffleManager: ShuffleManager, - blockTransferService: BlockTransferService) + blockTransferService: BlockTransferService, + securityManager: SecurityManager) extends BlockDataManager with Logging { val diskBlockManager = new DiskBlockManager(this, conf) @@ -115,7 +116,8 @@ private[spark] class BlockManager( // Client to read other executors' shuffle files. This is either an external service, or just the // standard BlockTranserService to directly connect to other Executors. private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { - new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf)) + new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf), securityManager, + securityManager.isAuthenticationEnabled()) } else { blockTransferService } @@ -166,9 +168,10 @@ private[spark] class BlockManager( conf: SparkConf, mapOutputTracker: MapOutputTracker, shuffleManager: ShuffleManager, - blockTransferService: BlockTransferService) = { + blockTransferService: BlockTransferService, + securityManager: SecurityManager) = { this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), - conf, mapOutputTracker, shuffleManager, blockTransferService) + conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager) } /** @@ -219,7 +222,6 @@ private[spark] class BlockManager( return } catch { case e: Exception if i < MAX_ATTEMPTS => - val attemptsRemaining = logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}}" + s" more times after waiting $SLEEP_TIME_SECS seconds...", e) Thread.sleep(SLEEP_TIME_SECS * 1000) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 1461fa69db90..f63e772bf1e5 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -62,7 +62,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, - mapOutputTracker, shuffleManager, transfer) + mapOutputTracker, shuffleManager, transfer, securityMgr) store.initialize("app-id") allStores += store store @@ -263,7 +263,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd when(failableTransfer.hostName).thenReturn("some-hostname") when(failableTransfer.port).thenReturn(1000) val failableStore = new BlockManager("failable-store", actorSystem, master, serializer, - 10000, conf, mapOutputTracker, shuffleManager, failableTransfer) + 10000, conf, mapOutputTracker, shuffleManager, failableTransfer, securityMgr) failableStore.initialize("app-id") allStores += failableStore // so that this gets stopped after test assert(master.getPeers(store.blockManagerId).toSet === Set(failableStore.blockManagerId)) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 0782876c8e3c..9529502bc8e1 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -74,7 +74,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) val manager = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, - mapOutputTracker, shuffleManager, transfer) + mapOutputTracker, shuffleManager, transfer, securityMgr) manager.initialize("app-id") manager } @@ -795,7 +795,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter // Use Java serializer so we can create an unserializable error. val transfer = new NioBlockTransferService(conf, securityMgr) store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, actorSystem, master, - new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer) + new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer, securityMgr) // The put should fail since a1 is not serializable. class UnserializableClass diff --git a/network/common/pom.xml b/network/common/pom.xml index ea887148d98b..6144548a8f99 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -50,6 +50,7 @@ com.google.guava guava + 11.0.2 provided diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index 89ed79bc6390..5fa1527ddff9 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -30,6 +30,7 @@ import io.netty.channel.DefaultFileRegion; import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.LimitedInputStream; /** * A {@link ManagedBuffer} backed by a segment in a file. @@ -101,7 +102,7 @@ public InputStream createInputStream() throws IOException { try { is = new FileInputStream(file); ByteStreams.skipFully(is, offset); - return ByteStreams.limit(is, length); + return new LimitedInputStream(is, length); } catch (IOException e) { try { if (is != null) { diff --git a/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java b/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java new file mode 100644 index 000000000000..63ca43c04652 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; + +import com.google.common.base.Preconditions; + +/** + * Wraps a {@link InputStream}, limiting the number of bytes which can be read. + * + * This code is from Guava's 14.0 source code, because there is no compatible way to + * use this functionality in both a Guava 11 environment and a Guava >14 environment. + */ +public final class LimitedInputStream extends FilterInputStream { + private long left; + private long mark = -1; + + public LimitedInputStream(InputStream in, long limit) { + super(in); + Preconditions.checkNotNull(in); + Preconditions.checkArgument(limit >= 0, "limit must be non-negative"); + left = limit; + } + @Override public int available() throws IOException { + return (int) Math.min(in.available(), left); + } + // it's okay to mark even if mark isn't supported, as reset won't work + @Override public synchronized void mark(int readLimit) { + in.mark(readLimit); + mark = left; + } + @Override public int read() throws IOException { + if (left == 0) { + return -1; + } + int result = in.read(); + if (result != -1) { + --left; + } + return result; + } + @Override public int read(byte[] b, int off, int len) throws IOException { + if (left == 0) { + return -1; + } + len = (int) Math.min(len, left); + int result = in.read(b, off, len); + if (result != -1) { + left -= result; + } + return result; + } + @Override public synchronized void reset() throws IOException { + if (!in.markSupported()) { + throw new IOException("Mark not supported"); + } + if (mark == -1) { + throw new IOException("Mark not set"); + } + in.reset(); + left = mark; + } + @Override public long skip(long n) throws IOException { + n = Math.min(n, left); + long skipped = in.skip(n); + left -= skipped; + return skipped; + } +} diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index d271704d98a7..fe5681d46349 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -51,6 +51,7 @@ com.google.guava guava + 11.0.2 provided diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java index 72ba737b998b..9abad1f30a25 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java @@ -126,7 +126,6 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback logger.trace("SASL client callback: setting realm"); RealmCallback rc = (RealmCallback) callback; rc.setText(rc.getDefaultText()); - logger.info("Realm callback"); } else if (callback instanceof RealmChoiceCallback) { // ignore (?) } else { diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java index 2c0ce40c75e8..e87b17ead1e1 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -34,7 +34,8 @@ import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; -import com.google.common.io.BaseEncoding; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.base64.Base64; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -159,12 +160,14 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback /* Encode a byte[] identifier as a Base64-encoded string. */ public static String encodeIdentifier(String identifier) { Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled"); - return BaseEncoding.base64().encode(identifier.getBytes(Charsets.UTF_8)); + return Base64.encode(Unpooled.wrappedBuffer(identifier.getBytes(Charsets.UTF_8))) + .toString(Charsets.UTF_8); } /** Encode a password as a base64-encoded char[] array. */ public static char[] encodePassword(String password) { Preconditions.checkNotNull(password, "Password cannot be null if SASL is enabled"); - return BaseEncoding.base64().encode(password.getBytes(Charsets.UTF_8)).toCharArray(); + return Base64.encode(Unpooled.wrappedBuffer(password.getBytes(Charsets.UTF_8))) + .toString(Charsets.UTF_8).toCharArray(); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index b0b19ba67bdd..3aa95d00f6b2 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -17,12 +17,18 @@ package org.apache.spark.network.shuffle; +import java.util.List; + +import com.google.common.collect.Lists; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.TransportContext; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.sasl.SaslClientBootstrap; +import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.shuffle.ExternalShuffleMessages.RegisterExecutor; import org.apache.spark.network.util.JavaUtils; @@ -37,18 +43,35 @@ public class ExternalShuffleClient extends ShuffleClient { private final Logger logger = LoggerFactory.getLogger(ExternalShuffleClient.class); - private final TransportClientFactory clientFactory; + private final TransportConf conf; + private final boolean saslEnabled; + private final SecretKeyHolder secretKeyHolder; + private TransportClientFactory clientFactory; private String appId; - public ExternalShuffleClient(TransportConf conf) { - TransportContext context = new TransportContext(conf, new NoOpRpcHandler()); - this.clientFactory = context.createClientFactory(); + /** + * Creates an external shuffle client, with SASL optionally enabled. If SASL is not enabled, + * then secretKeyHolder may be null. + */ + public ExternalShuffleClient( + TransportConf conf, + SecretKeyHolder secretKeyHolder, + boolean saslEnabled) { + this.conf = conf; + this.secretKeyHolder = secretKeyHolder; + this.saslEnabled = saslEnabled; } @Override public void init(String appId) { this.appId = appId; + TransportContext context = new TransportContext(conf, new NoOpRpcHandler()); + List bootstraps = Lists.newArrayList(); + if (saslEnabled) { + bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder)); + } + clientFactory = context.createClientFactory(bootstraps); } @Override diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index bc101f53844d..71e017b9e4e7 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -135,7 +135,7 @@ private FetchResult fetchBlocks(String execId, String[] blockIds, int port) thro final Semaphore requestsRemaining = new Semaphore(0); - ExternalShuffleClient client = new ExternalShuffleClient(conf); + ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false); client.init(APP_ID); client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, new BlockFetchingListener() { @@ -267,7 +267,7 @@ public void testFetchNoServer() throws Exception { } private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) { - ExternalShuffleClient client = new ExternalShuffleClient(conf); + ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false); client.init(APP_ID); client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), executorId, executorInfo); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java new file mode 100644 index 000000000000..4c18fcdfbcd8 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.TestUtils; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.sasl.SaslRpcHandler; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class ExternalShuffleSecuritySuite { + + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportServer server; + + @Before + public void beforeEach() { + RpcHandler handler = new SaslRpcHandler(new ExternalShuffleBlockHandler(), + new TestSecretKeyHolder("my-app-id", "secret")); + TransportContext context = new TransportContext(conf, handler); + this.server = context.createServer(); + } + + @After + public void afterEach() { + if (server != null) { + server.close(); + server = null; + } + } + + @Test + public void testValid() { + validate("my-app-id", "secret"); + } + + @Test + public void testBadAppId() { + try { + validate("wrong-app-id", "secret"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Wrong appId!")); + } + } + + @Test + public void testBadSecret() { + try { + validate("my-app-id", "bad-secret"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response")); + } + } + + /** Creates an ExternalShuffleClient and attempts to register with the server. */ + private void validate(String appId, String secretKey) { + ExternalShuffleClient client = + new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true); + client.init(appId); + // Registration either succeeds or throws an exception. + client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0", + new ExecutorShuffleInfo(new String[0], 0, "")); + client.close(); + } + + /** Provides a secret key holder which always returns the given secret key, for a single appId. */ + static class TestSecretKeyHolder implements SecretKeyHolder { + private final String appId; + private final String secretKey; + + TestSecretKeyHolder(String appId, String secretKey) { + this.appId = appId; + this.secretKey = secretKey; + } + + @Override + public String getSaslUser(String appId) { + return "user"; + } + + @Override + public String getSecretKey(String appId) { + if (!appId.equals(this.appId)) { + throw new IllegalArgumentException("Wrong appId!"); + } + return secretKey; + } + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 0f27f55fec4f..9efe15d01ed0 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -73,7 +73,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche blockManager = new BlockManager("bm", actorSystem, blockManagerMaster, serializer, blockManagerSize, conf, mapOutputTracker, shuffleManager, - new NioBlockTransferService(conf, securityMgr)) + new NioBlockTransferService(conf, securityMgr), securityMgr) blockManager.initialize("app-id") tempDirectory = Files.createTempDir() From 866c7bbe56f9c7fd96d3f4afe8a76405dc877a6e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 3 Nov 2014 18:18:47 -0800 Subject: [PATCH 022/652] [SPARK-611] Display executor thread dumps in web UI This patch allows executor thread dumps to be collected on-demand and viewed in the Spark web UI. The thread dumps are collected using Thread.getAllStackTraces(). To allow remote thread dumps to be triggered from the web UI, I added a new `ExecutorActor` that runs inside of the Executor actor system and responds to RPCs from the driver. The driver's mechanism for obtaining a reference to this actor is a little bit hacky: it uses the block manager master actor to determine the host/port of the executor actor systems in order to construct ActorRefs to ExecutorActor. Unfortunately, I couldn't find a much cleaner way to do this without a big refactoring of the executor -> driver communication. Screenshots: ![image](https://cloud.githubusercontent.com/assets/50748/4781793/7e7a0776-5cbf-11e4-874d-a91cd04620bd.png) ![image](https://cloud.githubusercontent.com/assets/50748/4781794/8bce76aa-5cbf-11e4-8d13-8477748c9f7e.png) ![image](https://cloud.githubusercontent.com/assets/50748/4781797/bd11a8b8-5cbf-11e4-9ad7-a7459467ec8e.png) Author: Josh Rosen Closes #2944 from JoshRosen/jstack-in-web-ui and squashes the following commits: 3c21a5d [Josh Rosen] Address review comments: 880f7f7 [Josh Rosen] Merge remote-tracking branch 'origin/master' into jstack-in-web-ui f719266 [Josh Rosen] Merge remote-tracking branch 'origin/master' into jstack-in-web-ui 19707b0 [Josh Rosen] Add one comment. 127a130 [Josh Rosen] Update to use SparkContext.DRIVER_IDENTIFIER b8e69aa [Josh Rosen] Merge remote-tracking branch 'origin/master' into jstack-in-web-ui 3dfc2d4 [Josh Rosen] Add missing file. bc1e675 [Josh Rosen] Undo some leftover changes from the earlier approach. f4ac1c1 [Josh Rosen] Switch to on-demand collection of thread dumps dfec08b [Josh Rosen] Add option to disable thread dumps in UI. 4c87d7f [Josh Rosen] Use separate RPC for sending thread dumps. 2b8bdf3 [Josh Rosen] Enable thread dumps from the driver when running in non-local mode. cc3e6b3 [Josh Rosen] Fix test code in DAGSchedulerSuite. 87b8b65 [Josh Rosen] Add new listener event for thread dumps. 8c10216 [Josh Rosen] Add missing file. 0f198ac [Josh Rosen] [SPARK-611] Display executor thread dumps in web UI --- .../scala/org/apache/spark/SparkContext.scala | 29 +++++++- .../CoarseGrainedExecutorBackend.scala | 3 +- .../org/apache/spark/executor/Executor.scala | 7 +- .../apache/spark/executor/ExecutorActor.scala | 41 +++++++++++ .../spark/storage/BlockManagerMaster.scala | 4 + .../storage/BlockManagerMasterActor.scala | 18 +++++ .../spark/storage/BlockManagerMessages.scala | 2 + .../ui/exec/ExecutorThreadDumpPage.scala | 73 +++++++++++++++++++ .../apache/spark/ui/exec/ExecutorsPage.scala | 15 +++- .../apache/spark/ui/exec/ExecutorsTab.scala | 8 +- .../org/apache/spark/util/AkkaUtils.scala | 14 ++++ .../apache/spark/util/ThreadStackTrace.scala | 27 +++++++ .../scala/org/apache/spark/util/Utils.scala | 13 ++++ 13 files changed, 247 insertions(+), 7 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala create mode 100644 core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index d65027d18e2d..3cdaa6a9cc8a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -21,9 +21,8 @@ import scala.language.implicitConversions import java.io._ import java.net.URI -import java.util.Arrays +import java.util.{Arrays, Properties, UUID} import java.util.concurrent.atomic.AtomicInteger -import java.util.{Properties, UUID} import java.util.UUID.randomUUID import scala.collection.{Map, Set} import scala.collection.generic.Growable @@ -41,6 +40,7 @@ import akka.actor.Props import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} +import org.apache.spark.executor.TriggerThreadDump import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat} import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ @@ -51,7 +51,7 @@ import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage._ import org.apache.spark.ui.SparkUI import org.apache.spark.ui.jobs.JobProgressListener -import org.apache.spark.util.{CallSite, ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedWeakValueHashMap, Utils} +import org.apache.spark.util._ /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -363,6 +363,29 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { override protected def childValue(parent: Properties): Properties = new Properties(parent) } + /** + * Called by the web UI to obtain executor thread dumps. This method may be expensive. + * Logs an error and returns None if we failed to obtain a thread dump, which could occur due + * to an executor being dead or unresponsive or due to network issues while sending the thread + * dump message back to the driver. + */ + private[spark] def getExecutorThreadDump(executorId: String): Option[Array[ThreadStackTrace]] = { + try { + if (executorId == SparkContext.DRIVER_IDENTIFIER) { + Some(Utils.getThreadDump()) + } else { + val (host, port) = env.blockManager.master.getActorSystemHostPortForExecutor(executorId).get + val actorRef = AkkaUtils.makeExecutorRef("ExecutorActor", conf, host, port, env.actorSystem) + Some(AkkaUtils.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump, actorRef, + AkkaUtils.numRetries(conf), AkkaUtils.retryWaitMs(conf), AkkaUtils.askTimeout(conf))) + } + } catch { + case e: Exception => + logError(s"Exception getting thread dump from executor $executorId", e) + None + } + } + private[spark] def getLocalProperties: Properties = localProperties.get() private[spark] def setLocalProperties(props: Properties) { diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 697154d762d4..3711824a40cf 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -131,7 +131,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Create a new ActorSystem using driver's Spark properties to run the backend. val driverConf = new SparkConf().setAll(props) val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "sparkExecutor", hostname, port, driverConf, new SecurityManager(driverConf)) + SparkEnv.executorActorSystemName, + hostname, port, driverConf, new SecurityManager(driverConf)) // set it val sparkHostPort = hostname + ":" + boundPort actorSystem.actorOf( diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 7dd5265891c3..abc1dd0be623 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -26,7 +26,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal -import akka.actor.ActorSystem +import akka.actor.{Props, ActorSystem} import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil @@ -93,6 +93,10 @@ private[spark] class Executor( } } + // Create an actor for receiving RPCs from the driver + private val executorActor = env.actorSystem.actorOf( + Props(new ExecutorActor(executorId)), "ExecutorActor") + // Create our ClassLoader // do this after SparkEnv creation so can access the SecurityManager private val urlClassLoader = createClassLoader() @@ -132,6 +136,7 @@ private[spark] class Executor( def stop() { env.metricsSystem.report() + env.actorSystem.stop(executorActor) isStopped = true threadPool.shutdown() if (!isLocal) { diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala new file mode 100644 index 000000000000..41925f7e97e8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import akka.actor.Actor +import org.apache.spark.Logging + +import org.apache.spark.util.{Utils, ActorLogReceive} + +/** + * Driver -> Executor message to trigger a thread dump. + */ +private[spark] case object TriggerThreadDump + +/** + * Actor that runs inside of executors to enable driver -> executor RPC. + */ +private[spark] +class ExecutorActor(executorId: String) extends Actor with ActorLogReceive with Logging { + + override def receiveWithLogging = { + case TriggerThreadDump => + sender ! Utils.getThreadDump() + } + +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index d08e1419e3e4..b63c7f191155 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -88,6 +88,10 @@ class BlockManagerMaster( askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId)) } + def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { + askDriverWithReply[Option[(String, Int)]](GetActorSystemHostPortForExecutor(executorId)) + } + /** * Remove a block from the slaves that have it. This can only be used to remove * blocks that the driver knows about. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 5e375a255397..685b2e11440f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -86,6 +86,9 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus case GetPeers(blockManagerId) => sender ! getPeers(blockManagerId) + case GetActorSystemHostPortForExecutor(executorId) => + sender ! getActorSystemHostPortForExecutor(executorId) + case GetMemoryStatus => sender ! memoryStatus @@ -412,6 +415,21 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus Seq.empty } } + + /** + * Returns the hostname and port of an executor's actor system, based on the Akka address of its + * BlockManagerSlaveActor. + */ + private def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { + for ( + blockManagerId <- blockManagerIdByExecutor.get(executorId); + info <- blockManagerInfo.get(blockManagerId); + host <- info.slaveActor.path.address.host; + port <- info.slaveActor.path.address.port + ) yield { + (host, port) + } + } } @DeveloperApi diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 291ddfcc113a..3f32099d08cc 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -92,6 +92,8 @@ private[spark] object BlockManagerMessages { case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster + case class GetActorSystemHostPortForExecutor(executorId: String) extends ToBlockManagerMaster + case class RemoveExecutor(execId: String) extends ToBlockManagerMaster case object StopBlockManagerMaster extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala new file mode 100644 index 000000000000..e9c755e36f71 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.exec + +import javax.servlet.http.HttpServletRequest + +import scala.util.Try +import scala.xml.{Text, Node} + +import org.apache.spark.ui.{UIUtils, WebUIPage} + +private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage("threadDump") { + + private val sc = parent.sc + + def render(request: HttpServletRequest): Seq[Node] = { + val executorId = Option(request.getParameter("executorId")).getOrElse { + return Text(s"Missing executorId parameter") + } + val time = System.currentTimeMillis() + val maybeThreadDump = sc.get.getExecutorThreadDump(executorId) + + val content = maybeThreadDump.map { threadDump => + val dumpRows = threadDump.map { thread => + + } + +
+

Updated at {UIUtils.formatDate(time)}

+ { + // scalastyle:off +

+ Expand All +

+

+ // scalastyle:on + } +
{dumpRows}
+
+ }.getOrElse(Text("Error fetching thread dump")) + UIUtils.headerSparkPage(s"Thread dump for executor $executorId", content, parent) + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index b0e3bb3b552f..048fee3ce1ff 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -41,7 +41,10 @@ private case class ExecutorSummaryInfo( totalShuffleWrite: Long, maxMemory: Long) -private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { +private[ui] class ExecutorsPage( + parent: ExecutorsTab, + threadDumpEnabled: Boolean) + extends WebUIPage("") { private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { @@ -75,6 +78,7 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { Shuffle Write + {if (threadDumpEnabled) Thread Dump else Seq.empty} {execInfoSorted.map(execRow)} @@ -133,6 +137,15 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { {Utils.bytesToString(info.totalShuffleWrite)} + { + if (threadDumpEnabled) { + + Thread Dump + + } else { + Seq.empty + } + } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 9e0e71a51a40..ba97630f025c 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -27,8 +27,14 @@ import org.apache.spark.ui.{SparkUI, SparkUITab} private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "executors") { val listener = parent.executorsListener + val sc = parent.sc + val threadDumpEnabled = + sc.isDefined && parent.conf.getBoolean("spark.ui.threadDumpsEnabled", true) - attachPage(new ExecutorsPage(this)) + attachPage(new ExecutorsPage(this, threadDumpEnabled)) + if (threadDumpEnabled) { + attachPage(new ExecutorThreadDumpPage(this)) + } } /** diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 79e398eb8c10..10010bdfa1a5 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -212,4 +212,18 @@ private[spark] object AkkaUtils extends Logging { logInfo(s"Connecting to $name: $url") Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) } + + def makeExecutorRef( + name: String, + conf: SparkConf, + host: String, + port: Int, + actorSystem: ActorSystem): ActorRef = { + val executorActorSystemName = SparkEnv.executorActorSystemName + Utils.checkHost(host, "Expected hostname") + val url = s"akka.tcp://$executorActorSystemName@$host:$port/user/$name" + val timeout = AkkaUtils.lookupTimeout(conf) + logInfo(s"Connecting to $name: $url") + Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) + } } diff --git a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala new file mode 100644 index 000000000000..d4e0ad93b966 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * Used for shipping per-thread stacktraces from the executors to driver. + */ +private[spark] case class ThreadStackTrace( + threadId: Long, + threadName: String, + threadState: Thread.State, + stackTrace: String) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index a33046d2040d..6ab94af9f373 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import java.io._ +import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer import java.util.jar.Attributes.Name @@ -1611,6 +1612,18 @@ private[spark] object Utils extends Logging { s"$className: $desc\n$st" } + /** Return a thread dump of all threads' stacktraces. Used to capture dumps for the web UI */ + def getThreadDump(): Array[ThreadStackTrace] = { + // We need to filter out null values here because dumpAllThreads() may return null array + // elements for threads that are dead / don't exist. + val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null) + threadInfos.sortBy(_.getThreadId).map { case threadInfo => + val stackTrace = threadInfo.getStackTrace.map(_.toString).mkString("\n") + ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName, + threadInfo.getThreadState, stackTrace) + } + } + /** * Convert all spark properties set in the given SparkConf to a sequence of java options. */ From 7517c37aee373c8bd3ccbf1eae079b0fc6b89c91 Mon Sep 17 00:00:00 2001 From: "Zhang, Liye" Date: Mon, 3 Nov 2014 18:17:32 -0800 Subject: [PATCH 023/652] [SPARK-4168][WebUI] web statges number should show correctly when stages are more than 1000 The number of completed stages and failed stages showed on webUI will always be less than 1000. This is really misleading when there are already thousands of stages completed or failed. The number should be correct even when only partial stages listed on the webUI (stage info will be removed if the number is too large). Author: Zhang, Liye Closes #3035 from liyezhang556520/webStageNum and squashes the following commits: d9e29fb [Zhang, Liye] add detailed comments for variables 4ea8fd1 [Zhang, Liye] change variable name accroding to comments f4c404d [Zhang, Liye] [SPARK-4168][WebUI] web statges number should show correctly when stages are more than 1000 --- .../org/apache/spark/ui/jobs/JobProgressListener.scala | 9 +++++++++ .../org/apache/spark/ui/jobs/JobProgressPage.scala | 10 ++++++---- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index b5207360510d..e3223403c17f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -59,6 +59,13 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val failedStages = ListBuffer[StageInfo]() val stageIdToData = new HashMap[(StageId, StageAttemptId), StageUIData] val stageIdToInfo = new HashMap[StageId, StageInfo] + + // Number of completed and failed stages, may not actually equal to completedStages.size and + // failedStages.size respectively due to completedStage and failedStages only maintain the latest + // part of the stages, the earlier ones will be removed when there are too many stages for + // memory sake. + var numCompletedStages = 0 + var numFailedStages = 0 // Map from pool name to a hash map (map from stage id to StageInfo). val poolToActiveStages = HashMap[String, HashMap[Int, StageInfo]]() @@ -110,9 +117,11 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { activeStages.remove(stage.stageId) if (stage.failureReason.isEmpty) { completedStages += stage + numCompletedStages += 1 trimIfNecessary(completedStages) } else { failedStages += stage + numFailedStages += 1 trimIfNecessary(failedStages) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala index 6e718eecdd52..83a7898071c9 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala @@ -34,7 +34,9 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") listener.synchronized { val activeStages = listener.activeStages.values.toSeq val completedStages = listener.completedStages.reverse.toSeq + val numCompletedStages = listener.numCompletedStages val failedStages = listener.failedStages.reverse.toSeq + val numFailedStages = listener.numFailedStages val now = System.currentTimeMillis val activeStagesTable = @@ -69,11 +71,11 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("")
  • Completed Stages: - {completedStages.size} + {numCompletedStages}
  • Failed Stages: - {failedStages.size} + {numFailedStages}
  • @@ -86,9 +88,9 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") }} ++

    Active Stages ({activeStages.size})

    ++ activeStagesTable.toNodeSeq ++ -

    Completed Stages ({completedStages.size})

    ++ +

    Completed Stages ({numCompletedStages})

    ++ completedStagesTable.toNodeSeq ++ -

    Failed Stages ({failedStages.size})

    ++ +

    Failed Stages ({numFailedStages})

    ++ failedStagesTable.toNodeSeq UIUtils.headerSparkPage("Spark Stages", content, parent) From e0a043b79c250515a680485f0dc7b1a149835445 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 3 Nov 2014 22:40:43 -0800 Subject: [PATCH 024/652] [SPARK-4163][Core] Add a backward compatibility test for FetchFailed /cc aarondav Author: zsxwing Closes #3086 from zsxwing/SPARK-4163-back-comp and squashes the following commits: 21cb2a8 [zsxwing] Add a backward compatibility test for FetchFailed --- .../org/apache/spark/util/JsonProtocolSuite.scala | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index a91c9ddeaef3..01030120ae54 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -177,6 +177,17 @@ class JsonProtocolSuite extends FunSuite { deserializedBmRemoved) } + test("FetchFailed backwards compatibility") { + // FetchFailed in Spark 1.1.0 does not have an "Message" property. + val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19, + "ignored") + val oldEvent = JsonProtocol.taskEndReasonToJson(fetchFailed) + .removeField({ _._1 == "Message" }) + val expectedFetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19, + "Unknown reason") + assert(expectedFetchFailed === JsonProtocol.taskEndReasonFromJson(oldEvent)) + } + test("SparkListenerApplicationStart backwards compatibility") { // SparkListenerApplicationStart in Spark 1.0.0 do not have an "appId" property. val applicationStart = SparkListenerApplicationStart("test", None, 1L, "user") From 68be37b823516dbeda066776bb060bf894db4e95 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 3 Nov 2014 22:47:45 -0800 Subject: [PATCH 025/652] [SPARK-4166][Core] Add a backward compatibility test for ExecutorLostFailure Author: zsxwing Closes #3085 from zsxwing/SPARK-4166-back-comp and squashes the following commits: 89329f4 [zsxwing] Add a backward compatibility test for ExecutorLostFailure --- .../scala/org/apache/spark/util/JsonProtocolSuite.scala | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 01030120ae54..aec1e409db95 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -196,6 +196,15 @@ class JsonProtocolSuite extends FunSuite { assert(applicationStart === JsonProtocol.applicationStartFromJson(oldEvent)) } + test("ExecutorLostFailure backward compatibility") { + // ExecutorLostFailure in Spark 1.1.0 does not have an "Executor ID" property. + val executorLostFailure = ExecutorLostFailure("100") + val oldEvent = JsonProtocol.taskEndReasonToJson(executorLostFailure) + .removeField({ _._1 == "Executor ID" }) + val expectedExecutorLostFailure = ExecutorLostFailure("Unknown") + assert(expectedExecutorLostFailure === JsonProtocol.taskEndReasonFromJson(oldEvent)) + } + /** -------------------------- * | Helper test running methods | * --------------------------- */ From b27d7dcaaad0bf04d341660ffbeb742cd4eecfd3 Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Mon, 3 Nov 2014 09:02:35 -0800 Subject: [PATCH 026/652] [EC2] Factor out Mesos spark-ec2 branch We reference a specific branch in two places. This patch makes it one place. Author: Nicholas Chammas Closes #3008 from nchammas/mesos-spark-ec2-branch and squashes the following commits: 10a6089 [Nicholas Chammas] factor out mess spark-ec2 branch --- ec2/spark_ec2.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 0d6b82b4944f..50f88f735650 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -41,8 +41,9 @@ DEFAULT_SPARK_VERSION = "1.1.0" +MESOS_SPARK_EC2_BRANCH = "v4" # A URL prefix from which to fetch AMI information -AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/v2/ami-list" +AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/{b}/ami-list".format(b=MESOS_SPARK_EC2_BRANCH) class UsageError(Exception): @@ -583,7 +584,13 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): # NOTE: We should clone the repository before running deploy_files to # prevent ec2-variables.sh from being overwritten - ssh(master, opts, "rm -rf spark-ec2 && git clone https://github.com/mesos/spark-ec2.git -b v4") + ssh( + host=master, + opts=opts, + command="rm -rf spark-ec2" + + " && " + + "git clone https://github.com/mesos/spark-ec2.git -b {b}".format(b=MESOS_SPARK_EC2_BRANCH) + ) print "Deploying files to master..." deploy_files(conn, "deploy.generic", opts, master_nodes, slave_nodes, modules) From f4beb77f083e477845b90b5049186095d2002f49 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Wed, 5 Nov 2014 15:30:31 -0800 Subject: [PATCH 027/652] [SPARK-3984] [SPARK-3983] Fix incorrect scheduler delay and display task deserialization time in UI This commit fixes the scheduler delay in the UI (which previously included things that are not scheduler delay, like time to deserialize the task and serialize the result), and also adds information about time to deserialize tasks to the optional additional metrics. Time to deserialize the task can be large relative to task time for short jobs, and understanding when it is high can help developers realize that they should try to reduce closure size (e.g, by including less data in the task description). cc shivaram etrain Author: Kay Ousterhout Closes #2832 from kayousterhout/SPARK-3983 and squashes the following commits: 0c1398e [Kay Ousterhout] Fixed ordering 531575d [Kay Ousterhout] Removed executor launch time 1f13afe [Kay Ousterhout] Minor spacing fixes 335be4b [Kay Ousterhout] Made metrics hideable 5bc3cba [Kay Ousterhout] [SPARK-3984] [SPARK-3983] Improve UI task metrics. (cherry picked from commit a46497eecc50f854c5c5701dc2b8a2468b76c085) Signed-off-by: Kay Ousterhout --- .../org/apache/spark/executor/Executor.scala | 4 +-- .../scala/org/apache/spark/ui/ToolTips.scala | 3 ++ .../org/apache/spark/ui/jobs/StagePage.scala | 31 ++++++++++++++++++- .../spark/ui/jobs/TaskDetailsClassNames.scala | 1 + 4 files changed, 36 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index abc1dd0be623..96114571d6c7 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -161,7 +161,7 @@ private[spark] class Executor( } override def run() { - val startTime = System.currentTimeMillis() + val deserializeStartTime = System.currentTimeMillis() Thread.currentThread.setContextClassLoader(replClassLoader) val ser = SparkEnv.get.closureSerializer.newInstance() logInfo(s"Running $taskName (TID $taskId)") @@ -206,7 +206,7 @@ private[spark] class Executor( val afterSerialization = System.currentTimeMillis() for (m <- task.metrics) { - m.executorDeserializeTime = taskStart - startTime + m.executorDeserializeTime = taskStart - deserializeStartTime m.executorRunTime = taskFinish - taskStart m.jvmGCTime = gcTime - startGCTime m.resultSerializationTime = afterSerialization - beforeSerialization diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index f02904df31fc..51dc08f668a4 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -24,6 +24,9 @@ private[spark] object ToolTips { scheduler delay is large, consider decreasing the size of tasks or decreasing the size of task results.""" + val TASK_DESERIALIZATION_TIME = + """Time spent deserializating the task closure on the executor.""" + val INPUT = "Bytes read from Hadoop or from Spark storage." val SHUFFLE_WRITE = "Bytes written to disk in order to be read by a shuffle in a future stage." diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 7cc03b7d333d..63ed5fc4949c 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -112,6 +112,13 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { Scheduler Delay +
  • + + + Task Deserialization Time + +
  • @@ -147,6 +154,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { ("Index", ""), ("ID", ""), ("Attempt", ""), ("Status", ""), ("Locality Level", ""), ("Executor ID / Host", ""), ("Launch Time", ""), ("Duration", ""), ("Scheduler Delay", TaskDetailsClassNames.SCHEDULER_DELAY), + ("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME), ("GC Time", TaskDetailsClassNames.GC_TIME), ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++ @@ -179,6 +187,17 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { } } + val deserializationTimes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.executorDeserializeTime.toDouble + } + val deserializationQuantiles = + + + Task Deserialization Time + + +: getFormattedTimeQuantiles(deserializationTimes) + val serviceTimes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.executorRunTime.toDouble } @@ -266,6 +285,9 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { val listings: Seq[Seq[Node]] = Seq( {serviceQuantiles}, {schedulerDelayQuantiles}, + + {deserializationQuantiles} + {gcQuantiles}, {serializationQuantiles} @@ -314,6 +336,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { else metrics.map(m => UIUtils.formatDuration(m.executorRunTime)).getOrElse("") val schedulerDelay = metrics.map(getSchedulerDelay(info, _)).getOrElse(0L) val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L) + val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) val gettingResultTime = info.gettingResultTime @@ -367,6 +390,10 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { class={TaskDetailsClassNames.SCHEDULER_DELAY}> {UIUtils.formatDuration(schedulerDelay.toLong)} + + {UIUtils.formatDuration(taskDeserializationTime.toLong)} + {if (gcTime > 0) UIUtils.formatDuration(gcTime) else ""} @@ -424,6 +451,8 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { (info.finishTime - info.launchTime) } } - totalExecutionTime - metrics.executorRunTime + val executorOverhead = (metrics.executorDeserializeTime + + metrics.resultSerializationTime) + totalExecutionTime - metrics.executorRunTime - executorOverhead } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala index 23d672cabda0..eb371bd0ea7e 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala @@ -24,6 +24,7 @@ package org.apache.spark.ui.jobs private object TaskDetailsClassNames { val SCHEDULER_DELAY = "scheduler_delay" val GC_TIME = "gc_time" + val TASK_DESERIALIZATION_TIME = "deserialization_time" val RESULT_SERIALIZATION_TIME = "serialization_time" val GETTING_RESULT_TIME = "getting_result_time" } From 6844e7a8219ac78790a422ffd5054924e7d2bea1 Mon Sep 17 00:00:00 2001 From: industrial-sloth Date: Wed, 5 Nov 2014 15:38:48 -0800 Subject: [PATCH 028/652] SPARK-4222 [CORE] use readFully in FixedLengthBinaryRecordReader replaces the existing read() call with readFully(). Author: industrial-sloth Closes #3093 from industrial-sloth/branch-1.2-fixedLenRecRdr and squashes the following commits: a245c8a [industrial-sloth] use readFully in FixedLengthBinaryRecordReader --- .../org/apache/spark/input/FixedLengthBinaryRecordReader.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala index 5164a74bec4e..36a1e5d475f4 100644 --- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala @@ -115,7 +115,7 @@ private[spark] class FixedLengthBinaryRecordReader if (currentPosition < splitEnd) { // setup a buffer to store the record val buffer = recordValue.getBytes - fileInputStream.read(buffer, 0, recordLength) + fileInputStream.readFully(buffer) // update our current position currentPosition = currentPosition + recordLength // return true From cf2f676f93807bc504b77409b6c3d66f0d5e38ab Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 5 Nov 2014 15:42:05 -0800 Subject: [PATCH 029/652] [SPARK-3797] Run external shuffle service in Yarn NM This creates a new module `network/yarn` that depends on `network/shuffle` recently created in #3001. This PR introduces a custom Yarn auxiliary service that runs the external shuffle service. As of the changes here this shuffle service is required for using dynamic allocation with Spark. This is still WIP mainly because it doesn't handle security yet. I have tested this on a stable Yarn cluster. Author: Andrew Or Closes #3082 from andrewor14/yarn-shuffle-service and squashes the following commits: ef3ddae [Andrew Or] Merge branch 'master' of github.com:apache/spark into yarn-shuffle-service 0ee67a2 [Andrew Or] Minor wording suggestions 1c66046 [Andrew Or] Remove unused provided dependencies 0eb6233 [Andrew Or] Merge branch 'master' of github.com:apache/spark into yarn-shuffle-service 6489db5 [Andrew Or] Try catch at the right places 7b71d8f [Andrew Or] Add detailed java docs + reword a few comments d1124e4 [Andrew Or] Add security to shuffle service (INCOMPLETE) 5f8a96f [Andrew Or] Merge branch 'master' of github.com:apache/spark into yarn-shuffle-service 9b6e058 [Andrew Or] Address various feedback f48b20c [Andrew Or] Fix tests again f39daa6 [Andrew Or] Do not make network-yarn an assembly module 761f58a [Andrew Or] Merge branch 'master' of github.com:apache/spark into yarn-shuffle-service 15a5b37 [Andrew Or] Fix build for Hadoop 1.x baff916 [Andrew Or] Fix tests 5bf9b7e [Andrew Or] Address a few minor comments 5b419b8 [Andrew Or] Add missing license header 804e7ff [Andrew Or] Include the Yarn shuffle service jar in the distribution cd076a4 [Andrew Or] Require external shuffle service for dynamic allocation ea764e0 [Andrew Or] Connect to Yarn shuffle service only if it's enabled 1bf5109 [Andrew Or] Use the shuffle service port specified through hadoop config b4b1f0c [Andrew Or] 4 tabs -> 2 tabs 43dcb96 [Andrew Or] First cut integration of shuffle service with Yarn aux service b54a0c4 [Andrew Or] Initial skeleton for Yarn shuffle service (cherry picked from commit 61a5cced049a8056292ba94f23fa7bd040f50685) Signed-off-by: Andrew Or --- .../spark/ExecutorAllocationManager.scala | 37 +++- .../apache/spark/storage/BlockManager.scala | 8 +- .../scala/org/apache/spark/util/Utils.scala | 16 ++ make-distribution.sh | 3 + .../network/sasl/ShuffleSecretManager.java | 117 ++++++++++++ network/yarn/pom.xml | 58 ++++++ .../network/yarn/YarnShuffleService.java | 176 ++++++++++++++++++ .../yarn/util/HadoopConfigProvider.java | 42 +++++ pom.xml | 2 + project/SparkBuild.scala | 8 +- .../spark/deploy/yarn/ExecutorRunnable.scala | 16 ++ .../spark/deploy/yarn/ExecutorRunnable.scala | 16 ++ 12 files changed, 483 insertions(+), 16 deletions(-) create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java create mode 100644 network/yarn/pom.xml create mode 100644 network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java create mode 100644 network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index c11f1db0064f..ef93009a074e 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -66,7 +66,6 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging // Lower and upper bounds on the number of executors. These are required. private val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", -1) private val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", -1) - verifyBounds() // How long there must be backlogged tasks for before an addition is triggered private val schedulerBacklogTimeout = conf.getLong( @@ -77,9 +76,14 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", schedulerBacklogTimeout) // How long an executor must be idle for before it is removed - private val removeThresholdSeconds = conf.getLong( + private val executorIdleTimeout = conf.getLong( "spark.dynamicAllocation.executorIdleTimeout", 600) + // During testing, the methods to actually kill and add executors are mocked out + private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) + + validateSettings() + // Number of executors to add in the next round private var numExecutorsToAdd = 1 @@ -103,17 +107,14 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging // Polling loop interval (ms) private val intervalMillis: Long = 100 - // Whether we are testing this class. This should only be used internally. - private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) - // Clock used to schedule when executors should be added and removed private var clock: Clock = new RealClock /** - * Verify that the lower and upper bounds on the number of executors are valid. + * Verify that the settings specified through the config are valid. * If not, throw an appropriate exception. */ - private def verifyBounds(): Unit = { + private def validateSettings(): Unit = { if (minNumExecutors < 0 || maxNumExecutors < 0) { throw new SparkException("spark.dynamicAllocation.{min/max}Executors must be set!") } @@ -124,6 +125,22 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging throw new SparkException(s"spark.dynamicAllocation.minExecutors ($minNumExecutors) must " + s"be less than or equal to spark.dynamicAllocation.maxExecutors ($maxNumExecutors)!") } + if (schedulerBacklogTimeout <= 0) { + throw new SparkException("spark.dynamicAllocation.schedulerBacklogTimeout must be > 0!") + } + if (sustainedSchedulerBacklogTimeout <= 0) { + throw new SparkException( + "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout must be > 0!") + } + if (executorIdleTimeout <= 0) { + throw new SparkException("spark.dynamicAllocation.executorIdleTimeout must be > 0!") + } + // Require external shuffle service for dynamic allocation + // Otherwise, we may lose shuffle files when killing executors + if (!conf.getBoolean("spark.shuffle.service.enabled", false) && !testing) { + throw new SparkException("Dynamic allocation of executors requires the external " + + "shuffle service. You may enable this through spark.shuffle.service.enabled.") + } } /** @@ -254,7 +271,7 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging val removeRequestAcknowledged = testing || sc.killExecutor(executorId) if (removeRequestAcknowledged) { logInfo(s"Removing executor $executorId because it has been idle for " + - s"$removeThresholdSeconds seconds (new desired total will be ${numExistingExecutors - 1})") + s"$executorIdleTimeout seconds (new desired total will be ${numExistingExecutors - 1})") executorsPendingToRemove.add(executorId) true } else { @@ -329,8 +346,8 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging private def onExecutorIdle(executorId: String): Unit = synchronized { if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) { logDebug(s"Starting idle timer for $executorId because there are no more tasks " + - s"scheduled to run on the executor (to expire in $removeThresholdSeconds seconds)") - removeTimes(executorId) = clock.getTimeMillis + removeThresholdSeconds * 1000 + s"scheduled to run on the executor (to expire in $executorIdleTimeout seconds)") + removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeout * 1000 } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index a5fb87b9b2c5..e48d7772d6ee 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -40,7 +40,6 @@ import org.apache.spark.network.util.{ConfigProvider, TransportConf} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.shuffle.hash.HashShuffleManager -import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.util._ private[spark] sealed trait BlockValues @@ -97,7 +96,12 @@ private[spark] class BlockManager( private[spark] val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) - private val externalShuffleServicePort = conf.getInt("spark.shuffle.service.port", 7337) + + // Port used by the external shuffle service. In Yarn mode, this may be already be + // set through the Hadoop configuration as the server is launched in the Yarn NM. + private val externalShuffleServicePort = + Utils.getSparkOrYarnConfig(conf, "spark.shuffle.service.port", "7337").toInt + // Check that we're not using external shuffle service with consolidated shuffle files. if (externalShuffleServiceEnabled && conf.getBoolean("spark.shuffle.consolidateFiles", false) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 6ab94af9f373..7caf6bcf94ef 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -45,6 +45,7 @@ import org.json4s._ import tachyon.client.{TachyonFile,TachyonFS} import org.apache.spark._ +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} /** CallSite represents a place in user code. It can have a short and a long form. */ @@ -1780,6 +1781,21 @@ private[spark] object Utils extends Logging { val manifest = new JarManifest(manifestUrl.openStream()) manifest.getMainAttributes.getValue(Name.IMPLEMENTATION_VERSION) }.getOrElse("Unknown") + + /** + * Return the value of a config either through the SparkConf or the Hadoop configuration + * if this is Yarn mode. In the latter case, this defaults to the value set through SparkConf + * if the key is not set in the Hadoop configuration. + */ + def getSparkOrYarnConfig(conf: SparkConf, key: String, default: String): String = { + val sparkValue = conf.get(key, default) + if (SparkHadoopUtil.get.isYarnMode) { + SparkHadoopUtil.get.newConfiguration(conf).get(key, sparkValue) + } else { + sparkValue + } + } + } /** diff --git a/make-distribution.sh b/make-distribution.sh index 0bc839e1dbe4..fac7f7e284be 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -181,6 +181,9 @@ echo "Spark $VERSION$GITREVSTRING built for Hadoop $SPARK_HADOOP_VERSION" > "$DI # Copy jars cp "$FWDIR"/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/lib/" cp "$FWDIR"/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/" +cp "$FWDIR"/network/yarn/target/scala*/spark-network-yarn*.jar "$DISTDIR/lib/" +cp "$FWDIR"/network/yarn/target/scala*/spark-network-shuffle*.jar "$DISTDIR/lib/" +cp "$FWDIR"/network/yarn/target/scala*/spark-network-common*.jar "$DISTDIR/lib/" # Copy example sources (needed for python and SQL) mkdir -p "$DISTDIR/examples/src/main" diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java new file mode 100644 index 000000000000..e66c4af0f1eb --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.lang.Override; +import java.nio.ByteBuffer; +import java.nio.charset.Charset; +import java.util.concurrent.ConcurrentHashMap; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.sasl.SecretKeyHolder; + +/** + * A class that manages shuffle secret used by the external shuffle service. + */ +public class ShuffleSecretManager implements SecretKeyHolder { + private final Logger logger = LoggerFactory.getLogger(ShuffleSecretManager.class); + private final ConcurrentHashMap shuffleSecretMap; + + private static final Charset UTF8_CHARSET = Charset.forName("UTF-8"); + + // Spark user used for authenticating SASL connections + // Note that this must match the value in org.apache.spark.SecurityManager + private static final String SPARK_SASL_USER = "sparkSaslUser"; + + /** + * Convert the given string to a byte buffer. The resulting buffer can be converted back to + * the same string through {@link #bytesToString(ByteBuffer)}. This is used if the external + * shuffle service represents shuffle secrets as bytes buffers instead of strings. + */ + public static ByteBuffer stringToBytes(String s) { + return ByteBuffer.wrap(s.getBytes(UTF8_CHARSET)); + } + + /** + * Convert the given byte buffer to a string. The resulting string can be converted back to + * the same byte buffer through {@link #stringToBytes(String)}. This is used if the external + * shuffle service represents shuffle secrets as bytes buffers instead of strings. + */ + public static String bytesToString(ByteBuffer b) { + return new String(b.array(), UTF8_CHARSET); + } + + public ShuffleSecretManager() { + shuffleSecretMap = new ConcurrentHashMap(); + } + + /** + * Register an application with its secret. + * Executors need to first authenticate themselves with the same secret before + * fetching shuffle files written by other executors in this application. + */ + public void registerApp(String appId, String shuffleSecret) { + if (!shuffleSecretMap.contains(appId)) { + shuffleSecretMap.put(appId, shuffleSecret); + logger.info("Registered shuffle secret for application {}", appId); + } else { + logger.debug("Application {} already registered", appId); + } + } + + /** + * Register an application with its secret specified as a byte buffer. + */ + public void registerApp(String appId, ByteBuffer shuffleSecret) { + registerApp(appId, bytesToString(shuffleSecret)); + } + + /** + * Unregister an application along with its secret. + * This is called when the application terminates. + */ + public void unregisterApp(String appId) { + if (shuffleSecretMap.contains(appId)) { + shuffleSecretMap.remove(appId); + logger.info("Unregistered shuffle secret for application {}", appId); + } else { + logger.warn("Attempted to unregister application {} when it is not registered", appId); + } + } + + /** + * Return the Spark user for authenticating SASL connections. + */ + @Override + public String getSaslUser(String appId) { + return SPARK_SASL_USER; + } + + /** + * Return the secret key registered with the given application. + * This key is used to authenticate the executors before they can fetch shuffle files + * written by this application from the external shuffle service. If the specified + * application is not registered, return null. + */ + @Override + public String getSecretKey(String appId) { + return shuffleSecretMap.get(appId); + } +} diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml new file mode 100644 index 000000000000..e60d8c1f7876 --- /dev/null +++ b/network/yarn/pom.xml @@ -0,0 +1,58 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.2.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-network-yarn_2.10 + jar + Spark Project Yarn Shuffle Service Code + http://spark.apache.org/ + + network-yarn + + + + + + org.apache.spark + spark-network-shuffle_2.10 + ${project.version} + + + + + org.apache.hadoop + hadoop-client + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java new file mode 100644 index 000000000000..bb0b8f7e6cba --- /dev/null +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.yarn; + +import java.lang.Override; +import java.nio.ByteBuffer; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.api.records.ApplicationId; +import org.apache.hadoop.yarn.api.records.ContainerId; +import org.apache.hadoop.yarn.server.api.AuxiliaryService; +import org.apache.hadoop.yarn.server.api.ApplicationInitializationContext; +import org.apache.hadoop.yarn.server.api.ApplicationTerminationContext; +import org.apache.hadoop.yarn.server.api.ContainerInitializationContext; +import org.apache.hadoop.yarn.server.api.ContainerTerminationContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.sasl.SaslRpcHandler; +import org.apache.spark.network.sasl.ShuffleSecretManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; +import org.apache.spark.network.util.TransportConf; +import org.apache.spark.network.yarn.util.HadoopConfigProvider; + +/** + * An external shuffle service used by Spark on Yarn. + * + * This is intended to be a long-running auxiliary service that runs in the NodeManager process. + * A Spark application may connect to this service by setting `spark.shuffle.service.enabled`. + * The application also automatically derives the service port through `spark.shuffle.service.port` + * specified in the Yarn configuration. This is so that both the clients and the server agree on + * the same port to communicate on. + * + * The service also optionally supports authentication. This ensures that executors from one + * application cannot read the shuffle files written by those from another. This feature can be + * enabled by setting `spark.authenticate` in the Yarn configuration before starting the NM. + * Note that the Spark application must also set `spark.authenticate` manually and, unlike in + * the case of the service port, will not inherit this setting from the Yarn configuration. This + * is because an application running on the same Yarn cluster may choose to not use the external + * shuffle service, in which case its setting of `spark.authenticate` should be independent of + * the service's. + */ +public class YarnShuffleService extends AuxiliaryService { + private final Logger logger = LoggerFactory.getLogger(YarnShuffleService.class); + + // Port on which the shuffle server listens for fetch requests + private static final String SPARK_SHUFFLE_SERVICE_PORT_KEY = "spark.shuffle.service.port"; + private static final int DEFAULT_SPARK_SHUFFLE_SERVICE_PORT = 7337; + + // Whether the shuffle server should authenticate fetch requests + private static final String SPARK_AUTHENTICATE_KEY = "spark.authenticate"; + private static final boolean DEFAULT_SPARK_AUTHENTICATE = false; + + // An entity that manages the shuffle secret per application + // This is used only if authentication is enabled + private ShuffleSecretManager secretManager; + + // The actual server that serves shuffle files + private TransportServer shuffleServer = null; + + public YarnShuffleService() { + super("spark_shuffle"); + logger.info("Initializing YARN shuffle service for Spark"); + } + + /** + * Return whether authentication is enabled as specified by the configuration. + * If so, fetch requests will fail unless the appropriate authentication secret + * for the application is provided. + */ + private boolean isAuthenticationEnabled() { + return secretManager != null; + } + + /** + * Start the shuffle server with the given configuration. + */ + @Override + protected void serviceInit(Configuration conf) { + // If authentication is enabled, set up the shuffle server to use a + // special RPC handler that filters out unauthenticated fetch requests + boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); + RpcHandler rpcHandler = new ExternalShuffleBlockHandler(); + if (authEnabled) { + secretManager = new ShuffleSecretManager(); + rpcHandler = new SaslRpcHandler(rpcHandler, secretManager); + } + + int port = conf.getInt( + SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT); + TransportConf transportConf = new TransportConf(new HadoopConfigProvider(conf)); + TransportContext transportContext = new TransportContext(transportConf, rpcHandler); + shuffleServer = transportContext.createServer(port); + String authEnabledString = authEnabled ? "enabled" : "not enabled"; + logger.info("Started YARN shuffle service for Spark on port {}. " + + "Authentication is {}.", port, authEnabledString); + } + + @Override + public void initializeApplication(ApplicationInitializationContext context) { + String appId = context.getApplicationId().toString(); + try { + ByteBuffer shuffleSecret = context.getApplicationDataForService(); + logger.info("Initializing application {}", appId); + if (isAuthenticationEnabled()) { + secretManager.registerApp(appId, shuffleSecret); + } + } catch (Exception e) { + logger.error("Exception when initializing application {}", appId, e); + } + } + + @Override + public void stopApplication(ApplicationTerminationContext context) { + String appId = context.getApplicationId().toString(); + try { + logger.info("Stopping application {}", appId); + if (isAuthenticationEnabled()) { + secretManager.unregisterApp(appId); + } + } catch (Exception e) { + logger.error("Exception when stopping application {}", appId, e); + } + } + + @Override + public void initializeContainer(ContainerInitializationContext context) { + ContainerId containerId = context.getContainerId(); + logger.info("Initializing container {}", containerId); + } + + @Override + public void stopContainer(ContainerTerminationContext context) { + ContainerId containerId = context.getContainerId(); + logger.info("Stopping container {}", containerId); + } + + /** + * Close the shuffle server to clean up any associated state. + */ + @Override + protected void serviceStop() { + try { + if (shuffleServer != null) { + shuffleServer.close(); + } + } catch (Exception e) { + logger.error("Exception when stopping service", e); + } + } + + // Not currently used + @Override + public ByteBuffer getMetaData() { + return ByteBuffer.allocate(0); + } + +} diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java new file mode 100644 index 000000000000..884861752e80 --- /dev/null +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.yarn.util; + +import java.util.NoSuchElementException; + +import org.apache.hadoop.conf.Configuration; + +import org.apache.spark.network.util.ConfigProvider; + +/** Use the Hadoop configuration to obtain config values. */ +public class HadoopConfigProvider extends ConfigProvider { + private final Configuration conf; + + public HadoopConfigProvider(Configuration conf) { + this.conf = conf; + } + + @Override + public String get(String name) { + String value = conf.get(name); + if (value == null) { + throw new NoSuchElementException(name); + } + return value; + } +} diff --git a/pom.xml b/pom.xml index eb613531b8a5..88ef67c515b3 100644 --- a/pom.xml +++ b/pom.xml @@ -1229,6 +1229,7 @@ yarn-alpha yarn + network/yarn @@ -1236,6 +1237,7 @@ yarn yarn + network/yarn diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 33618f540176..657e4b443277 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -38,9 +38,9 @@ object BuildCommons { "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", "streaming-zeromq").map(ProjectRef(buildLocation, _)) - val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl, sparkKinesisAsl) = - Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl", "kinesis-asl") - .map(ProjectRef(buildLocation, _)) + val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, networkYarn, java8Tests, + sparkGangliaLgpl, sparkKinesisAsl) = Seq("yarn", "yarn-stable", "yarn-alpha", "network-yarn", + "java8-tests", "ganglia-lgpl", "kinesis-asl").map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(assembly, examples) = Seq("assembly", "examples") .map(ProjectRef(buildLocation, _)) @@ -143,7 +143,7 @@ object SparkBuild extends PomBuild { // TODO: Add Sql to mima checks allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl, - streamingFlumeSink, networkCommon, networkShuffle).contains(x)).foreach { + streamingFlumeSink, networkCommon, networkShuffle, networkYarn).contains(x)).foreach { x => enable(MimaBuild.mimaSettings(sparkHome, x))(x) } diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 7ee4b5c842df..5f47c79cabae 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -36,6 +36,7 @@ import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records, ProtoUtils} import org.apache.spark.{SecurityManager, SparkConf, Logging} +import org.apache.spark.network.sasl.ShuffleSecretManager @deprecated("use yarn/stable", "1.2.0") class ExecutorRunnable( @@ -90,6 +91,21 @@ class ExecutorRunnable( ctx.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr)) + // If external shuffle service is enabled, register with the Yarn shuffle service already + // started on the NodeManager and, if authentication is enabled, provide it with our secret + // key for fetching shuffle files later + if (sparkConf.getBoolean("spark.shuffle.service.enabled", false)) { + val secretString = securityMgr.getSecretKey() + val secretBytes = + if (secretString != null) { + ShuffleSecretManager.stringToBytes(secretString) + } else { + // Authentication is not enabled, so just provide dummy metadata + ByteBuffer.allocate(0) + } + ctx.setServiceData(Map[String, ByteBuffer]("spark_shuffle" -> secretBytes)) + } + // Send the start request to the ContainerManager val startReq = Records.newRecord(classOf[StartContainerRequest]) .asInstanceOf[StartContainerRequest] diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 0b5a92d87d72..18f48b4b6caf 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -36,6 +36,7 @@ import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records} import org.apache.spark.{SecurityManager, SparkConf, Logging} +import org.apache.spark.network.sasl.ShuffleSecretManager class ExecutorRunnable( @@ -89,6 +90,21 @@ class ExecutorRunnable( ctx.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr)) + // If external shuffle service is enabled, register with the Yarn shuffle service already + // started on the NodeManager and, if authentication is enabled, provide it with our secret + // key for fetching shuffle files later + if (sparkConf.getBoolean("spark.shuffle.service.enabled", false)) { + val secretString = securityMgr.getSecretKey() + val secretBytes = + if (secretString != null) { + ShuffleSecretManager.stringToBytes(secretString) + } else { + // Authentication is not enabled, so just provide dummy metadata + ByteBuffer.allocate(0) + } + ctx.setServiceData(Map[String, ByteBuffer]("spark_shuffle" -> secretBytes)) + } + // Send the start request to the ContainerManager nmClient.startContainer(container, ctx) } From fe4ead2995ab8529602090ed21941b6005a07c9d Mon Sep 17 00:00:00 2001 From: "jay@apache.org" Date: Wed, 5 Nov 2014 15:45:34 -0800 Subject: [PATCH 030/652] SPARK-4040. Update documentation to exemplify use of local (n) value, fo... This is a minor docs update which helps to clarify the way local[n] is used for streaming apps. Author: jay@apache.org Closes #2964 from jayunit100/SPARK-4040 and squashes the following commits: 35b5a5e [jay@apache.org] SPARK-4040: Update documentation to exemplify use of local (n) value. (cherry picked from commit 868cd4c3ca11e6ecc4425b972d9a20c360b52425) Signed-off-by: Matei Zaharia --- docs/configuration.md | 10 ++++++++-- docs/streaming-programming-guide.md | 14 +++++++++----- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 685101ea5c9c..0f9eb81f6e99 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -21,16 +21,22 @@ application. These properties can be set directly on a [SparkConf](api/scala/index.html#org.apache.spark.SparkConf) passed to your `SparkContext`. `SparkConf` allows you to configure some of the common properties (e.g. master URL and application name), as well as arbitrary key-value pairs through the -`set()` method. For example, we could initialize an application as follows: +`set()` method. For example, we could initialize an application with two threads as follows: + +Note that we run with local[2], meaning two threads - which represents "minimal" parallelism, +which can help detect bugs that only exist when we run in a distributed context. {% highlight scala %} val conf = new SparkConf() - .setMaster("local") + .setMaster("local[2]") .setAppName("CountingSheep") .set("spark.executor.memory", "1g") val sc = new SparkContext(conf) {% endhighlight %} +Note that we can have more than 1 thread in local mode, and in cases like spark streaming, we may actually +require one to prevent any sort of starvation issues. + ## Dynamically Loading Spark Properties In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For instance, if you'd like to run the same application with different masters or different diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 8bbba88b3197..44a1f3ad7560 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -68,7 +68,9 @@ import org.apache.spark._ import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext._ -// Create a local StreamingContext with two working thread and batch interval of 1 second +// Create a local StreamingContext with two working thread and batch interval of 1 second. +// The master requires 2 cores to prevent from a starvation scenario. + val conf = new SparkConf().setMaster("local[2]").setAppName("NetworkWordCount") val ssc = new StreamingContext(conf, Seconds(1)) {% endhighlight %} @@ -586,11 +588,13 @@ Every input DStream (except file stream) is associated with a single [Receiver]( A receiver is run within a Spark worker/executor as a long-running task, hence it occupies one of the cores allocated to the Spark Streaming application. Hence, it is important to remember that Spark Streaming application needs to be allocated enough cores to process the received data, as well as, to run the receiver(s). Therefore, few important points to remember are: -##### Points to remember: +##### Points to remember {:.no_toc} -- If the number of cores allocated to the application is less than or equal to the number of input DStreams / receivers, then the system will receive data, but not be able to process them. -- When running locally, if you master URL is set to "local", then there is only one core to run tasks. That is insufficient for programs with even one input DStream (file streams are okay) as the receiver will occupy that core and there will be no core left to process the data. - +- If the number of threads allocated to the application is less than or equal to the number of input DStreams / receivers, then the system will receive data, but not be able to process them. +- When running locally, if you master URL is set to "local", then there is only one core to run tasks. That is insufficient for programs using a DStream as the receiver (file streams are okay). So, a "local" master URL in a streaming app is generally going to cause starvation for the processor. +Thus in any streaming app, you generally will want to allocate more than one thread (i.e. set your master to "local[2]") when testing locally. +See [Spark Properties] (configuration.html#spark-properties.html). + ### Basic Sources {:.no_toc} From 0e16d3a3dde7a0988dfd8eff05922a1ac917fe28 Mon Sep 17 00:00:00 2001 From: Jongyoul Lee Date: Wed, 5 Nov 2014 15:49:42 -0800 Subject: [PATCH 031/652] SPARK-3223 runAsSparkUser cannot change HDFS write permission properly i... ...n mesos cluster mode - change master newer Author: Jongyoul Lee Closes #3034 from jongyoul/SPARK-3223 and squashes the following commits: 42b2ed3 [Jongyoul Lee] SPARK-3223 runAsSparkUser cannot change HDFS write permission properly in mesos cluster mode - change master newer (cherry picked from commit f7ac8c2b1de96151231617846b7468d23379c74a) Signed-off-by: Andrew Or --- .../scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala | 2 +- .../spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index d8c0e2f66df0..e4b859846035 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -93,7 +93,7 @@ private[spark] class CoarseMesosSchedulerBackend( setDaemon(true) override def run() { val scheduler = CoarseMesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(sc.appName).build() + val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() driver = new MesosSchedulerDriver(scheduler, fwInfo, master) try { { val ret = driver.run() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 8e2faff90f9b..7d097a3a7aaa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -72,7 +72,7 @@ private[spark] class MesosSchedulerBackend( setDaemon(true) override def run() { val scheduler = MesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(sc.appName).build() + val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() driver = new MesosSchedulerDriver(scheduler, fwInfo, master) try { val ret = driver.run() From 9ac5c517b64606db7d6b8ac3b823c3d5a45e0ed0 Mon Sep 17 00:00:00 2001 From: Brenden Matthews Date: Wed, 5 Nov 2014 16:02:44 -0800 Subject: [PATCH 032/652] [SPARK-4158] Fix for missing resources. Mesos offers may not contain all resources, and Spark needs to check to ensure they are present and sufficient. Spark may throw an erroneous exception when resources aren't present. Author: Brenden Matthews Closes #3024 from brndnmtthws/fix-mesos-resource-misuse and squashes the following commits: e5f9580 [Brenden Matthews] [SPARK-4158] Fix for missing resources. (cherry picked from commit cb0eae3b78d7f6f56c0b9521ee48564a4967d3de) Signed-off-by: Andrew Or --- .../scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala | 3 +-- .../spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index e4b859846035..5289661eb896 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -242,8 +242,7 @@ private[spark] class CoarseMesosSchedulerBackend( for (r <- res if r.getName == name) { return r.getScalar.getValue } - // If we reached here, no resource with the required name was present - throw new IllegalArgumentException("No resource called " + name + " in " + res) + 0 } /** Build a Mesos resource protobuf object */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 7d097a3a7aaa..c5f3493477bc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -278,8 +278,7 @@ private[spark] class MesosSchedulerBackend( for (r <- res if r.getName == name) { return r.getScalar.getValue } - // If we reached here, no resource with the required name was present - throw new IllegalArgumentException("No resource called " + name + " in " + res) + 0 } /** Turn a Spark TaskDescription into a Mesos task */ From ff84a8ae258083423529885d85bf1d939a62d899 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 5 Nov 2014 19:51:18 -0800 Subject: [PATCH 033/652] [SPARK-4254] [mllib] MovieLensALS bug fix Changed code so it does not try to serialize Params. CC: mengxr debasish83 srowen Author: Joseph K. Bradley Closes #3116 from jkbradley/als-bugfix and squashes the following commits: e575bd8 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into als-bugfix 9401b16 [Joseph K. Bradley] changed implicitPrefs so it is not serialized to fix MovieLensALS example bug (cherry picked from commit c315d1316cb2372e90ae3a12f72d5b3304435a6b) Signed-off-by: Xiangrui Meng --- .../scala/org/apache/spark/examples/mllib/MovieLensALS.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 8796c28db8a6..91a0a860d6c7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -106,9 +106,11 @@ object MovieLensALS { Logger.getRootLogger.setLevel(Level.WARN) + val implicitPrefs = params.implicitPrefs + val ratings = sc.textFile(params.input).map { line => val fields = line.split("::") - if (params.implicitPrefs) { + if (implicitPrefs) { /* * MovieLens ratings are on a scale of 1-5: * 5: Must see From 7e0da9f6b423842adc9fed2db2d4a80cab541351 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 5 Nov 2014 19:56:16 -0800 Subject: [PATCH 034/652] [SPARK-4262][SQL] add .schemaRDD to JavaSchemaRDD marmbrus Author: Xiangrui Meng Closes #3125 from mengxr/SPARK-4262 and squashes the following commits: 307695e [Xiangrui Meng] add .schemaRDD to JavaSchemaRDD (cherry picked from commit 3d2b5bc5bb979d8b0b71e06bc0f4548376fdbb98) Signed-off-by: Xiangrui Meng --- .../scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index 1e0ccb368a27..78e8d908fe0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -47,6 +47,9 @@ class JavaSchemaRDD( private[sql] val baseSchemaRDD = new SchemaRDD(sqlContext, logicalPlan) + /** Returns the underlying Scala SchemaRDD. */ + val schemaRDD: SchemaRDD = baseSchemaRDD + override val classTag = scala.reflect.classTag[Row] override def wrapRDD(rdd: RDD[Row]): JavaRDD[Row] = JavaRDD.fromRDD(rdd) From 70f6f36e03f97847cd2f3e4fe2902bb8459ca6a3 Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Wed, 5 Nov 2014 20:45:35 -0800 Subject: [PATCH 035/652] [SPARK-4137] [EC2] Don't change working dir on user This issue was uncovered after [this discussion](https://issues.apache.org/jira/browse/SPARK-3398?focusedCommentId=14187471&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-14187471). Don't change the working directory on the user. This breaks relative paths the user may pass in, e.g., for the SSH identity file. ``` ./ec2/spark-ec2 -i ../my.pem ``` This patch will preserve the user's current working directory and allow calls like the one above to work. Author: Nicholas Chammas Closes #2988 from nchammas/spark-ec2-cwd and squashes the following commits: f3850b5 [Nicholas Chammas] pep8 fix fbc20c7 [Nicholas Chammas] revert to old commenting style 752f958 [Nicholas Chammas] specify deploy.generic path absolutely bcdf6a5 [Nicholas Chammas] fix typo 77871a2 [Nicholas Chammas] add clarifying comment ce071fc [Nicholas Chammas] don't change working dir (cherry picked from commit db45f5ad0368760dbeaa618a04f66ae9b2bed656) Signed-off-by: Shivaram Venkataraman --- ec2/spark-ec2 | 8 ++++++-- ec2/spark_ec2.py | 12 +++++++++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/ec2/spark-ec2 b/ec2/spark-ec2 index 31f9771223e5..4aa908242eea 100755 --- a/ec2/spark-ec2 +++ b/ec2/spark-ec2 @@ -18,5 +18,9 @@ # limitations under the License. # -cd "`dirname $0`" -PYTHONPATH="./third_party/boto-2.4.1.zip/boto-2.4.1:$PYTHONPATH" python ./spark_ec2.py "$@" +# Preserve the user's CWD so that relative paths are passed correctly to +#+ the underlying Python script. +SPARK_EC2_DIR="$(dirname $0)" + +PYTHONPATH="${SPARK_EC2_DIR}/third_party/boto-2.4.1.zip/boto-2.4.1:$PYTHONPATH" \ + python "${SPARK_EC2_DIR}/spark_ec2.py" "$@" diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 50f88f735650..a5396c237591 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -40,6 +40,7 @@ from boto import ec2 DEFAULT_SPARK_VERSION = "1.1.0" +SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) MESOS_SPARK_EC2_BRANCH = "v4" # A URL prefix from which to fetch AMI information @@ -593,7 +594,14 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): ) print "Deploying files to master..." - deploy_files(conn, "deploy.generic", opts, master_nodes, slave_nodes, modules) + deploy_files( + conn=conn, + root_dir=SPARK_EC2_DIR + "/" + "deploy.generic", + opts=opts, + master_nodes=master_nodes, + slave_nodes=slave_nodes, + modules=modules + ) print "Running setup on master..." setup_spark_cluster(master, opts) @@ -730,6 +738,8 @@ def get_num_disks(instance_type): # cluster (e.g. lists of masters and slaves). Files are only deployed to # the first master instance in the cluster, and we expect the setup # script to be run on that instance to copy them to other nodes. +# +# root_dir should be an absolute path to the directory with the files we want to deploy. def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): active_master = master_nodes[0].public_dns_name From 2c84178b8283269512b1c968b9995a7bdedd7aa5 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Thu, 6 Nov 2014 00:03:03 -0800 Subject: [PATCH 036/652] [SPARK-4255] Fix incorrect table striping This commit stripes table rows after hiding some rows, to ensure that rows are correct striped to alternate white and grey even when rows are hidden by default. Author: Kay Ousterhout Closes #3117 from kayousterhout/striping and squashes the following commits: be6e10a [Kay Ousterhout] [SPARK-4255] Fix incorrect table striping (cherry picked from commit 5f27ae16d5b016fae4afeb0f2ad779fd3130b390) Signed-off-by: Kay Ousterhout --- .../org/apache/spark/ui/static/additional-metrics.js | 2 ++ core/src/main/resources/org/apache/spark/ui/static/table.js | 5 ----- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js index c5936b5038ac..badd85ed48c8 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js +++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js @@ -39,6 +39,8 @@ $(function() { var column = "table ." + $(this).attr("name"); $(column).hide(); }); + // Stripe table rows after rows have been hidden to ensure correct striping. + stripeTables(); $("input:checkbox").click(function() { var column = "table ." + $(this).attr("name"); diff --git a/core/src/main/resources/org/apache/spark/ui/static/table.js b/core/src/main/resources/org/apache/spark/ui/static/table.js index 32187ba6e8df..6bb03015abb5 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/table.js +++ b/core/src/main/resources/org/apache/spark/ui/static/table.js @@ -28,8 +28,3 @@ function stripeTables() { }); }); } - -/* Stripe all tables after pages finish loading. */ -$(function() { - stripeTables(); -}); From 01484455c4ee4ee8e848be56f395d38841fbf86a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 6 Nov 2014 00:22:19 -0800 Subject: [PATCH 037/652] [SPARK-4186] add binaryFiles and binaryRecords in Python add binaryFiles() and binaryRecords() in Python ``` binaryFiles(self, path, minPartitions=None): :: Developer API :: Read a directory of binary files from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI as a byte array. Each file is read as a single record and returned in a key-value pair, where the key is the path of each file, the value is the content of each file. Note: Small files are preferred, large file is also allowable, but may cause bad performance. binaryRecords(self, path, recordLength): Load data from a flat binary file, assuming each record is a set of numbers with the specified numerical format (see ByteBuffer), and the number of bytes per record is constant. :param path: Directory to the input data files :param recordLength: The length at which to split the records ``` Author: Davies Liu Closes #3078 from davies/binary and squashes the following commits: cd0bdbd [Davies Liu] Merge branch 'master' of github.com:apache/spark into binary 3aa349b [Davies Liu] add experimental notes 24e84b6 [Davies Liu] Merge branch 'master' of github.com:apache/spark into binary 5ceaa8a [Davies Liu] Merge branch 'master' of github.com:apache/spark into binary 1900085 [Davies Liu] bugfix bb22442 [Davies Liu] add binaryFiles and binaryRecords in Python (cherry picked from commit b41a39e24038876359aeb7ce2bbbb4de2234e5f3) Signed-off-by: Matei Zaharia --- .../scala/org/apache/spark/SparkContext.scala | 4 ++ .../spark/api/java/JavaSparkContext.scala | 12 ++--- .../apache/spark/api/python/PythonRDD.scala | 45 ++++++++++++------- python/pyspark/context.py | 32 ++++++++++++- python/pyspark/tests.py | 19 ++++++++ 5 files changed, 90 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3cdaa6a9cc8a..03ea672c813d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -560,6 +560,8 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { /** + * :: Experimental :: + * * Get an RDD for a Hadoop-readable dataset as PortableDataStream for each file * (useful for binary data) * @@ -602,6 +604,8 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { } /** + * :: Experimental :: + * * Load data from a flat binary file, assuming the length of each record is constant. * * @param path Directory to the input data files diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index e3aeba7e6c39..5c6e8d32c5c8 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -21,11 +21,6 @@ import java.io.Closeable import java.util import java.util.{Map => JMap} -import java.io.DataInputStream - -import org.apache.hadoop.io.{BytesWritable, LongWritable} -import org.apache.spark.input.{PortableDataStream, FixedLengthBinaryInputFormat} - import scala.collection.JavaConversions import scala.collection.JavaConversions._ import scala.language.implicitConversions @@ -33,6 +28,7 @@ import scala.reflect.ClassTag import com.google.common.base.Optional import org.apache.hadoop.conf.Configuration +import org.apache.spark.input.PortableDataStream import org.apache.hadoop.mapred.{InputFormat, JobConf} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} @@ -286,6 +282,8 @@ class JavaSparkContext(val sc: SparkContext) new JavaPairRDD(sc.binaryFiles(path, minPartitions)) /** + * :: Experimental :: + * * Read a directory of binary files from HDFS, a local file system (available on all nodes), * or any Hadoop-supported file system URI as a byte array. Each file is read as a single * record and returned in a key-value pair, where the key is the path of each file, @@ -312,15 +310,19 @@ class JavaSparkContext(val sc: SparkContext) * * @note Small files are preferred; very large files but may cause bad performance. */ + @Experimental def binaryFiles(path: String): JavaPairRDD[String, PortableDataStream] = new JavaPairRDD(sc.binaryFiles(path, defaultMinPartitions)) /** + * :: Experimental :: + * * Load data from a flat binary file, assuming the length of each record is constant. * * @param path Directory to the input data files * @return An RDD of data with values, represented as byte arrays */ + @Experimental def binaryRecords(path: String, recordLength: Int): JavaRDD[Array[Byte]] = { new JavaRDD(sc.binaryRecords(path, recordLength)) } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index e94ccdcd47bb..45beb8fc8c92 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -21,6 +21,8 @@ import java.io._ import java.net._ import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} +import org.apache.spark.input.PortableDataStream + import scala.collection.JavaConversions._ import scala.collection.mutable import scala.language.existentials @@ -395,22 +397,33 @@ private[spark] object PythonRDD extends Logging { newIter.asInstanceOf[Iterator[String]].foreach { str => writeUTF(str, dataOut) } - case pair: Tuple2[_, _] => - pair._1 match { - case bytePair: Array[Byte] => - newIter.asInstanceOf[Iterator[Tuple2[Array[Byte], Array[Byte]]]].foreach { pair => - dataOut.writeInt(pair._1.length) - dataOut.write(pair._1) - dataOut.writeInt(pair._2.length) - dataOut.write(pair._2) - } - case stringPair: String => - newIter.asInstanceOf[Iterator[Tuple2[String, String]]].foreach { pair => - writeUTF(pair._1, dataOut) - writeUTF(pair._2, dataOut) - } - case other => - throw new SparkException("Unexpected Tuple2 element type " + pair._1.getClass) + case stream: PortableDataStream => + newIter.asInstanceOf[Iterator[PortableDataStream]].foreach { stream => + val bytes = stream.toArray() + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } + case (key: String, stream: PortableDataStream) => + newIter.asInstanceOf[Iterator[(String, PortableDataStream)]].foreach { + case (key, stream) => + writeUTF(key, dataOut) + val bytes = stream.toArray() + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } + case (key: String, value: String) => + newIter.asInstanceOf[Iterator[(String, String)]].foreach { + case (key, value) => + writeUTF(key, dataOut) + writeUTF(value, dataOut) + } + case (key: Array[Byte], value: Array[Byte]) => + newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach { + case (key, value) => + dataOut.writeInt(key.length) + dataOut.write(key) + dataOut.writeInt(value.length) + dataOut.write(value) } case other => throw new SparkException("Unexpected element type " + first.getClass) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a0e4821728c8..faa5952258ae 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -29,7 +29,7 @@ from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ - PairDeserializer, CompressedSerializer, AutoBatchedSerializer + PairDeserializer, CompressedSerializer, AutoBatchedSerializer, NoOpSerializer from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD from pyspark.traceback_utils import CallSite, first_spark_call @@ -388,6 +388,36 @@ def wholeTextFiles(self, path, minPartitions=None, use_unicode=True): return RDD(self._jsc.wholeTextFiles(path, minPartitions), self, PairDeserializer(UTF8Deserializer(use_unicode), UTF8Deserializer(use_unicode))) + def binaryFiles(self, path, minPartitions=None): + """ + :: Experimental :: + + Read a directory of binary files from HDFS, a local file system + (available on all nodes), or any Hadoop-supported file system URI + as a byte array. Each file is read as a single record and returned + in a key-value pair, where the key is the path of each file, the + value is the content of each file. + + Note: Small files are preferred, large file is also allowable, but + may cause bad performance. + """ + minPartitions = minPartitions or self.defaultMinPartitions + return RDD(self._jsc.binaryFiles(path, minPartitions), self, + PairDeserializer(UTF8Deserializer(), NoOpSerializer())) + + def binaryRecords(self, path, recordLength): + """ + :: Experimental :: + + Load data from a flat binary file, assuming each record is a set of numbers + with the specified numerical format (see ByteBuffer), and the number of + bytes per record is constant. + + :param path: Directory to the input data files + :param recordLength: The length at which to split the records + """ + return RDD(self._jsc.binaryRecords(path, recordLength), self, NoOpSerializer()) + def _dictToJavaMap(self, d): jm = self._jvm.java.util.HashMap() if not d: diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 7e61b017efa7..9f625c5c6ca4 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1110,6 +1110,25 @@ def test_converters(self): (u'\x03', [2.0])] self.assertEqual(maps, em) + def test_binary_files(self): + path = os.path.join(self.tempdir.name, "binaryfiles") + os.mkdir(path) + data = "short binary data" + with open(os.path.join(path, "part-0000"), 'w') as f: + f.write(data) + [(p, d)] = self.sc.binaryFiles(path).collect() + self.assertTrue(p.endswith("part-0000")) + self.assertEqual(d, data) + + def test_binary_records(self): + path = os.path.join(self.tempdir.name, "binaryrecords") + os.mkdir(path) + with open(os.path.join(path, "part-0000"), 'w') as f: + for i in range(100): + f.write('%04d' % i) + result = self.sc.binaryRecords(path, 4).map(int).collect() + self.assertEqual(range(100), result) + class OutputFormatTests(ReusedPySparkTestCase): From aaaeaf93902a1954df11fa4982b1c6c7e29f5b8d Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Thu, 6 Nov 2014 10:45:46 -0800 Subject: [PATCH 038/652] [SPARK-4264] Completion iterator should only invoke callback once Author: Aaron Davidson Closes #3128 from aarondav/compiter and squashes the following commits: 698e4be [Aaron Davidson] [SPARK-4264] Completion iterator should only invoke callback once (cherry picked from commit 23eaf0e12ff221dcca40a79e61b6cc5e7c846cb5) Signed-off-by: Aaron Davidson --- .../spark/util/CompletionIterator.scala | 5 +- .../spark/util/CompletionIteratorSuite.scala | 47 +++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) create mode 100644 core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala diff --git a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala index b6a099825f01..390310243ee0 100644 --- a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala @@ -25,10 +25,13 @@ private[spark] // scalastyle:off abstract class CompletionIterator[ +A, +I <: Iterator[A]](sub: I) extends Iterator[A] { // scalastyle:on + + private[this] var completed = false def next() = sub.next() def hasNext = { val r = sub.hasNext - if (!r) { + if (!r && !completed) { + completed = true completion() } r diff --git a/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala new file mode 100644 index 000000000000..3755d43e25ea --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.scalatest.FunSuite + +class CompletionIteratorSuite extends FunSuite { + test("basic test") { + var numTimesCompleted = 0 + val iter = List(1, 2, 3).iterator + val completionIter = CompletionIterator[Int, Iterator[Int]](iter, { numTimesCompleted += 1 }) + + assert(completionIter.hasNext) + assert(completionIter.next() === 1) + assert(numTimesCompleted === 0) + + assert(completionIter.hasNext) + assert(completionIter.next() === 2) + assert(numTimesCompleted === 0) + + assert(completionIter.hasNext) + assert(completionIter.next() === 3) + assert(numTimesCompleted === 0) + + assert(!completionIter.hasNext) + assert(numTimesCompleted === 1) + + // SPARK-4264: Calling hasNext should not trigger the completion callback again. + assert(!completionIter.hasNext) + assert(numTimesCompleted === 1) + } +} From 9061bc4e127abb0c44e37f1b8b7706883d451bc7 Mon Sep 17 00:00:00 2001 From: lianhuiwang Date: Thu, 6 Nov 2014 10:46:45 -0800 Subject: [PATCH 039/652] [SPARK-4249][GraphX]fix a problem of EdgePartitionBuilder in Graphx at first srcIds is not initialized and are all 0. so we use edgeArray(0).srcId to currSrcId Author: lianhuiwang Closes #3138 from lianhuiwang/SPARK-4249 and squashes the following commits: 3f4e503 [lianhuiwang] fix a problem of EdgePartitionBuilder in Graphx (cherry picked from commit d15c6e9dc2860bbe56e31ddf71218ccc6d5c841d) Signed-off-by: Ankur Dave --- .../org/apache/spark/graphx/impl/EdgePartitionBuilder.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala index 4520beb99151..2b6137be2554 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala @@ -45,8 +45,8 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla // Copy edges into columnar structures, tracking the beginnings of source vertex id clusters and // adding them to the index if (edgeArray.length > 0) { - index.update(srcIds(0), 0) - var currSrcId: VertexId = srcIds(0) + index.update(edgeArray(0).srcId, 0) + var currSrcId: VertexId = edgeArray(0).srcId var i = 0 while (i < edgeArray.size) { srcIds(i) = edgeArray(i).srcId From 9ea0fac0eafd7264a30f36c0d20863700245991f Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 6 Nov 2014 15:31:07 -0800 Subject: [PATCH 040/652] [HOT FIX] Make distribution fails This was added by me in https://github.com/apache/spark/commit/61a5cced049a8056292ba94f23fa7bd040f50685. The real fix will be added in [SPARK-4281](https://issues.apache.org/jira/browse/SPARK-4281). Author: Andrew Or Closes #3145 from andrewor14/fix-make-distribution and squashes the following commits: c78be61 [Andrew Or] Hot fix make distribution (cherry picked from commit 470881b24a503c9edcaed159c29bafa446ab0e9a) Signed-off-by: Andrew Or --- make-distribution.sh | 3 --- 1 file changed, 3 deletions(-) diff --git a/make-distribution.sh b/make-distribution.sh index fac7f7e284be..0bc839e1dbe4 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -181,9 +181,6 @@ echo "Spark $VERSION$GITREVSTRING built for Hadoop $SPARK_HADOOP_VERSION" > "$DI # Copy jars cp "$FWDIR"/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/lib/" cp "$FWDIR"/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/" -cp "$FWDIR"/network/yarn/target/scala*/spark-network-yarn*.jar "$DISTDIR/lib/" -cp "$FWDIR"/network/yarn/target/scala*/spark-network-shuffle*.jar "$DISTDIR/lib/" -cp "$FWDIR"/network/yarn/target/scala*/spark-network-common*.jar "$DISTDIR/lib/" # Copy example sources (needed for python and SQL) mkdir -p "$DISTDIR/examples/src/main" From 6508953a4b8622312c1f0ae4b4b4275b5a2c2bd6 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 6 Nov 2014 17:18:49 -0800 Subject: [PATCH 041/652] [SPARK-3797] Minor addendum to Yarn shuffle service I did not realize there was a `network.util.JavaUtils` when I wrote this code. This PR moves the `ByteBuffer` string conversion to the appropriate place. I tested the changes on a stable yarn cluster. Author: Andrew Or Closes #3144 from andrewor14/yarn-shuffle-util and squashes the following commits: b6c08bf [Andrew Or] Remove unused import 94e205c [Andrew Or] Use netty Unpooled 85202a5 [Andrew Or] Use guava Charsets 057135b [Andrew Or] Reword comment adf186d [Andrew Or] Move byte buffer String conversion logic to JavaUtils (cherry picked from commit 96136f222abd4f3abd10cb78a4ebecdb21f3bde7) Signed-off-by: Andrew Or --- .../apache/spark/network/util/JavaUtils.java | 20 ++++++++++++++++ .../network/sasl/ShuffleSecretManager.java | 24 ++----------------- .../spark/deploy/yarn/ExecutorRunnable.scala | 5 ++-- .../spark/deploy/yarn/ExecutorRunnable.scala | 5 ++-- 4 files changed, 28 insertions(+), 26 deletions(-) diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index 40b71b0c87a4..2856d1c8c933 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -17,6 +17,8 @@ package org.apache.spark.network.util; +import java.nio.ByteBuffer; + import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.Closeable; @@ -25,6 +27,8 @@ import java.io.ObjectOutputStream; import com.google.common.io.Closeables; +import com.google.common.base.Charsets; +import io.netty.buffer.Unpooled; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -73,4 +77,20 @@ public static int nonNegativeHash(Object obj) { int hash = obj.hashCode(); return hash != Integer.MIN_VALUE ? Math.abs(hash) : 0; } + + /** + * Convert the given string to a byte buffer. The resulting buffer can be + * converted back to the same string through {@link #bytesToString(ByteBuffer)}. + */ + public static ByteBuffer stringToBytes(String s) { + return Unpooled.wrappedBuffer(s.getBytes(Charsets.UTF_8)).nioBuffer(); + } + + /** + * Convert the given byte buffer to a string. The resulting string can be + * converted back to the same byte buffer through {@link #stringToBytes(String)}. + */ + public static String bytesToString(ByteBuffer b) { + return Unpooled.wrappedBuffer(b).toString(Charsets.UTF_8); + } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java index e66c4af0f1eb..351c7930a900 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java @@ -19,13 +19,13 @@ import java.lang.Override; import java.nio.ByteBuffer; -import java.nio.charset.Charset; import java.util.concurrent.ConcurrentHashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.util.JavaUtils; /** * A class that manages shuffle secret used by the external shuffle service. @@ -34,30 +34,10 @@ public class ShuffleSecretManager implements SecretKeyHolder { private final Logger logger = LoggerFactory.getLogger(ShuffleSecretManager.class); private final ConcurrentHashMap shuffleSecretMap; - private static final Charset UTF8_CHARSET = Charset.forName("UTF-8"); - // Spark user used for authenticating SASL connections // Note that this must match the value in org.apache.spark.SecurityManager private static final String SPARK_SASL_USER = "sparkSaslUser"; - /** - * Convert the given string to a byte buffer. The resulting buffer can be converted back to - * the same string through {@link #bytesToString(ByteBuffer)}. This is used if the external - * shuffle service represents shuffle secrets as bytes buffers instead of strings. - */ - public static ByteBuffer stringToBytes(String s) { - return ByteBuffer.wrap(s.getBytes(UTF8_CHARSET)); - } - - /** - * Convert the given byte buffer to a string. The resulting string can be converted back to - * the same byte buffer through {@link #stringToBytes(String)}. This is used if the external - * shuffle service represents shuffle secrets as bytes buffers instead of strings. - */ - public static String bytesToString(ByteBuffer b) { - return new String(b.array(), UTF8_CHARSET); - } - public ShuffleSecretManager() { shuffleSecretMap = new ConcurrentHashMap(); } @@ -80,7 +60,7 @@ public void registerApp(String appId, String shuffleSecret) { * Register an application with its secret specified as a byte buffer. */ public void registerApp(String appId, ByteBuffer shuffleSecret) { - registerApp(appId, bytesToString(shuffleSecret)); + registerApp(appId, JavaUtils.bytesToString(shuffleSecret)); } /** diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 5f47c79cabae..7023a1170654 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -36,7 +36,7 @@ import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records, ProtoUtils} import org.apache.spark.{SecurityManager, SparkConf, Logging} -import org.apache.spark.network.sasl.ShuffleSecretManager +import org.apache.spark.network.util.JavaUtils @deprecated("use yarn/stable", "1.2.0") class ExecutorRunnable( @@ -98,7 +98,8 @@ class ExecutorRunnable( val secretString = securityMgr.getSecretKey() val secretBytes = if (secretString != null) { - ShuffleSecretManager.stringToBytes(secretString) + // This conversion must match how the YarnShuffleService decodes our secret + JavaUtils.stringToBytes(secretString) } else { // Authentication is not enabled, so just provide dummy metadata ByteBuffer.allocate(0) diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 18f48b4b6caf..fdd3c2300fa7 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -36,7 +36,7 @@ import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records} import org.apache.spark.{SecurityManager, SparkConf, Logging} -import org.apache.spark.network.sasl.ShuffleSecretManager +import org.apache.spark.network.util.JavaUtils class ExecutorRunnable( @@ -97,7 +97,8 @@ class ExecutorRunnable( val secretString = securityMgr.getSecretKey() val secretBytes = if (secretString != null) { - ShuffleSecretManager.stringToBytes(secretString) + // This conversion must match how the YarnShuffleService decodes our secret + JavaUtils.stringToBytes(secretString) } else { // Authentication is not enabled, so just provide dummy metadata ByteBuffer.allocate(0) From cbe9a6c8a822beaea5a79e4155759c39d078ea2c Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Thu, 6 Nov 2014 17:20:46 -0800 Subject: [PATCH 042/652] [SPARK-4277] Support external shuffle service on Standalone Worker Author: Aaron Davidson Closes #3142 from aarondav/worker and squashes the following commits: 3780bd7 [Aaron Davidson] Address comments 2dcdfc1 [Aaron Davidson] Add private[worker] 47f49d3 [Aaron Davidson] NettyBlockTransferService shouldn't care about app ids (it's only b/t executors) 258417c [Aaron Davidson] [SPARK-4277] Support external shuffle service on executor (cherry picked from commit 6e9ef10fd7446a11f37446c961916ba2a8e02cb8) Signed-off-by: Andrew Or --- .../org/apache/spark/SecurityManager.scala | 14 +--- .../StandaloneWorkerShuffleService.scala | 66 +++++++++++++++++++ .../apache/spark/deploy/worker/Worker.scala | 8 ++- .../storage/ShuffleBlockFetcherIterator.scala | 2 +- .../NettyBlockTransferSecuritySuite.scala | 12 ---- .../spark/network/sasl/SaslMessage.java | 3 +- 6 files changed, 79 insertions(+), 26 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index dee935ffad51..dbff9d12b5ad 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -343,15 +343,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with */ def getSecretKey(): String = secretKey - override def getSaslUser(appId: String): String = { - val myAppId = sparkConf.getAppId - require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}") - getSaslUser() - } - - override def getSecretKey(appId: String): String = { - val myAppId = sparkConf.getAppId - require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}") - getSecretKey() - } + // Default SecurityManager only has a single secret key, so ignore appId. + override def getSaslUser(appId: String): String = getSaslUser() + override def getSecretKey(appId: String): String = getSecretKey() } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala new file mode 100644 index 000000000000..88118e283774 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.worker + +import org.apache.spark.{Logging, SparkConf, SecurityManager} +import org.apache.spark.network.TransportContext +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.sasl.SaslRpcHandler +import org.apache.spark.network.server.TransportServer +import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler + +/** + * Provides a server from which Executors can read shuffle files (rather than reading directly from + * each other), to provide uninterrupted access to the files in the face of executors being turned + * off or killed. + * + * Optionally requires SASL authentication in order to read. See [[SecurityManager]]. + */ +private[worker] +class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: SecurityManager) + extends Logging { + + private val enabled = sparkConf.getBoolean("spark.shuffle.service.enabled", false) + private val port = sparkConf.getInt("spark.shuffle.service.port", 7337) + private val useSasl: Boolean = securityManager.isAuthenticationEnabled() + + private val transportConf = SparkTransportConf.fromSparkConf(sparkConf) + private val blockHandler = new ExternalShuffleBlockHandler() + private val transportContext: TransportContext = { + val handler = if (useSasl) new SaslRpcHandler(blockHandler, securityManager) else blockHandler + new TransportContext(transportConf, handler) + } + + private var server: TransportServer = _ + + /** Starts the external shuffle service if the user has configured us to. */ + def startIfEnabled() { + if (enabled) { + require(server == null, "Shuffle server already started") + logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl") + server = transportContext.createServer(port) + } + } + + def stop() { + if (enabled && server != null) { + server.close() + server = null + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index f1f66d0903f1..ca262de832e2 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -111,6 +111,9 @@ private[spark] class Worker( val drivers = new HashMap[String, DriverRunner] val finishedDrivers = new HashMap[String, DriverRunner] + // The shuffle service is not actually started unless configured. + val shuffleService = new StandaloneWorkerShuffleService(conf, securityMgr) + val publicAddress = { val envVar = System.getenv("SPARK_PUBLIC_DNS") if (envVar != null) envVar else host @@ -154,6 +157,7 @@ private[spark] class Worker( logInfo("Spark home: " + sparkHome) createWorkDir() context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + shuffleService.startIfEnabled() webUi = new WorkerWebUI(this, workDir, webUiPort) webUi.bind() registerWithMaster() @@ -419,6 +423,7 @@ private[spark] class Worker( registrationRetryTimer.foreach(_.cancel()) executors.values.foreach(_.kill()) drivers.values.foreach(_.kill()) + shuffleService.stop() webUi.stop() metricsSystem.stop() } @@ -441,7 +446,8 @@ private[spark] object Worker extends Logging { cores: Int, memory: Int, masterUrls: Array[String], - workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = { + workDir: String, + workerNumber: Option[Int] = None): (ActorSystem, Int) = { // The LocalSparkCluster runs multiple local sparkWorkerX actor systems val conf = new SparkConf diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 1e579187e419..6b1f57a06943 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -92,7 +92,7 @@ final class ShuffleBlockFetcherIterator( * Current [[FetchResult]] being processed. We track this so we can release the current buffer * in case of a runtime exception when processing the current buffer. */ - private[this] var currentResult: FetchResult = null + @volatile private[this] var currentResult: FetchResult = null /** * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index bed0ed9d713d..9162ec980166 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -89,18 +89,6 @@ class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with Sh } } - test("security mismatch app ids") { - val conf0 = new SparkConf() - .set("spark.authenticate", "true") - .set("spark.authenticate.secret", "good") - .set("spark.app.id", "app-id") - val conf1 = conf0.clone.set("spark.app.id", "other-id") - testConnection(conf0, conf1) match { - case Success(_) => fail("Should have failed") - case Failure(t) => t.getMessage should include ("SASL appId app-id did not match") - } - } - /** * Creates two servers with different configurations and sees if they can talk. * Returns Success() if they can transfer a block, and Failure() if the block transfer was failed diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java index 5b77e18c26bf..599cc6428c90 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -58,7 +58,8 @@ public void encode(ByteBuf buf) { public static SaslMessage decode(ByteBuf buf) { if (buf.readByte() != TAG_BYTE) { - throw new IllegalStateException("Expected SaslMessage, received something else"); + throw new IllegalStateException("Expected SaslMessage, received something else" + + " (maybe your client does not have SASL enabled?)"); } int idLength = buf.readInt(); From c1ea5c542f3267c0b23a7775887e3a6ece793fe3 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Thu, 6 Nov 2014 18:39:14 -0800 Subject: [PATCH 043/652] [SPARK-4188] [Core] Perform network-level retry of shuffle file fetches This adds a RetryingBlockFetcher to the NettyBlockTransferService which is wrapped around our typical OneForOneBlockFetcher, adding retry logic in the event of an IOException. This sort of retry allows us to avoid marking an entire executor as failed due to garbage collection or high network load. TODO: - [x] unit tests - [x] put in ExternalShuffleClient too Author: Aaron Davidson Closes #3101 from aarondav/retry and squashes the following commits: 72a2a32 [Aaron Davidson] Add that we should remove the condition around the retry thingy c7fd107 [Aaron Davidson] Fix unit tests e80e4c2 [Aaron Davidson] Address initial comments 6f594cd [Aaron Davidson] Fix unit test 05ff43c [Aaron Davidson] Add to external shuffle client and add unit test 66e5a24 [Aaron Davidson] [SPARK-4238] [Core] Perform network-level retry of shuffle file fetches (cherry picked from commit f165b2bbf5d4acf34d826fa55b900f5bbc295654) Signed-off-by: Reynold Xin --- .../netty/NettyBlockTransferService.scala | 21 +- .../spark/network/client/TransportClient.java | 16 +- .../client/TransportClientFactory.java | 13 +- .../client/TransportResponseHandler.java | 3 +- .../network/protocol/MessageEncoder.java | 2 +- .../spark/network/server/TransportServer.java | 8 +- .../apache/spark/network/util/NettyUtils.java | 14 +- .../spark/network/util/TransportConf.java | 17 + .../network/TransportClientFactorySuite.java | 7 +- .../shuffle/ExternalShuffleClient.java | 31 +- .../shuffle/OneForOneBlockFetcher.java | 9 +- .../network/shuffle/RetryingBlockFetcher.java | 234 +++++++++++++ .../network/sasl/SaslIntegrationSuite.java | 4 +- .../ExternalShuffleIntegrationSuite.java | 18 +- .../shuffle/ExternalShuffleSecuritySuite.java | 6 +- .../shuffle/RetryingBlockFetcherSuite.java | 310 ++++++++++++++++++ 16 files changed, 668 insertions(+), 45 deletions(-) create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java create mode 100644 network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 0d1fc81d2a16..b937ea825f49 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -27,7 +27,7 @@ import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCal import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock} import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap} import org.apache.spark.network.server._ -import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher} +import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils @@ -71,9 +71,22 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage listener: BlockFetchingListener): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") try { - val client = clientFactory.createClient(host, port) - new OneForOneBlockFetcher(client, blockIds.toArray, listener) - .start(OpenBlocks(blockIds.map(BlockId.apply))) + val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { + override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { + val client = clientFactory.createClient(host, port) + new OneForOneBlockFetcher(client, blockIds.toArray, listener) + .start(OpenBlocks(blockIds.map(BlockId.apply))) + } + } + + val maxRetries = transportConf.maxIORetries() + if (maxRetries > 0) { + // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's + // a bug in this code. We should remove the if statement once we're sure of the stability. + new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start() + } else { + blockFetchStarter.createAndStart(blockIds, listener) + } } catch { case e: Exception => logError("Exception while beginning fetchBlocks", e) diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index a08cee02dd57..4e944114e817 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -18,7 +18,9 @@ package org.apache.spark.network.client; import java.io.Closeable; +import java.io.IOException; import java.util.UUID; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import com.google.common.base.Objects; @@ -116,8 +118,12 @@ public void operationComplete(ChannelFuture future) throws Exception { serverAddr, future.cause()); logger.error(errorMsg, future.cause()); handler.removeFetchRequest(streamChunkId); - callback.onFailure(chunkIndex, new RuntimeException(errorMsg, future.cause())); channel.close(); + try { + callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause())); + } catch (Exception e) { + logger.error("Uncaught exception in RPC response callback handler!", e); + } } } }); @@ -147,8 +153,12 @@ public void operationComplete(ChannelFuture future) throws Exception { serverAddr, future.cause()); logger.error(errorMsg, future.cause()); handler.removeRpcRequest(requestId); - callback.onFailure(new RuntimeException(errorMsg, future.cause())); channel.close(); + try { + callback.onFailure(new IOException(errorMsg, future.cause())); + } catch (Exception e) { + logger.error("Uncaught exception in RPC response callback handler!", e); + } } } }); @@ -175,6 +185,8 @@ public void onFailure(Throwable e) { try { return result.get(timeoutMs, TimeUnit.MILLISECONDS); + } catch (ExecutionException e) { + throw Throwables.propagate(e.getCause()); } catch (Exception e) { throw Throwables.propagate(e); } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 1723fed30725..397d3a8455c8 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -18,12 +18,12 @@ package org.apache.spark.network.client; import java.io.Closeable; +import java.io.IOException; import java.lang.reflect.Field; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.util.List; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; import com.google.common.base.Preconditions; @@ -44,7 +44,6 @@ import org.apache.spark.network.TransportContext; import org.apache.spark.network.server.TransportChannelHandler; import org.apache.spark.network.util.IOMode; -import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.TransportConf; @@ -93,15 +92,17 @@ public TransportClientFactory( * * Concurrency: This method is safe to call from multiple threads. */ - public TransportClient createClient(String remoteHost, int remotePort) { + public TransportClient createClient(String remoteHost, int remotePort) throws IOException { // Get connection from the connection pool first. // If it is not found or not active, create a new one. final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); TransportClient cachedClient = connectionPool.get(address); if (cachedClient != null) { if (cachedClient.isActive()) { + logger.trace("Returning cached connection to {}: {}", address, cachedClient); return cachedClient; } else { + logger.info("Found inactive connection to {}, closing it.", address); connectionPool.remove(address, cachedClient); // Remove inactive clients. } } @@ -133,10 +134,10 @@ public void initChannel(SocketChannel ch) { long preConnect = System.currentTimeMillis(); ChannelFuture cf = bootstrap.connect(address); if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) { - throw new RuntimeException( + throw new IOException( String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs())); } else if (cf.cause() != null) { - throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause()); + throw new IOException(String.format("Failed to connect to %s", address), cf.cause()); } TransportClient client = clientRef.get(); @@ -198,7 +199,7 @@ public void close() { */ private PooledByteBufAllocator createPooledByteBufAllocator() { return new PooledByteBufAllocator( - PlatformDependent.directBufferPreferred(), + conf.preferDirectBufs() && PlatformDependent.directBufferPreferred(), getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"), getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"), getPrivateStaticField("DEFAULT_PAGE_SIZE"), diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index d8965590b34d..2044afb0d85d 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -17,6 +17,7 @@ package org.apache.spark.network.client; +import java.io.IOException; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -94,7 +95,7 @@ public void channelUnregistered() { String remoteAddress = NettyUtils.getRemoteAddress(channel); logger.error("Still have {} requests outstanding when connection from {} is closed", numOutstandingRequests(), remoteAddress); - failOutstandingRequests(new RuntimeException("Connection from " + remoteAddress + " closed")); + failOutstandingRequests(new IOException("Connection from " + remoteAddress + " closed")); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index 4cb8becc3ed2..91d1e8a538a7 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -66,7 +66,7 @@ public void encode(ChannelHandlerContext ctx, Message in, List out) { // All messages have the frame length, message type, and message itself. int headerLength = 8 + msgType.encodedLength() + in.encodedLength(); long frameLength = headerLength + bodyLength; - ByteBuf header = ctx.alloc().buffer(headerLength); + ByteBuf header = ctx.alloc().heapBuffer(headerLength); header.writeLong(frameLength); msgType.encode(header); in.encode(header); diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java index 70da48ca8ee7..579676c2c356 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -28,6 +28,7 @@ import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.socket.SocketChannel; +import io.netty.util.internal.PlatformDependent; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -71,11 +72,14 @@ private void init(int portToBind) { NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server"); EventLoopGroup workerGroup = bossGroup; + PooledByteBufAllocator allocator = new PooledByteBufAllocator( + conf.preferDirectBufs() && PlatformDependent.directBufferPreferred()); + bootstrap = new ServerBootstrap() .group(bossGroup, workerGroup) .channel(NettyUtils.getServerChannelClass(ioMode)) - .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) - .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT); + .option(ChannelOption.ALLOCATOR, allocator) + .childOption(ChannelOption.ALLOCATOR, allocator); if (conf.backLog() > 0) { bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog()); diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java index b1872341198e..2a7664fe8938 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -37,13 +37,17 @@ * Utilities for creating various Netty constructs based on whether we're using EPOLL or NIO. */ public class NettyUtils { - /** Creates a Netty EventLoopGroup based on the IOMode. */ - public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) { - - ThreadFactory threadFactory = new ThreadFactoryBuilder() + /** Creates a new ThreadFactory which prefixes each thread with the given name. */ + public static ThreadFactory createThreadFactory(String threadPoolPrefix) { + return new ThreadFactoryBuilder() .setDaemon(true) - .setNameFormat(threadPrefix + "-%d") + .setNameFormat(threadPoolPrefix + "-%d") .build(); + } + + /** Creates a Netty EventLoopGroup based on the IOMode. */ + public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) { + ThreadFactory threadFactory = createThreadFactory(threadPrefix); switch (mode) { case NIO: diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index 823790dd3c66..787a8f0031af 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -30,6 +30,11 @@ public TransportConf(ConfigProvider conf) { /** IO mode: nio or epoll */ public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); } + /** If true, we will prefer allocating off-heap byte buffers within Netty. */ + public boolean preferDirectBufs() { + return conf.getBoolean("spark.shuffle.io.preferDirectBufs", true); + } + /** Connect timeout in secs. Default 120 secs. */ public int connectionTimeoutMs() { return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000; @@ -58,4 +63,16 @@ public int connectionTimeoutMs() { /** Timeout for a single round trip of SASL token exchange, in milliseconds. */ public int saslRTTimeout() { return conf.getInt("spark.shuffle.sasl.timeout", 30000); } + + /** + * Max number of times we will try IO exceptions (such as connection timeouts) per request. + * If set to 0, we will not do any retries. + */ + public int maxIORetries() { return conf.getInt("spark.shuffle.io.maxRetries", 3); } + + /** + * Time (in milliseconds) that we will wait in order to perform a retry after an IOException. + * Only relevant if maxIORetries > 0. + */ + public int ioRetryWaitTime() { return conf.getInt("spark.shuffle.io.retryWaitMs", 5000); } } diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index 5a10fdb3842e..822bef1d81b2 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -17,6 +17,7 @@ package org.apache.spark.network; +import java.io.IOException; import java.util.concurrent.TimeoutException; import org.junit.After; @@ -57,7 +58,7 @@ public void tearDown() { } @Test - public void createAndReuseBlockClients() throws TimeoutException { + public void createAndReuseBlockClients() throws IOException { TransportClientFactory factory = context.createClientFactory(); TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); @@ -70,7 +71,7 @@ public void createAndReuseBlockClients() throws TimeoutException { } @Test - public void neverReturnInactiveClients() throws Exception { + public void neverReturnInactiveClients() throws IOException, InterruptedException { TransportClientFactory factory = context.createClientFactory(); TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); c1.close(); @@ -88,7 +89,7 @@ public void neverReturnInactiveClients() throws Exception { } @Test - public void closeBlockClientsWithFactory() throws TimeoutException { + public void closeBlockClientsWithFactory() throws IOException { TransportClientFactory factory = context.createClientFactory(); TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 3aa95d00f6b2..27884b82c8cb 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -17,6 +17,7 @@ package org.apache.spark.network.shuffle; +import java.io.IOException; import java.util.List; import com.google.common.collect.Lists; @@ -76,17 +77,33 @@ public void init(String appId) { @Override public void fetchBlocks( - String host, - int port, - String execId, + final String host, + final int port, + final String execId, String[] blockIds, BlockFetchingListener listener) { assert appId != null : "Called before init()"; logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { - TransportClient client = clientFactory.createClient(host, port); - new OneForOneBlockFetcher(client, blockIds, listener) - .start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds)); + RetryingBlockFetcher.BlockFetchStarter blockFetchStarter = + new RetryingBlockFetcher.BlockFetchStarter() { + @Override + public void createAndStart(String[] blockIds, BlockFetchingListener listener) + throws IOException { + TransportClient client = clientFactory.createClient(host, port); + new OneForOneBlockFetcher(client, blockIds, listener) + .start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds)); + } + }; + + int maxRetries = conf.maxIORetries(); + if (maxRetries > 0) { + // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's + // a bug in this code. We should remove the if statement once we're sure of the stability. + new RetryingBlockFetcher(conf, blockFetchStarter, blockIds, listener).start(); + } else { + blockFetchStarter.createAndStart(blockIds, listener); + } } catch (Exception e) { logger.error("Exception while beginning fetchBlocks", e); for (String blockId : blockIds) { @@ -108,7 +125,7 @@ public void registerWithShuffleServer( String host, int port, String execId, - ExecutorShuffleInfo executorInfo) { + ExecutorShuffleInfo executorInfo) throws IOException { assert appId != null : "Called before init()"; TransportClient client = clientFactory.createClient(host, port); byte[] registerExecutorMessage = diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 39b6f30f92ba..9e77a1f68c4b 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -51,9 +51,6 @@ public OneForOneBlockFetcher( TransportClient client, String[] blockIds, BlockFetchingListener listener) { - if (blockIds.length == 0) { - throw new IllegalArgumentException("Zero-sized blockIds array"); - } this.client = client; this.blockIds = blockIds; this.listener = listener; @@ -82,6 +79,10 @@ public void onFailure(int chunkIndex, Throwable e) { * {@link ShuffleStreamHandle}. We will send all fetch requests immediately, without throttling. */ public void start(Object openBlocksMessage) { + if (blockIds.length == 0) { + throw new IllegalArgumentException("Zero-sized blockIds array"); + } + client.sendRpc(JavaUtils.serialize(openBlocksMessage), new RpcResponseCallback() { @Override public void onSuccess(byte[] response) { @@ -95,7 +96,7 @@ public void onSuccess(byte[] response) { client.fetchChunk(streamHandle.streamId, i, chunkCallback); } } catch (Exception e) { - logger.error("Failed while starting block fetches", e); + logger.error("Failed while starting block fetches after success", e); failRemainingBlocks(blockIds, e); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java new file mode 100644 index 000000000000..f8a1a266863b --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.IOException; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import com.google.common.collect.Sets; +import com.google.common.util.concurrent.Uninterruptibles; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.network.util.TransportConf; + +/** + * Wraps another BlockFetcher with the ability to automatically retry fetches which fail due to + * IOExceptions, which we hope are due to transient network conditions. + * + * This fetcher provides stronger guarantees regarding the parent BlockFetchingListener. In + * particular, the listener will be invoked exactly once per blockId, with a success or failure. + */ +public class RetryingBlockFetcher { + + /** + * Used to initiate the first fetch for all blocks, and subsequently for retrying the fetch on any + * remaining blocks. + */ + public static interface BlockFetchStarter { + /** + * Creates a new BlockFetcher to fetch the given block ids which may do some synchronous + * bootstrapping followed by fully asynchronous block fetching. + * The BlockFetcher must eventually invoke the Listener on every input blockId, or else this + * method must throw an exception. + * + * This method should always attempt to get a new TransportClient from the + * {@link org.apache.spark.network.client.TransportClientFactory} in order to fix connection + * issues. + */ + void createAndStart(String[] blockIds, BlockFetchingListener listener) throws IOException; + } + + /** Shared executor service used for waiting and retrying. */ + private static final ExecutorService executorService = Executors.newCachedThreadPool( + NettyUtils.createThreadFactory("Block Fetch Retry")); + + private final Logger logger = LoggerFactory.getLogger(RetryingBlockFetcher.class); + + /** Used to initiate new Block Fetches on our remaining blocks. */ + private final BlockFetchStarter fetchStarter; + + /** Parent listener which we delegate all successful or permanently failed block fetches to. */ + private final BlockFetchingListener listener; + + /** Max number of times we are allowed to retry. */ + private final int maxRetries; + + /** Milliseconds to wait before each retry. */ + private final int retryWaitTime; + + // NOTE: + // All of our non-final fields are synchronized under 'this' and should only be accessed/mutated + // while inside a synchronized block. + /** Number of times we've attempted to retry so far. */ + private int retryCount = 0; + + /** + * Set of all block ids which have not been fetched successfully or with a non-IO Exception. + * A retry involves requesting every outstanding block. Note that since this is a LinkedHashSet, + * input ordering is preserved, so we always request blocks in the same order the user provided. + */ + private final LinkedHashSet outstandingBlocksIds; + + /** + * The BlockFetchingListener that is active with our current BlockFetcher. + * When we start a retry, we immediately replace this with a new Listener, which causes all any + * old Listeners to ignore all further responses. + */ + private RetryingBlockFetchListener currentListener; + + public RetryingBlockFetcher( + TransportConf conf, + BlockFetchStarter fetchStarter, + String[] blockIds, + BlockFetchingListener listener) { + this.fetchStarter = fetchStarter; + this.listener = listener; + this.maxRetries = conf.maxIORetries(); + this.retryWaitTime = conf.ioRetryWaitTime(); + this.outstandingBlocksIds = Sets.newLinkedHashSet(); + Collections.addAll(outstandingBlocksIds, blockIds); + this.currentListener = new RetryingBlockFetchListener(); + } + + /** + * Initiates the fetch of all blocks provided in the constructor, with possible retries in the + * event of transient IOExceptions. + */ + public void start() { + fetchAllOutstanding(); + } + + /** + * Fires off a request to fetch all blocks that have not been fetched successfully or permanently + * failed (i.e., by a non-IOException). + */ + private void fetchAllOutstanding() { + // Start by retrieving our shared state within a synchronized block. + String[] blockIdsToFetch; + int numRetries; + RetryingBlockFetchListener myListener; + synchronized (this) { + blockIdsToFetch = outstandingBlocksIds.toArray(new String[outstandingBlocksIds.size()]); + numRetries = retryCount; + myListener = currentListener; + } + + // Now initiate the fetch on all outstanding blocks, possibly initiating a retry if that fails. + try { + fetchStarter.createAndStart(blockIdsToFetch, myListener); + } catch (Exception e) { + logger.error(String.format("Exception while beginning fetch of %s outstanding blocks %s", + blockIdsToFetch.length, numRetries > 0 ? "(after " + numRetries + " retries)" : ""), e); + + if (shouldRetry(e)) { + initiateRetry(); + } else { + for (String bid : blockIdsToFetch) { + listener.onBlockFetchFailure(bid, e); + } + } + } + } + + /** + * Lightweight method which initiates a retry in a different thread. The retry will involve + * calling fetchAllOutstanding() after a configured wait time. + */ + private synchronized void initiateRetry() { + retryCount += 1; + currentListener = new RetryingBlockFetchListener(); + + logger.info("Retrying fetch ({}/{}) for {} outstanding blocks after {} ms", + retryCount, maxRetries, outstandingBlocksIds.size(), retryWaitTime); + + executorService.submit(new Runnable() { + @Override + public void run() { + Uninterruptibles.sleepUninterruptibly(retryWaitTime, TimeUnit.MILLISECONDS); + fetchAllOutstanding(); + } + }); + } + + /** + * Returns true if we should retry due a block fetch failure. We will retry if and only if + * the exception was an IOException and we haven't retried 'maxRetries' times already. + */ + private synchronized boolean shouldRetry(Throwable e) { + boolean isIOException = e instanceof IOException + || (e.getCause() != null && e.getCause() instanceof IOException); + boolean hasRemainingRetries = retryCount < maxRetries; + return isIOException && hasRemainingRetries; + } + + /** + * Our RetryListener intercepts block fetch responses and forwards them to our parent listener. + * Note that in the event of a retry, we will immediately replace the 'currentListener' field, + * indicating that any responses from non-current Listeners should be ignored. + */ + private class RetryingBlockFetchListener implements BlockFetchingListener { + @Override + public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { + // We will only forward this success message to our parent listener if this block request is + // outstanding and we are still the active listener. + boolean shouldForwardSuccess = false; + synchronized (RetryingBlockFetcher.this) { + if (this == currentListener && outstandingBlocksIds.contains(blockId)) { + outstandingBlocksIds.remove(blockId); + shouldForwardSuccess = true; + } + } + + // Now actually invoke the parent listener, outside of the synchronized block. + if (shouldForwardSuccess) { + listener.onBlockFetchSuccess(blockId, data); + } + } + + @Override + public void onBlockFetchFailure(String blockId, Throwable exception) { + // We will only forward this failure to our parent listener if this block request is + // outstanding, we are still the active listener, AND we cannot retry the fetch. + boolean shouldForwardFailure = false; + synchronized (RetryingBlockFetcher.this) { + if (this == currentListener && outstandingBlocksIds.contains(blockId)) { + if (shouldRetry(exception)) { + initiateRetry(); + } else { + logger.error(String.format("Failed to fetch block %s, and will not retry (%s retries)", + blockId, retryCount), exception); + outstandingBlocksIds.remove(blockId); + shouldForwardFailure = true; + } + } + } + + // Now actually invoke the parent listener, outside of the synchronized block. + if (shouldForwardFailure) { + listener.onBlockFetchFailure(blockId, exception); + } + } + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 84781207861e..d25283e46ef9 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -93,7 +93,7 @@ public void afterEach() { } @Test - public void testGoodClient() { + public void testGoodClient() throws IOException { clientFactory = context.createClientFactory( Lists.newArrayList( new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("good-key")))); @@ -119,7 +119,7 @@ public void testBadClient() { } @Test - public void testNoSaslClient() { + public void testNoSaslClient() throws IOException { clientFactory = context.createClientFactory( Lists.newArrayList()); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 71e017b9e4e7..06294fef1962 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -259,14 +259,20 @@ public void testFetchUnregisteredExecutor() throws Exception { @Test public void testFetchNoServer() throws Exception { - registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); - FetchResult execFetch = fetchBlocks("exec-0", - new String[] { "shuffle_1_0_0", "shuffle_1_0_1" }, 1 /* port */); - assertTrue(execFetch.successBlocks.isEmpty()); - assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); + System.setProperty("spark.shuffle.io.maxRetries", "0"); + try { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-0", + new String[]{"shuffle_1_0_0", "shuffle_1_0_1"}, 1 /* port */); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); + } finally { + System.clearProperty("spark.shuffle.io.maxRetries"); + } } - private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) { + private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) + throws IOException { ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false); client.init(APP_ID); client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index 4c18fcdfbcd8..848c88f743d5 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -17,6 +17,8 @@ package org.apache.spark.network.shuffle; +import java.io.IOException; + import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -54,7 +56,7 @@ public void afterEach() { } @Test - public void testValid() { + public void testValid() throws IOException { validate("my-app-id", "secret"); } @@ -77,7 +79,7 @@ public void testBadSecret() { } /** Creates an ExternalShuffleClient and attempts to register with the server. */ - private void validate(String appId, String secretKey) { + private void validate(String appId, String secretKey) throws IOException { ExternalShuffleClient client = new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true); client.init(appId); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java new file mode 100644 index 000000000000..0191fe529e1b --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.LinkedHashSet; +import java.util.Map; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Sets; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import org.mockito.stubbing.Stubber; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; +import static org.apache.spark.network.shuffle.RetryingBlockFetcher.BlockFetchStarter; + +/** + * Tests retry logic by throwing IOExceptions and ensuring that subsequent attempts are made to + * fetch the lost blocks. + */ +public class RetryingBlockFetcherSuite { + + ManagedBuffer block0 = new NioManagedBuffer(ByteBuffer.wrap(new byte[13])); + ManagedBuffer block1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); + ManagedBuffer block2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19])); + + @Before + public void beforeEach() { + System.setProperty("spark.shuffle.io.maxRetries", "2"); + System.setProperty("spark.shuffle.io.retryWaitMs", "0"); + } + + @After + public void afterEach() { + System.clearProperty("spark.shuffle.io.maxRetries"); + System.clearProperty("spark.shuffle.io.retryWaitMs"); + } + + @Test + public void testNoFailures() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // Immediately return both blocks successfully. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener).onBlockFetchSuccess("b0", block0); + verify(listener).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testUnrecoverableFailure() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // b0 throws a non-IOException error, so it will be failed without retry. + ImmutableMap.builder() + .put("b0", new RuntimeException("Ouch!")) + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener).onBlockFetchFailure(eq("b0"), (Throwable) any()); + verify(listener).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testSingleIOExceptionOnFirst() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // IOException will cause a retry. Since b0 fails, we will retry both. + ImmutableMap.builder() + .put("b0", new IOException("Connection failed or something")) + .put("b1", block1) + .build(), + ImmutableMap.builder() + .put("b0", block0) + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testSingleIOExceptionOnSecond() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // IOException will cause a retry. Since b1 fails, we will not retry b0. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new IOException("Connection failed or something")) + .build(), + ImmutableMap.builder() + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testTwoIOExceptions() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // b0's IOException will trigger retry, b1's will be ignored. + ImmutableMap.builder() + .put("b0", new IOException()) + .put("b1", new IOException()) + .build(), + // Next, b0 is successful and b1 errors again, so we just request that one. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new IOException()) + .build(), + // b1 returns successfully within 2 retries. + ImmutableMap.builder() + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testThreeIOExceptions() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // b0's IOException will trigger retry, b1's will be ignored. + ImmutableMap.builder() + .put("b0", new IOException()) + .put("b1", new IOException()) + .build(), + // Next, b0 is successful and b1 errors again, so we just request that one. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new IOException()) + .build(), + // b1 errors again, but this was the last retry + ImmutableMap.builder() + .put("b1", new IOException()) + .build(), + // This is not reached -- b1 has failed. + ImmutableMap.builder() + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verifyNoMoreInteractions(listener); + } + + @Test + public void testRetryAndUnrecoverable() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // b0's IOException will trigger retry, subsequent messages will be ignored. + ImmutableMap.builder() + .put("b0", new IOException()) + .put("b1", new RuntimeException()) + .put("b2", block2) + .build(), + // Next, b0 is successful, b1 errors unrecoverably, and b2 triggers a retry. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new RuntimeException()) + .put("b2", new IOException()) + .build(), + // b2 succeeds in its last retry. + ImmutableMap.builder() + .put("b2", block2) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verify(listener, timeout(5000)).onBlockFetchSuccess("b2", block2); + verifyNoMoreInteractions(listener); + } + + /** + * Performs a set of interactions in response to block requests from a RetryingBlockFetcher. + * Each interaction is a Map from BlockId to either ManagedBuffer or Exception. This interaction + * means "respond to the next block fetch request with these Successful buffers and these Failure + * exceptions". We verify that the expected block ids are exactly the ones requested. + * + * If multiple interactions are supplied, they will be used in order. This is useful for encoding + * retries -- the first interaction may include an IOException, which causes a retry of some + * subset of the original blocks in a second interaction. + */ + @SuppressWarnings("unchecked") + private void performInteractions(final Map[] interactions, BlockFetchingListener listener) + throws IOException { + + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + BlockFetchStarter fetchStarter = mock(BlockFetchStarter.class); + + Stubber stub = null; + + // Contains all blockIds that are referenced across all interactions. + final LinkedHashSet blockIds = Sets.newLinkedHashSet(); + + for (final Map interaction : interactions) { + blockIds.addAll(interaction.keySet()); + + Answer answer = new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + try { + // Verify that the RetryingBlockFetcher requested the expected blocks. + String[] requestedBlockIds = (String[]) invocationOnMock.getArguments()[0]; + String[] desiredBlockIds = interaction.keySet().toArray(new String[interaction.size()]); + assertArrayEquals(desiredBlockIds, requestedBlockIds); + + // Now actually invoke the success/failure callbacks on each block. + BlockFetchingListener retryListener = + (BlockFetchingListener) invocationOnMock.getArguments()[1]; + for (Map.Entry block : interaction.entrySet()) { + String blockId = block.getKey(); + Object blockValue = block.getValue(); + + if (blockValue instanceof ManagedBuffer) { + retryListener.onBlockFetchSuccess(blockId, (ManagedBuffer) blockValue); + } else if (blockValue instanceof Exception) { + retryListener.onBlockFetchFailure(blockId, (Exception) blockValue); + } else { + fail("Can only handle ManagedBuffers and Exceptions, got " + blockValue); + } + } + return null; + } catch (Throwable e) { + e.printStackTrace(); + throw e; + } + } + }; + + // This is either the first stub, or should be chained behind the prior ones. + if (stub == null) { + stub = doAnswer(answer); + } else { + stub.doAnswer(answer); + } + } + + assert stub != null; + stub.when(fetchStarter).createAndStart((String[]) any(), (BlockFetchingListener) anyObject()); + String[] blockIdArray = blockIds.toArray(new String[blockIds.size()]); + new RetryingBlockFetcher(conf, fetchStarter, blockIdArray, listener).start(); + } +} From f92e6d74910b41c5dc43285cb122b908a97e82c6 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Thu, 6 Nov 2014 19:54:32 -0800 Subject: [PATCH 044/652] [SPARK-4236] Cleanup removed applications' files in shuffle service This relies on a hook from whoever is hosting the shuffle service to invoke removeApplication() when the application is completed. Once invoked, we will clean up all the executors' shuffle directories we know about. Author: Aaron Davidson Closes #3126 from aarondav/cleanup and squashes the following commits: 33a64a9 [Aaron Davidson] Missing brace e6e428f [Aaron Davidson] Address comments 16a0d27 [Aaron Davidson] Cleanup e4df3e7 [Aaron Davidson] [SPARK-4236] Cleanup removed applications' files in shuffle service (cherry picked from commit 48a19a6dba896f7d0b637f84e114b7efbb814e51) Signed-off-by: Andrew Or --- .../scala/org/apache/spark/util/Utils.scala | 1 + .../spark/ExternalShuffleServiceSuite.scala | 5 +- .../apache/spark/network/util/JavaUtils.java | 59 ++++++++ .../shuffle/ExternalShuffleBlockHandler.java | 10 +- .../shuffle/ExternalShuffleBlockManager.java | 118 +++++++++++++-- .../shuffle/ExternalShuffleCleanupSuite.java | 142 ++++++++++++++++++ .../ExternalShuffleIntegrationSuite.java | 2 +- .../shuffle/TestShuffleDataContext.java | 4 +- 8 files changed, 319 insertions(+), 22 deletions(-) create mode 100644 network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 7caf6bcf94ef..2cbd38d72caa 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -755,6 +755,7 @@ private[spark] object Utils extends Logging { /** * Delete a file or directory and its contents recursively. * Don't follow directories if they are symlinks. + * Throws an exception if deletion is unsuccessful. */ def deleteRecursively(file: File) { if (file != null) { diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 792b9cd8b6ff..6608ed1e57b3 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -63,8 +63,9 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { rdd.count() rdd.count() - // Invalidate the registered executors, disallowing access to their shuffle blocks. - rpcHandler.clearRegisteredExecutors() + // Invalidate the registered executors, disallowing access to their shuffle blocks (without + // deleting the actual shuffle files, so we could access them without the shuffle service). + rpcHandler.applicationRemoved(sc.conf.getAppId, false /* cleanupLocalDirs */) // Now Spark will receive FetchFailed, and not retry the stage due to "spark.test.noStageRetry" // being set. diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index 2856d1c8c933..75c4a3981a24 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -22,16 +22,22 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.Closeable; +import java.io.File; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import com.google.common.base.Preconditions; import com.google.common.io.Closeables; import com.google.common.base.Charsets; import io.netty.buffer.Unpooled; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +/** + * General utilities available in the network package. Many of these are sourced from Spark's + * own Utils, just accessible within this package. + */ public class JavaUtils { private static final Logger logger = LoggerFactory.getLogger(JavaUtils.class); @@ -93,4 +99,57 @@ public static ByteBuffer stringToBytes(String s) { public static String bytesToString(ByteBuffer b) { return Unpooled.wrappedBuffer(b).toString(Charsets.UTF_8); } + + /* + * Delete a file or directory and its contents recursively. + * Don't follow directories if they are symlinks. + * Throws an exception if deletion is unsuccessful. + */ + public static void deleteRecursively(File file) throws IOException { + if (file == null) { return; } + + if (file.isDirectory() && !isSymlink(file)) { + IOException savedIOException = null; + for (File child : listFilesSafely(file)) { + try { + deleteRecursively(child); + } catch (IOException e) { + // In case of multiple exceptions, only last one will be thrown + savedIOException = e; + } + } + if (savedIOException != null) { + throw savedIOException; + } + } + + boolean deleted = file.delete(); + // Delete can also fail if the file simply did not exist. + if (!deleted && file.exists()) { + throw new IOException("Failed to delete: " + file.getAbsolutePath()); + } + } + + private static File[] listFilesSafely(File file) throws IOException { + if (file.exists()) { + File[] files = file.listFiles(); + if (files == null) { + throw new IOException("Failed to list files for dir: " + file); + } + return files; + } else { + return new File[0]; + } + } + + private static boolean isSymlink(File file) throws IOException { + Preconditions.checkNotNull(file); + File fileInCanonicalDir = null; + if (file.getParent() == null) { + fileInCanonicalDir = file; + } else { + fileInCanonicalDir = new File(file.getParentFile().getCanonicalFile(), file.getName()); + } + return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()); + } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index cd3fea85b19a..75ebf8c7b060 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -94,9 +94,11 @@ public StreamManager getStreamManager() { return streamManager; } - /** For testing, clears all executors registered with "RegisterExecutor". */ - @VisibleForTesting - public void clearRegisteredExecutors() { - blockManager.clearRegisteredExecutors(); + /** + * Removes an application (once it has been terminated), and optionally will clean up any + * local directories associated with the executors of that application in a separate thread. + */ + public void applicationRemoved(String appId, boolean cleanupLocalDirs) { + blockManager.applicationRemoved(appId, cleanupLocalDirs); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java index 6589889fe1be..98fcfb82aa5d 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java @@ -21,9 +21,15 @@ import java.io.File; import java.io.FileInputStream; import java.io.IOException; -import java.util.concurrent.ConcurrentHashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Objects; +import com.google.common.collect.Maps; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -43,13 +49,22 @@ public class ExternalShuffleBlockManager { private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockManager.class); - // Map from "appId-execId" to the executor's configuration. - private final ConcurrentHashMap executors = - new ConcurrentHashMap(); + // Map containing all registered executors' metadata. + private final ConcurrentMap executors; - // Returns an id suitable for a single executor within a single application. - private String getAppExecId(String appId, String execId) { - return appId + "-" + execId; + // Single-threaded Java executor used to perform expensive recursive directory deletion. + private final Executor directoryCleaner; + + public ExternalShuffleBlockManager() { + // TODO: Give this thread a name. + this(Executors.newSingleThreadExecutor()); + } + + // Allows tests to have more control over when directories are cleaned up. + @VisibleForTesting + ExternalShuffleBlockManager(Executor directoryCleaner) { + this.executors = Maps.newConcurrentMap(); + this.directoryCleaner = directoryCleaner; } /** Registers a new Executor with all the configuration we need to find its shuffle files. */ @@ -57,7 +72,7 @@ public void registerExecutor( String appId, String execId, ExecutorShuffleInfo executorInfo) { - String fullId = getAppExecId(appId, execId); + AppExecId fullId = new AppExecId(appId, execId); logger.info("Registered executor {} with {}", fullId, executorInfo); executors.put(fullId, executorInfo); } @@ -78,7 +93,7 @@ public ManagedBuffer getBlockData(String appId, String execId, String blockId) { int mapId = Integer.parseInt(blockIdParts[2]); int reduceId = Integer.parseInt(blockIdParts[3]); - ExecutorShuffleInfo executor = executors.get(getAppExecId(appId, execId)); + ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); if (executor == null) { throw new RuntimeException( String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); @@ -94,6 +109,56 @@ public ManagedBuffer getBlockData(String appId, String execId, String blockId) { } } + /** + * Removes our metadata of all executors registered for the given application, and optionally + * also deletes the local directories associated with the executors of that application in a + * separate thread. + * + * It is not valid to call registerExecutor() for an executor with this appId after invoking + * this method. + */ + public void applicationRemoved(String appId, boolean cleanupLocalDirs) { + logger.info("Application {} removed, cleanupLocalDirs = {}", appId, cleanupLocalDirs); + Iterator> it = executors.entrySet().iterator(); + while (it.hasNext()) { + Map.Entry entry = it.next(); + AppExecId fullId = entry.getKey(); + final ExecutorShuffleInfo executor = entry.getValue(); + + // Only touch executors associated with the appId that was removed. + if (appId.equals(fullId.appId)) { + it.remove(); + + if (cleanupLocalDirs) { + logger.info("Cleaning up executor {}'s {} local dirs", fullId, executor.localDirs.length); + + // Execute the actual deletion in a different thread, as it may take some time. + directoryCleaner.execute(new Runnable() { + @Override + public void run() { + deleteExecutorDirs(executor.localDirs); + } + }); + } + } + } + } + + /** + * Synchronously deletes each directory one at a time. + * Should be executed in its own thread, as this may take a long time. + */ + private void deleteExecutorDirs(String[] dirs) { + for (String localDir : dirs) { + try { + JavaUtils.deleteRecursively(new File(localDir)); + logger.debug("Successfully cleaned up directory: " + localDir); + } catch (Exception e) { + logger.error("Failed to delete directory: " + localDir, e); + } + } + } + /** * Hash-based shuffle data is simply stored as one file per block. * This logic is from FileShuffleBlockManager. @@ -146,9 +211,36 @@ static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) return new File(new File(localDir, String.format("%02x", subDirId)), filename); } - /** For testing, clears all registered executors. */ - @VisibleForTesting - void clearRegisteredExecutors() { - executors.clear(); + /** Simply encodes an executor's full ID, which is appId + execId. */ + private static class AppExecId { + final String appId; + final String execId; + + private AppExecId(String appId, String execId) { + this.appId = appId; + this.execId = execId; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + AppExecId appExecId = (AppExecId) o; + return Objects.equal(appId, appExecId.appId) && Objects.equal(execId, appExecId.execId); + } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .toString(); + } } } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java new file mode 100644 index 000000000000..c8ece3bc53ac --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.File; +import java.io.IOException; +import java.util.Random; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; + +import com.google.common.util.concurrent.MoreExecutors; +import org.junit.Test; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class ExternalShuffleCleanupSuite { + + // Same-thread Executor used to ensure cleanup happens synchronously in test thread. + Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); + + @Test + public void noCleanupAndCleanup() throws IOException { + TestShuffleDataContext dataContext = createSomeData(); + + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(sameThreadExecutor); + manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); + manager.applicationRemoved("app", false /* cleanup */); + + assertStillThere(dataContext); + + manager.registerExecutor("app", "exec1", dataContext.createExecutorInfo("shuffleMgr")); + manager.applicationRemoved("app", true /* cleanup */); + + assertCleanedUp(dataContext); + } + + @Test + public void cleanupUsesExecutor() throws IOException { + TestShuffleDataContext dataContext = createSomeData(); + + final AtomicBoolean cleanupCalled = new AtomicBoolean(false); + + // Executor which does nothing to ensure we're actually using it. + Executor noThreadExecutor = new Executor() { + @Override public void execute(Runnable runnable) { cleanupCalled.set(true); } + }; + + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(noThreadExecutor); + + manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); + manager.applicationRemoved("app", true); + + assertTrue(cleanupCalled.get()); + assertStillThere(dataContext); + + dataContext.cleanup(); + assertCleanedUp(dataContext); + } + + @Test + public void cleanupMultipleExecutors() throws IOException { + TestShuffleDataContext dataContext0 = createSomeData(); + TestShuffleDataContext dataContext1 = createSomeData(); + + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(sameThreadExecutor); + + manager.registerExecutor("app", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); + manager.registerExecutor("app", "exec1", dataContext1.createExecutorInfo("shuffleMgr")); + manager.applicationRemoved("app", true); + + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + } + + @Test + public void cleanupOnlyRemovedApp() throws IOException { + TestShuffleDataContext dataContext0 = createSomeData(); + TestShuffleDataContext dataContext1 = createSomeData(); + + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(sameThreadExecutor); + + manager.registerExecutor("app-0", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); + manager.registerExecutor("app-1", "exec0", dataContext1.createExecutorInfo("shuffleMgr")); + + manager.applicationRemoved("app-nonexistent", true); + assertStillThere(dataContext0); + assertStillThere(dataContext1); + + manager.applicationRemoved("app-0", true); + assertCleanedUp(dataContext0); + assertStillThere(dataContext1); + + manager.applicationRemoved("app-1", true); + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + + // Make sure it's not an error to cleanup multiple times + manager.applicationRemoved("app-1", true); + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + } + + private void assertStillThere(TestShuffleDataContext dataContext) { + for (String localDir : dataContext.localDirs) { + assertTrue(localDir + " was cleaned up prematurely", new File(localDir).exists()); + } + } + + private void assertCleanedUp(TestShuffleDataContext dataContext) { + for (String localDir : dataContext.localDirs) { + assertFalse(localDir + " wasn't cleaned up", new File(localDir).exists()); + } + } + + private TestShuffleDataContext createSomeData() throws IOException { + Random rand = new Random(123); + TestShuffleDataContext dataContext = new TestShuffleDataContext(10, 5); + + dataContext.create(); + dataContext.insertSortShuffleData(rand.nextInt(1000), rand.nextInt(1000), + new byte[][] { "ABC".getBytes(), "DEF".getBytes() } ); + dataContext.insertHashShuffleData(rand.nextInt(1000), rand.nextInt(1000) + 1000, + new byte[][] { "GHI".getBytes(), "JKLMNOPQRSTUVWXYZ".getBytes() } ); + return dataContext; + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 06294fef1962..3bea5b0f253c 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -105,7 +105,7 @@ public static void afterAll() { @After public void afterEach() { - handler.clearRegisteredExecutors(); + handler.applicationRemoved(APP_ID, false /* cleanupLocalDirs */); } class FetchResult { diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java index 442b75646744..337b5c7bdb5d 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -30,8 +30,8 @@ * and cleanup of directories that can be read by the {@link ExternalShuffleBlockManager}. */ public class TestShuffleDataContext { - private final String[] localDirs; - private final int subDirsPerLocalDir; + public final String[] localDirs; + public final int subDirsPerLocalDir; public TestShuffleDataContext(int numLocalDirs, int subDirsPerLocalDir) { this.localDirs = new String[numLocalDirs]; From 7f86c350c946ac0c44e5e70acc8b7e51bace90a4 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 6 Nov 2014 21:52:12 -0800 Subject: [PATCH 045/652] [SPARK-4204][Core][WebUI] Change Utils.exceptionString to contain the inner exceptions and make the error information in Web UI more friendly This PR fixed `Utils.exceptionString` to output the full exception information. However, the stack trace may become very huge, so I also updated the Web UI to collapse the error information by default (display the first line and clicking `+detail` will display the full info). Here are the screenshots: Stages: ![stages](https://cloud.githubusercontent.com/assets/1000778/4882441/66d8cc68-6356-11e4-8346-6318677d9470.png) Details for one stage: ![stage](https://cloud.githubusercontent.com/assets/1000778/4882513/1311043c-6357-11e4-8804-ca14240a9145.png) The full information in the gray text field is: ```Java org.apache.spark.shuffle.FetchFailedException: Connection reset by peer at org.apache.spark.shuffle.hash.BlockStoreShuffleFetcher$.org$apache$spark$shuffle$hash$BlockStoreShuffleFetcher$$unpackBlock$1(BlockStoreShuffleFetcher.scala:67) at org.apache.spark.shuffle.hash.BlockStoreShuffleFetcher$$anonfun$3.apply(BlockStoreShuffleFetcher.scala:83) at org.apache.spark.shuffle.hash.BlockStoreShuffleFetcher$$anonfun$3.apply(BlockStoreShuffleFetcher.scala:83) at scala.collection.Iterator$$anon$13.hasNext(Iterator.scala:371) at org.apache.spark.util.CompletionIterator.hasNext(CompletionIterator.scala:30) at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:39) at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:327) at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:327) at org.apache.spark.util.collection.ExternalAppendOnlyMap.insertAll(ExternalAppendOnlyMap.scala:129) at org.apache.spark.rdd.CoGroupedRDD$$anonfun$compute$5.apply(CoGroupedRDD.scala:160) at org.apache.spark.rdd.CoGroupedRDD$$anonfun$compute$5.apply(CoGroupedRDD.scala:159) at scala.collection.TraversableLike$WithFilter$$anonfun$foreach$1.apply(TraversableLike.scala:772) at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:47) at scala.collection.TraversableLike$WithFilter.foreach(TraversableLike.scala:771) at org.apache.spark.rdd.CoGroupedRDD.compute(CoGroupedRDD.scala:159) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:263) at org.apache.spark.rdd.RDD.iterator(RDD.scala:230) at org.apache.spark.rdd.MappedValuesRDD.compute(MappedValuesRDD.scala:31) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:263) at org.apache.spark.rdd.RDD.iterator(RDD.scala:230) at org.apache.spark.rdd.FlatMappedValuesRDD.compute(FlatMappedValuesRDD.scala:31) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:263) at org.apache.spark.rdd.RDD.iterator(RDD.scala:230) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:61) at org.apache.spark.scheduler.Task.run(Task.scala:56) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:189) at java.util.concurrent.ThreadPoolExecutor$Worker.runTask(ThreadPoolExecutor.java:886) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:908) at java.lang.Thread.run(Thread.java:662) Caused by: java.io.IOException: Connection reset by peer at sun.nio.ch.FileDispatcher.read0(Native Method) at sun.nio.ch.SocketDispatcher.read(SocketDispatcher.java:21) at sun.nio.ch.IOUtil.readIntoNativeBuffer(IOUtil.java:198) at sun.nio.ch.IOUtil.read(IOUtil.java:166) at sun.nio.ch.SocketChannelImpl.read(SocketChannelImpl.java:245) at io.netty.buffer.PooledUnsafeDirectByteBuf.setBytes(PooledUnsafeDirectByteBuf.java:311) at io.netty.buffer.AbstractByteBuf.writeBytes(AbstractByteBuf.java:881) at io.netty.channel.socket.nio.NioSocketChannel.doReadBytes(NioSocketChannel.java:225) at io.netty.channel.nio.AbstractNioByteChannel$NioByteUnsafe.read(AbstractNioByteChannel.java:119) at io.netty.channel.nio.NioEventLoop.processSelectedKey(NioEventLoop.java:511) at io.netty.channel.nio.NioEventLoop.processSelectedKeysOptimized(NioEventLoop.java:468) at io.netty.channel.nio.NioEventLoop.processSelectedKeys(NioEventLoop.java:382) at io.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:354) at io.netty.util.concurrent.SingleThreadEventExecutor$2.run(SingleThreadEventExecutor.java:116) ... 1 more ``` /cc aarondav Author: zsxwing Closes #3073 from zsxwing/SPARK-4204 and squashes the following commits: 176d1e3 [zsxwing] Add comments to explain the stack trace difference ca509d3 [zsxwing] Add fullStackTrace to the constructor of ExceptionFailure a07057b [zsxwing] Core style fix dfb0032 [zsxwing] Backward compatibility for old history server 1e50f71 [zsxwing] Update as per review and increase the max height of the stack trace details 94f2566 [zsxwing] Change Utils.exceptionString to contain the inner exceptions and make the error information in Web UI more friendly (cherry picked from commit 3abdb1b24aa48f21e7eed1232c01d3933873688c) Signed-off-by: Andrew Or --- .../org/apache/spark/ui/static/webui.css | 14 ++++++++ .../org/apache/spark/TaskEndReason.scala | 35 ++++++++++++++++++- .../org/apache/spark/executor/Executor.scala | 2 +- .../apache/spark/scheduler/DAGScheduler.scala | 4 +-- .../spark/shuffle/FetchFailedException.scala | 17 +++++++-- .../hash/BlockStoreShuffleFetcher.scala | 5 ++- .../org/apache/spark/ui/jobs/StagePage.scala | 32 +++++++++++++++-- .../org/apache/spark/ui/jobs/StageTable.scala | 28 +++++++++++++-- .../org/apache/spark/util/JsonProtocol.scala | 5 ++- .../scala/org/apache/spark/util/Utils.scala | 24 ++++++------- .../ui/jobs/JobProgressListenerSuite.scala | 2 +- .../apache/spark/util/JsonProtocolSuite.scala | 10 +++++- 12 files changed, 148 insertions(+), 30 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index a2220e761ac9..db57712c8350 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -120,6 +120,20 @@ pre { border: none; } +.stacktrace-details { + max-height: 300px; + overflow-y: auto; + margin: 0; + transition: max-height 0.5s ease-out, padding 0.5s ease-out; +} + +.stacktrace-details.collapsed { + max-height: 0; + padding-top: 0; + padding-bottom: 0; + border: none; +} + span.expand-additional-metrics { cursor: pointer; } diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index f45b463fb6f6..af5fd8e0ac00 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -83,15 +83,48 @@ case class FetchFailed( * :: DeveloperApi :: * Task failed due to a runtime exception. This is the most common failure case and also captures * user program exceptions. + * + * `stackTrace` contains the stack trace of the exception itself. It still exists for backward + * compatibility. It's better to use `this(e: Throwable, metrics: Option[TaskMetrics])` to + * create `ExceptionFailure` as it will handle the backward compatibility properly. + * + * `fullStackTrace` is a better representation of the stack trace because it contains the whole + * stack trace including the exception and its causes */ @DeveloperApi case class ExceptionFailure( className: String, description: String, stackTrace: Array[StackTraceElement], + fullStackTrace: String, metrics: Option[TaskMetrics]) extends TaskFailedReason { - override def toErrorString: String = Utils.exceptionString(className, description, stackTrace) + + private[spark] def this(e: Throwable, metrics: Option[TaskMetrics]) { + this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics) + } + + override def toErrorString: String = + if (fullStackTrace == null) { + // fullStackTrace is added in 1.2.0 + // If fullStackTrace is null, use the old error string for backward compatibility + exceptionString(className, description, stackTrace) + } else { + fullStackTrace + } + + /** + * Return a nice string representation of the exception, including the stack trace. + * Note: It does not include the exception's causes, and is only used for backward compatibility. + */ + private def exceptionString( + className: String, + description: String, + stackTrace: Array[StackTraceElement]): String = { + val desc = if (description == null) "" else description + val st = if (stackTrace == null) "" else stackTrace.map(" " + _).mkString("\n") + s"$className: $desc\n$st" + } } /** diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 96114571d6c7..caf4d76713d4 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -263,7 +263,7 @@ private[spark] class Executor( m.executorRunTime = serviceTime m.jvmGCTime = gcTime - startGCTime } - val reason = ExceptionFailure(t.getClass.getName, t.getMessage, t.getStackTrace, metrics) + val reason = new ExceptionFailure(t, metrics) execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) // Don't forcibly exit unless the exception was inherently fatal, to avoid diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 96114c0423a9..22449517d100 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1063,7 +1063,7 @@ class DAGScheduler( if (runningStages.contains(failedStage)) { logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + s"due to a fetch failure from $mapStage (${mapStage.name})") - markStageAsFinished(failedStage, Some("Fetch failure: " + failureMessage)) + markStageAsFinished(failedStage, Some(failureMessage)) runningStages -= failedStage } @@ -1094,7 +1094,7 @@ class DAGScheduler( handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch)) } - case ExceptionFailure(className, description, stackTrace, metrics) => + case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) => // Do nothing here, left up to the TaskScheduler to decide how to handle user failures case TaskResultLost => diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala index 0c1b6f4defdb..be184464e0ae 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala @@ -32,10 +32,21 @@ private[spark] class FetchFailedException( shuffleId: Int, mapId: Int, reduceId: Int, - message: String) - extends Exception(message) { + message: String, + cause: Throwable = null) + extends Exception(message, cause) { + + def this( + bmAddress: BlockManagerId, + shuffleId: Int, + mapId: Int, + reduceId: Int, + cause: Throwable) { + this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause) + } - def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId, message) + def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId, + Utils.exceptionString(this)) } /** diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 0d5247f4176d..e3e7434df45b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -25,7 +25,7 @@ import org.apache.spark._ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} -import org.apache.spark.util.{CompletionIterator, Utils} +import org.apache.spark.util.CompletionIterator private[hash] object BlockStoreShuffleFetcher extends Logging { def fetch[T]( @@ -64,8 +64,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { blockId match { case ShuffleBlockId(shufId, mapId, _) => val address = statuses(mapId.toInt)._1 - throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, - Utils.exceptionString(e)) + throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e) case _ => throw new SparkException( "Failed to get block " + blockId + ", which is not a shuffle block", e) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 63ed5fc4949c..250bddbe2f26 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -22,6 +22,8 @@ import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Unparsed} +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.executor.TaskMetrics import org.apache.spark.ui.{ToolTips, WebUIPage, UIUtils} import org.apache.spark.ui.jobs.UIData._ @@ -436,13 +438,37 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { {diskBytesSpilledReadable} }} - - {errorMessage.map { e =>
    {e}
    }.getOrElse("")} - + {errorMessageCell(errorMessage)} } } + private def errorMessageCell(errorMessage: Option[String]): Seq[Node] = { + val error = errorMessage.getOrElse("") + val isMultiline = error.indexOf('\n') >= 0 + // Display the first line by default + val errorSummary = StringEscapeUtils.escapeHtml4( + if (isMultiline) { + error.substring(0, error.indexOf('\n')) + } else { + error + }) + val details = if (isMultiline) { + // scalastyle:off + + +details + ++ + + // scalastyle:on + } else { + "" + } + {errorSummary}{details} + } + private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics): Long = { val totalExecutionTime = { if (info.gettingResultTime > 0) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 4ee7f08ab47a..3b4866e05956 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -22,6 +22,8 @@ import scala.xml.Text import java.util.Date +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.scheduler.StageInfo import org.apache.spark.ui.{ToolTips, UIUtils} import org.apache.spark.util.Utils @@ -195,7 +197,29 @@ private[ui] class FailedStageTable( override protected def stageRow(s: StageInfo): Seq[Node] = { val basicColumns = super.stageRow(s) - val failureReason =
    {s.failureReason.getOrElse("")}
    - basicColumns ++ failureReason + val failureReason = s.failureReason.getOrElse("") + val isMultiline = failureReason.indexOf('\n') >= 0 + // Display the first line by default + val failureReasonSummary = StringEscapeUtils.escapeHtml4( + if (isMultiline) { + failureReason.substring(0, failureReason.indexOf('\n')) + } else { + failureReason + }) + val details = if (isMultiline) { + // scalastyle:off + + +details + ++ + + // scalastyle:on + } else { + "" + } + val failureReasonHtml = {failureReasonSummary}{details} + basicColumns ++ failureReasonHtml } } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index f7ae1f7f334d..f15d0c856663 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -287,6 +287,7 @@ private[spark] object JsonProtocol { ("Class Name" -> exceptionFailure.className) ~ ("Description" -> exceptionFailure.description) ~ ("Stack Trace" -> stackTrace) ~ + ("Full Stack Trace" -> exceptionFailure.fullStackTrace) ~ ("Metrics" -> metrics) case ExecutorLostFailure(executorId) => ("Executor ID" -> executorId) @@ -637,8 +638,10 @@ private[spark] object JsonProtocol { val className = (json \ "Class Name").extract[String] val description = (json \ "Description").extract[String] val stackTrace = stackTraceFromJson(json \ "Stack Trace") + val fullStackTrace = Utils.jsonOption(json \ "Full Stack Trace"). + map(_.extract[String]).orNull val metrics = Utils.jsonOption(json \ "Metrics").map(taskMetricsFromJson) - new ExceptionFailure(className, description, stackTrace, metrics) + ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) case `taskResultLost` => TaskResultLost case `taskKilled` => TaskKilled case `executorLostFailure` => diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 2cbd38d72caa..a14d6125484f 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1599,19 +1599,19 @@ private[spark] object Utils extends Logging { .orNull } - /** Return a nice string representation of the exception, including the stack trace. */ + /** + * Return a nice string representation of the exception. It will call "printStackTrace" to + * recursively generate the stack trace including the exception and its causes. + */ def exceptionString(e: Throwable): String = { - if (e == null) "" else exceptionString(getFormattedClassName(e), e.getMessage, e.getStackTrace) - } - - /** Return a nice string representation of the exception, including the stack trace. */ - def exceptionString( - className: String, - description: String, - stackTrace: Array[StackTraceElement]): String = { - val desc = if (description == null) "" else description - val st = if (stackTrace == null) "" else stackTrace.map(" " + _).mkString("\n") - s"$className: $desc\n$st" + if (e == null) { + "" + } else { + // Use e.printStackTrace here because e.getStackTrace doesn't include the cause + val stringWriter = new StringWriter() + e.printStackTrace(new PrintWriter(stringWriter)) + stringWriter.toString + } } /** Return a thread dump of all threads' stacktraces. Used to capture dumps for the web UI */ diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 2efbae689771..2608ad4b32e1 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -116,7 +116,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc val taskFailedReasons = Seq( Resubmitted, new FetchFailed(null, 0, 0, 0, "ignored"), - new ExceptionFailure("Exception", "description", null, None), + ExceptionFailure("Exception", "description", null, null, None), TaskResultLost, TaskKilled, ExecutorLostFailure("0"), diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index aec1e409db95..39e69851e7e3 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -109,7 +109,7 @@ class JsonProtocolSuite extends FunSuite { // TaskEndReason val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19, "Some exception") - val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, None) + val exceptionFailure = new ExceptionFailure(exception, None) testTaskEndReason(Success) testTaskEndReason(Resubmitted) testTaskEndReason(fetchFailed) @@ -127,6 +127,13 @@ class JsonProtocolSuite extends FunSuite { testBlockId(StreamBlockId(1, 2L)) } + test("ExceptionFailure backward compatibility") { + val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, None) + val oldEvent = JsonProtocol.taskEndReasonToJson(exceptionFailure) + .removeField({ _._1 == "Full Stack Trace" }) + assertEquals(exceptionFailure, JsonProtocol.taskEndReasonFromJson(oldEvent)) + } + test("StageInfo backward compatibility") { val info = makeStageInfo(1, 2, 3, 4L, 5L) val newJson = JsonProtocol.stageInfoToJson(info) @@ -422,6 +429,7 @@ class JsonProtocolSuite extends FunSuite { assert(r1.className === r2.className) assert(r1.description === r2.description) assertSeqEquals(r1.stackTrace, r2.stackTrace, assertStackTraceElementEquals) + assert(r1.fullStackTrace === r2.fullStackTrace) assertOptionEquals(r1.metrics, r2.metrics, assertTaskMetricsEquals) case (TaskResultLost, TaskResultLost) => case (TaskKilled, TaskKilled) => From d6262fa05b9b7ffde00e6659810a3436e53df6b8 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Fri, 7 Nov 2014 09:42:21 -0800 Subject: [PATCH 046/652] [SPARK-4187] [Core] Switch to binary protocol for external shuffle service messages This PR elimiantes the network package's usage of the Java serializer and replaces it with Encodable, which is a lightweight binary protocol. Each message is preceded by a type id, which will allow us to change messages (by only adding new ones), or to change the format entirely by switching to a special id (such as -1). This protocol has the advantage over Java that we can guarantee that messages will remain compatible across compiled versions and JVMs, though it does not provide a clean way to do schema migration. In the future, it may be good to use a more heavy-weight serialization format like protobuf, thrift, or avro, but these all add several dependencies which are unnecessary at the present time. Additionally this unifies the RPC messages of NettyBlockTransferService and ExternalShuffleClient. Author: Aaron Davidson Closes #3146 from aarondav/free and squashes the following commits: ed1102a [Aaron Davidson] Remove some unused imports b8e2a49 [Aaron Davidson] Add appId to test 538f2a3 [Aaron Davidson] [SPARK-4187] [Core] Switch to binary protocol for external shuffle service messages (cherry picked from commit d4fa04e50d299e9cad349b3781772956453a696b) Signed-off-by: Reynold Xin --- .../spark/network/BlockTransferService.scala | 4 +- .../network/netty/NettyBlockRpcServer.scala | 31 ++--- .../netty/NettyBlockTransferService.scala | 15 ++- .../network/nio/NioBlockTransferService.scala | 1 + .../apache/spark/storage/BlockManager.scala | 5 +- .../NettyBlockTransferSecuritySuite.scala | 4 +- .../network/protocol/ChunkFetchFailure.java | 12 +- .../spark/network/protocol/Encoders.java | 93 ++++++++++++++ .../spark/network/protocol/RpcFailure.java | 12 +- .../spark/network/protocol/RpcRequest.java | 9 +- .../spark/network/protocol/RpcResponse.java | 9 +- .../apache/spark/network/util/JavaUtils.java | 27 ----- .../spark/network/sasl/SaslMessage.java | 24 ++-- .../shuffle/ExternalShuffleBlockHandler.java | 21 ++-- .../shuffle/ExternalShuffleBlockManager.java | 1 + .../shuffle/ExternalShuffleClient.java | 12 +- .../shuffle/ExternalShuffleMessages.java | 106 ---------------- .../shuffle/OneForOneBlockFetcher.java | 17 ++- .../protocol/BlockTransferMessage.java | 76 ++++++++++++ .../{ => protocol}/ExecutorShuffleInfo.java | 36 +++++- .../network/shuffle/protocol/OpenBlocks.java | 87 ++++++++++++++ .../shuffle/protocol/RegisterExecutor.java | 91 ++++++++++++++ .../StreamHandle.java} | 34 ++++-- .../network/shuffle/protocol/UploadBlock.java | 113 ++++++++++++++++++ ...e.java => BlockTransferMessagesSuite.java} | 33 ++--- .../ExternalShuffleBlockHandlerSuite.java | 29 ++--- .../ExternalShuffleIntegrationSuite.java | 1 + .../shuffle/ExternalShuffleSecuritySuite.java | 1 + .../shuffle/OneForOneBlockFetcherSuite.java | 18 +-- .../shuffle/TestShuffleDataContext.java | 2 + 30 files changed, 640 insertions(+), 284 deletions(-) create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java delete mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java rename network/shuffle/src/main/java/org/apache/spark/network/shuffle/{ => protocol}/ExecutorShuffleInfo.java (68%) create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java rename network/shuffle/src/main/java/org/apache/spark/network/shuffle/{ShuffleStreamHandle.java => protocol/StreamHandle.java} (65%) create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java rename network/shuffle/src/test/java/org/apache/spark/network/shuffle/{ShuffleMessagesSuite.java => BlockTransferMessagesSuite.java} (55%) diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 210a581db466..dcbda5a8515d 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -73,6 +73,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo def uploadBlock( hostname: String, port: Int, + execId: String, blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel): Future[Unit] @@ -110,9 +111,10 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo def uploadBlockSync( hostname: String, port: Int, + execId: String, blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel): Unit = { - Await.result(uploadBlock(hostname, port, blockId, blockData, level), Duration.Inf) + Await.result(uploadBlock(hostname, port, execId, blockId, blockData, level), Duration.Inf) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 1950e7bd634e..b089da8596e2 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -26,18 +26,10 @@ import org.apache.spark.network.BlockDataManager import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager} -import org.apache.spark.network.shuffle.ShuffleStreamHandle +import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, OpenBlocks, StreamHandle, UploadBlock} import org.apache.spark.serializer.Serializer import org.apache.spark.storage.{BlockId, StorageLevel} -object NettyMessages { - /** Request to read a set of blocks. Returns [[ShuffleStreamHandle]] to identify the stream. */ - case class OpenBlocks(blockIds: Seq[BlockId]) - - /** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */ - case class UploadBlock(blockId: BlockId, blockData: Array[Byte], level: StorageLevel) -} - /** * Serves requests to open blocks by simply registering one chunk per block requested. * Handles opening and uploading arbitrary BlockManager blocks. @@ -50,28 +42,29 @@ class NettyBlockRpcServer( blockManager: BlockDataManager) extends RpcHandler with Logging { - import NettyMessages._ - private val streamManager = new OneForOneStreamManager() override def receive( client: TransportClient, messageBytes: Array[Byte], responseContext: RpcResponseCallback): Unit = { - val ser = serializer.newInstance() - val message = ser.deserialize[AnyRef](ByteBuffer.wrap(messageBytes)) + val message = BlockTransferMessage.Decoder.fromByteArray(messageBytes) logTrace(s"Received request: $message") message match { - case OpenBlocks(blockIds) => - val blocks: Seq[ManagedBuffer] = blockIds.map(blockManager.getBlockData) + case openBlocks: OpenBlocks => + val blocks: Seq[ManagedBuffer] = + openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) val streamId = streamManager.registerStream(blocks.iterator) logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") - responseContext.onSuccess( - ser.serialize(new ShuffleStreamHandle(streamId, blocks.size)).array()) + responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray) - case UploadBlock(blockId, blockData, level) => - blockManager.putBlockData(blockId, new NioManagedBuffer(ByteBuffer.wrap(blockData)), level) + case uploadBlock: UploadBlock => + // StorageLevel is serialized as bytes using our JavaSerializer. + val level: StorageLevel = + serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata)) + val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData)) + blockManager.putBlockData(BlockId(uploadBlock.blockId), data, level) responseContext.onSuccess(new Array[Byte](0)) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index b937ea825f49..f8a7f640689a 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -24,10 +24,10 @@ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory} -import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock} import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap} import org.apache.spark.network.server._ import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher} +import org.apache.spark.network.shuffle.protocol.UploadBlock import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils @@ -46,6 +46,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage private[this] var transportContext: TransportContext = _ private[this] var server: TransportServer = _ private[this] var clientFactory: TransportClientFactory = _ + private[this] var appId: String = _ override def init(blockDataManager: BlockDataManager): Unit = { val (rpcHandler: RpcHandler, bootstrap: Option[TransportClientBootstrap]) = { @@ -60,6 +61,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage transportContext = new TransportContext(transportConf, rpcHandler) clientFactory = transportContext.createClientFactory(bootstrap.toList) server = transportContext.createServer() + appId = conf.getAppId logInfo("Server created on " + server.getPort) } @@ -74,8 +76,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { val client = clientFactory.createClient(host, port) - new OneForOneBlockFetcher(client, blockIds.toArray, listener) - .start(OpenBlocks(blockIds.map(BlockId.apply))) + new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start() } } @@ -101,12 +102,17 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage override def uploadBlock( hostname: String, port: Int, + execId: String, blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel): Future[Unit] = { val result = Promise[Unit]() val client = clientFactory.createClient(hostname, port) + // StorageLevel is serialized as bytes using our JavaSerializer. Everything else is encoded + // using our binary protocol. + val levelBytes = serializer.newInstance().serialize(level).array() + // Convert or copy nio buffer into array in order to serialize it. val nioBuffer = blockData.nioByteBuffer() val array = if (nioBuffer.hasArray) { @@ -117,8 +123,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage data } - val ser = serializer.newInstance() - client.sendRpc(ser.serialize(new UploadBlock(blockId, array, level)).array(), + client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteArray, new RpcResponseCallback { override def onSuccess(response: Array[Byte]): Unit = { logTrace(s"Successfully uploaded block $blockId") diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index f56d165daba5..b2aec160635c 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -137,6 +137,7 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa override def uploadBlock( hostname: String, port: Int, + execId: String, blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e48d7772d6ee..39434f473a9d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -35,7 +35,8 @@ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.netty.{SparkTransportConf, NettyBlockTransferService} -import org.apache.spark.network.shuffle.{ExecutorShuffleInfo, ExternalShuffleClient} +import org.apache.spark.network.shuffle.ExternalShuffleClient +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.network.util.{ConfigProvider, TransportConf} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager @@ -939,7 +940,7 @@ private[spark] class BlockManager( data.rewind() logTrace(s"Trying to replicate $blockId of ${data.limit()} bytes to $peer") blockTransferService.uploadBlockSync( - peer.host, peer.port, blockId, new NioManagedBuffer(data), tLevel) + peer.host, peer.port, peer.executorId, blockId, new NioManagedBuffer(data), tLevel) logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %s ms" .format(System.currentTimeMillis - onePeerStartTime)) peersReplicatedTo += peer diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 9162ec980166..530f5d6db5a2 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -36,7 +36,9 @@ import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite, ShouldMat class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with ShouldMatchers { test("security default off") { - testConnection(new SparkConf, new SparkConf) match { + val conf = new SparkConf() + .set("spark.app.id", "app-id") + testConnection(conf, conf) match { case Success(_) => // expected case Failure(t) => fail(t) } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java index 152af98ced7c..986957c1509f 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java @@ -38,23 +38,19 @@ public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) { @Override public int encodedLength() { - return streamChunkId.encodedLength() + 4 + errorString.getBytes(Charsets.UTF_8).length; + return streamChunkId.encodedLength() + Encoders.Strings.encodedLength(errorString); } @Override public void encode(ByteBuf buf) { streamChunkId.encode(buf); - byte[] errorBytes = errorString.getBytes(Charsets.UTF_8); - buf.writeInt(errorBytes.length); - buf.writeBytes(errorBytes); + Encoders.Strings.encode(buf, errorString); } public static ChunkFetchFailure decode(ByteBuf buf) { StreamChunkId streamChunkId = StreamChunkId.decode(buf); - int numErrorStringBytes = buf.readInt(); - byte[] errorBytes = new byte[numErrorStringBytes]; - buf.readBytes(errorBytes); - return new ChunkFetchFailure(streamChunkId, new String(errorBytes, Charsets.UTF_8)); + String errorString = Encoders.Strings.decode(buf); + return new ChunkFetchFailure(streamChunkId, errorString); } @Override diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java b/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java new file mode 100644 index 000000000000..873c69425094 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + + +import com.google.common.base.Charsets; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +/** Provides a canonical set of Encoders for simple types. */ +public class Encoders { + + /** Strings are encoded with their length followed by UTF-8 bytes. */ + public static class Strings { + public static int encodedLength(String s) { + return 4 + s.getBytes(Charsets.UTF_8).length; + } + + public static void encode(ByteBuf buf, String s) { + byte[] bytes = s.getBytes(Charsets.UTF_8); + buf.writeInt(bytes.length); + buf.writeBytes(bytes); + } + + public static String decode(ByteBuf buf) { + int length = buf.readInt(); + byte[] bytes = new byte[length]; + buf.readBytes(bytes); + return new String(bytes, Charsets.UTF_8); + } + } + + /** Byte arrays are encoded with their length followed by bytes. */ + public static class ByteArrays { + public static int encodedLength(byte[] arr) { + return 4 + arr.length; + } + + public static void encode(ByteBuf buf, byte[] arr) { + buf.writeInt(arr.length); + buf.writeBytes(arr); + } + + public static byte[] decode(ByteBuf buf) { + int length = buf.readInt(); + byte[] bytes = new byte[length]; + buf.readBytes(bytes); + return bytes; + } + } + + /** String arrays are encoded with the number of strings followed by per-String encoding. */ + public static class StringArrays { + public static int encodedLength(String[] strings) { + int totalLength = 4; + for (String s : strings) { + totalLength += Strings.encodedLength(s); + } + return totalLength; + } + + public static void encode(ByteBuf buf, String[] strings) { + buf.writeInt(strings.length); + for (String s : strings) { + Strings.encode(buf, s); + } + } + + public static String[] decode(ByteBuf buf) { + int numStrings = buf.readInt(); + String[] strings = new String[numStrings]; + for (int i = 0; i < strings.length; i ++) { + strings[i] = Strings.decode(buf); + } + return strings; + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java index e239d4ffbd29..ebd764eb5eb5 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java @@ -36,23 +36,19 @@ public RpcFailure(long requestId, String errorString) { @Override public int encodedLength() { - return 8 + 4 + errorString.getBytes(Charsets.UTF_8).length; + return 8 + Encoders.Strings.encodedLength(errorString); } @Override public void encode(ByteBuf buf) { buf.writeLong(requestId); - byte[] errorBytes = errorString.getBytes(Charsets.UTF_8); - buf.writeInt(errorBytes.length); - buf.writeBytes(errorBytes); + Encoders.Strings.encode(buf, errorString); } public static RpcFailure decode(ByteBuf buf) { long requestId = buf.readLong(); - int numErrorStringBytes = buf.readInt(); - byte[] errorBytes = new byte[numErrorStringBytes]; - buf.readBytes(errorBytes); - return new RpcFailure(requestId, new String(errorBytes, Charsets.UTF_8)); + String errorString = Encoders.Strings.decode(buf); + return new RpcFailure(requestId, errorString); } @Override diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java index 099e934ae018..cdee0b0e0316 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java @@ -44,21 +44,18 @@ public RpcRequest(long requestId, byte[] message) { @Override public int encodedLength() { - return 8 + 4 + message.length; + return 8 + Encoders.ByteArrays.encodedLength(message); } @Override public void encode(ByteBuf buf) { buf.writeLong(requestId); - buf.writeInt(message.length); - buf.writeBytes(message); + Encoders.ByteArrays.encode(buf, message); } public static RpcRequest decode(ByteBuf buf) { long requestId = buf.readLong(); - int messageLen = buf.readInt(); - byte[] message = new byte[messageLen]; - buf.readBytes(message); + byte[] message = Encoders.ByteArrays.decode(buf); return new RpcRequest(requestId, message); } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java index ed479478325b..0a62e09a8115 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java @@ -36,20 +36,17 @@ public RpcResponse(long requestId, byte[] response) { public Type type() { return Type.RpcResponse; } @Override - public int encodedLength() { return 8 + 4 + response.length; } + public int encodedLength() { return 8 + Encoders.ByteArrays.encodedLength(response); } @Override public void encode(ByteBuf buf) { buf.writeLong(requestId); - buf.writeInt(response.length); - buf.writeBytes(response); + Encoders.ByteArrays.encode(buf, response); } public static RpcResponse decode(ByteBuf buf) { long requestId = buf.readLong(); - int responseLen = buf.readInt(); - byte[] response = new byte[responseLen]; - buf.readBytes(response); + byte[] response = Encoders.ByteArrays.decode(buf); return new RpcResponse(requestId, response); } diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index 75c4a3981a24..009dbcf01323 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -50,33 +50,6 @@ public static void closeQuietly(Closeable closeable) { } } - // TODO: Make this configurable, do not use Java serialization! - public static T deserialize(byte[] bytes) { - try { - ObjectInputStream is = new ObjectInputStream(new ByteArrayInputStream(bytes)); - Object out = is.readObject(); - is.close(); - return (T) out; - } catch (ClassNotFoundException e) { - throw new RuntimeException("Could not deserialize object", e); - } catch (IOException e) { - throw new RuntimeException("Could not deserialize object", e); - } - } - - // TODO: Make this configurable, do not use Java serialization! - public static byte[] serialize(Object object) { - try { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - ObjectOutputStream os = new ObjectOutputStream(baos); - os.writeObject(object); - os.close(); - return baos.toByteArray(); - } catch (IOException e) { - throw new RuntimeException("Could not serialize object", e); - } - } - /** Returns a hash consistent with Spark's Utils.nonNegativeHash(). */ public static int nonNegativeHash(Object obj) { if (obj == null) { return 0; } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java index 599cc6428c90..cad76ab7aa54 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -17,10 +17,10 @@ package org.apache.spark.network.sasl; -import com.google.common.base.Charsets; import io.netty.buffer.ByteBuf; import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.Encoders; /** * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged @@ -42,18 +42,14 @@ public SaslMessage(String appId, byte[] payload) { @Override public int encodedLength() { - // tag + appIdLength + appId + payloadLength + payload - return 1 + 4 + appId.getBytes(Charsets.UTF_8).length + 4 + payload.length; + return 1 + Encoders.Strings.encodedLength(appId) + Encoders.ByteArrays.encodedLength(payload); } @Override public void encode(ByteBuf buf) { buf.writeByte(TAG_BYTE); - byte[] idBytes = appId.getBytes(Charsets.UTF_8); - buf.writeInt(idBytes.length); - buf.writeBytes(idBytes); - buf.writeInt(payload.length); - buf.writeBytes(payload); + Encoders.Strings.encode(buf, appId); + Encoders.ByteArrays.encode(buf, payload); } public static SaslMessage decode(ByteBuf buf) { @@ -62,14 +58,8 @@ public static SaslMessage decode(ByteBuf buf) { + " (maybe your client does not have SASL enabled?)"); } - int idLength = buf.readInt(); - byte[] idBytes = new byte[idLength]; - buf.readBytes(idBytes); - - int payloadLength = buf.readInt(); - byte[] payload = new byte[payloadLength]; - buf.readBytes(payload); - - return new SaslMessage(new String(idBytes, Charsets.UTF_8), payload); + String appId = Encoders.Strings.decode(buf); + byte[] payload = Encoders.ByteArrays.decode(buf); + return new SaslMessage(appId, payload); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 75ebf8c7b060..a6db4b2abd6c 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -24,15 +24,16 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import static org.apache.spark.network.shuffle.ExternalShuffleMessages.*; - import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.RegisterExecutor; +import org.apache.spark.network.shuffle.protocol.StreamHandle; /** * RPC Handler for a server which can serve shuffle blocks from outside of an Executor process. @@ -62,12 +63,10 @@ public ExternalShuffleBlockHandler() { @Override public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { - Object msgObj = JavaUtils.deserialize(message); - - logger.trace("Received message: " + msgObj); + BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteArray(message); - if (msgObj instanceof OpenShuffleBlocks) { - OpenShuffleBlocks msg = (OpenShuffleBlocks) msgObj; + if (msgObj instanceof OpenBlocks) { + OpenBlocks msg = (OpenBlocks) msgObj; List blocks = Lists.newArrayList(); for (String blockId : msg.blockIds) { @@ -75,8 +74,7 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback } long streamId = streamManager.registerStream(blocks.iterator()); logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length); - callback.onSuccess(JavaUtils.serialize( - new ShuffleStreamHandle(streamId, msg.blockIds.length))); + callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteArray()); } else if (msgObj instanceof RegisterExecutor) { RegisterExecutor msg = (RegisterExecutor) msgObj; @@ -84,8 +82,7 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback callback.onSuccess(new byte[0]); } else { - throw new UnsupportedOperationException(String.format( - "Unexpected message: %s (class = %s)", msgObj, msgObj.getClass())); + throw new UnsupportedOperationException("Unexpected message: " + msgObj); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java index 98fcfb82aa5d..ffb7faa3dbdc 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java @@ -35,6 +35,7 @@ import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.util.JavaUtils; /** diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 27884b82c8cb..6e8018b723dc 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -31,8 +31,8 @@ import org.apache.spark.network.sasl.SaslClientBootstrap; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.server.NoOpRpcHandler; -import org.apache.spark.network.shuffle.ExternalShuffleMessages.RegisterExecutor; -import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.shuffle.protocol.RegisterExecutor; import org.apache.spark.network.util.TransportConf; /** @@ -91,8 +91,7 @@ public void fetchBlocks( public void createAndStart(String[] blockIds, BlockFetchingListener listener) throws IOException { TransportClient client = clientFactory.createClient(host, port); - new OneForOneBlockFetcher(client, blockIds, listener) - .start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds)); + new OneForOneBlockFetcher(client, appId, execId, blockIds, listener).start(); } }; @@ -128,9 +127,8 @@ public void registerWithShuffleServer( ExecutorShuffleInfo executorInfo) throws IOException { assert appId != null : "Called before init()"; TransportClient client = clientFactory.createClient(host, port); - byte[] registerExecutorMessage = - JavaUtils.serialize(new RegisterExecutor(appId, execId, executorInfo)); - client.sendRpcSync(registerExecutorMessage, 5000 /* timeoutMs */); + byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray(); + client.sendRpcSync(registerMessage, 5000 /* timeoutMs */); } @Override diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java deleted file mode 100644 index e79420ed8254..000000000000 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle; - -import java.io.Serializable; -import java.util.Arrays; - -import com.google.common.base.Objects; - -/** Messages handled by the {@link ExternalShuffleBlockHandler}. */ -public class ExternalShuffleMessages { - - /** Request to read a set of shuffle blocks. Returns [[ShuffleStreamHandle]]. */ - public static class OpenShuffleBlocks implements Serializable { - public final String appId; - public final String execId; - public final String[] blockIds; - - public OpenShuffleBlocks(String appId, String execId, String[] blockIds) { - this.appId = appId; - this.execId = execId; - this.blockIds = blockIds; - } - - @Override - public int hashCode() { - return Objects.hashCode(appId, execId) * 41 + Arrays.hashCode(blockIds); - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("appId", appId) - .add("execId", execId) - .add("blockIds", Arrays.toString(blockIds)) - .toString(); - } - - @Override - public boolean equals(Object other) { - if (other != null && other instanceof OpenShuffleBlocks) { - OpenShuffleBlocks o = (OpenShuffleBlocks) other; - return Objects.equal(appId, o.appId) - && Objects.equal(execId, o.execId) - && Arrays.equals(blockIds, o.blockIds); - } - return false; - } - } - - /** Initial registration message between an executor and its local shuffle server. */ - public static class RegisterExecutor implements Serializable { - public final String appId; - public final String execId; - public final ExecutorShuffleInfo executorInfo; - - public RegisterExecutor( - String appId, - String execId, - ExecutorShuffleInfo executorInfo) { - this.appId = appId; - this.execId = execId; - this.executorInfo = executorInfo; - } - - @Override - public int hashCode() { - return Objects.hashCode(appId, execId, executorInfo); - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("appId", appId) - .add("execId", execId) - .add("executorInfo", executorInfo) - .toString(); - } - - @Override - public boolean equals(Object other) { - if (other != null && other instanceof RegisterExecutor) { - RegisterExecutor o = (RegisterExecutor) other; - return Objects.equal(appId, o.appId) - && Objects.equal(execId, o.execId) - && Objects.equal(executorInfo, o.executorInfo); - } - return false; - } - } -} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 9e77a1f68c4b..8ed2e0b39ad2 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -26,6 +26,9 @@ import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.StreamHandle; import org.apache.spark.network.util.JavaUtils; /** @@ -41,17 +44,21 @@ public class OneForOneBlockFetcher { private final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class); private final TransportClient client; + private final OpenBlocks openMessage; private final String[] blockIds; private final BlockFetchingListener listener; private final ChunkReceivedCallback chunkCallback; - private ShuffleStreamHandle streamHandle = null; + private StreamHandle streamHandle = null; public OneForOneBlockFetcher( TransportClient client, + String appId, + String execId, String[] blockIds, BlockFetchingListener listener) { this.client = client; + this.openMessage = new OpenBlocks(appId, execId, blockIds); this.blockIds = blockIds; this.listener = listener; this.chunkCallback = new ChunkCallback(); @@ -76,18 +83,18 @@ public void onFailure(int chunkIndex, Throwable e) { /** * Begins the fetching process, calling the listener with every block fetched. * The given message will be serialized with the Java serializer, and the RPC must return a - * {@link ShuffleStreamHandle}. We will send all fetch requests immediately, without throttling. + * {@link StreamHandle}. We will send all fetch requests immediately, without throttling. */ - public void start(Object openBlocksMessage) { + public void start() { if (blockIds.length == 0) { throw new IllegalArgumentException("Zero-sized blockIds array"); } - client.sendRpc(JavaUtils.serialize(openBlocksMessage), new RpcResponseCallback() { + client.sendRpc(openMessage.toByteArray(), new RpcResponseCallback() { @Override public void onSuccess(byte[] response) { try { - streamHandle = JavaUtils.deserialize(response); + streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response); logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle); // Immediately request all chunks -- we expect that the total size of the request is diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java new file mode 100644 index 000000000000..b4b13b8a6ef5 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.protocol.Encodable; + +/** + * Messages handled by the {@link org.apache.spark.network.shuffle.ExternalShuffleBlockHandler}, or + * by Spark's NettyBlockTransferService. + * + * At a high level: + * - OpenBlock is handled by both services, but only services shuffle files for the external + * shuffle service. It returns a StreamHandle. + * - UploadBlock is only handled by the NettyBlockTransferService. + * - RegisterExecutor is only handled by the external shuffle service. + */ +public abstract class BlockTransferMessage implements Encodable { + protected abstract Type type(); + + /** Preceding every serialized message is its type, which allows us to deserialize it. */ + public static enum Type { + OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3); + + private final byte id; + + private Type(int id) { + assert id < 128 : "Cannot have more than 128 message types"; + this.id = (byte) id; + } + + public byte id() { return id; } + } + + // NB: Java does not support static methods in interfaces, so we must put this in a static class. + public static class Decoder { + /** Deserializes the 'type' byte followed by the message itself. */ + public static BlockTransferMessage fromByteArray(byte[] msg) { + ByteBuf buf = Unpooled.wrappedBuffer(msg); + byte type = buf.readByte(); + switch (type) { + case 0: return OpenBlocks.decode(buf); + case 1: return UploadBlock.decode(buf); + case 2: return RegisterExecutor.decode(buf); + case 3: return StreamHandle.decode(buf); + default: throw new IllegalArgumentException("Unknown message type: " + type); + } + } + } + + /** Serializes the 'type' byte followed by the message itself. */ + public byte[] toByteArray() { + ByteBuf buf = Unpooled.buffer(encodedLength()); + buf.writeByte(type().id); + encode(buf); + assert buf.writableBytes() == 0 : "Writable bytes remain: " + buf.writableBytes(); + return buf.array(); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java similarity index 68% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java rename to network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java index d45e64656a0e..cadc8e8369c6 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java @@ -15,21 +15,24 @@ * limitations under the License. */ -package org.apache.spark.network.shuffle; +package org.apache.spark.network.shuffle.protocol; -import java.io.Serializable; import java.util.Arrays; import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.Encoders; /** Contains all configuration necessary for locating the shuffle files of an executor. */ -public class ExecutorShuffleInfo implements Serializable { +public class ExecutorShuffleInfo implements Encodable { /** The base set of local directories that the executor stores its shuffle files in. */ - final String[] localDirs; + public final String[] localDirs; /** Number of subdirectories created within each localDir. */ - final int subDirsPerLocalDir; + public final int subDirsPerLocalDir; /** Shuffle manager (SortShuffleManager or HashShuffleManager) that the executor is using. */ - final String shuffleManager; + public final String shuffleManager; public ExecutorShuffleInfo(String[] localDirs, int subDirsPerLocalDir, String shuffleManager) { this.localDirs = localDirs; @@ -61,4 +64,25 @@ public boolean equals(Object other) { } return false; } + + @Override + public int encodedLength() { + return Encoders.StringArrays.encodedLength(localDirs) + + 4 // int + + Encoders.Strings.encodedLength(shuffleManager); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.StringArrays.encode(buf, localDirs); + buf.writeInt(subDirsPerLocalDir); + Encoders.Strings.encode(buf, shuffleManager); + } + + public static ExecutorShuffleInfo decode(ByteBuf buf) { + String[] localDirs = Encoders.StringArrays.decode(buf); + int subDirsPerLocalDir = buf.readInt(); + String shuffleManager = Encoders.Strings.decode(buf); + return new ExecutorShuffleInfo(localDirs, subDirsPerLocalDir, shuffleManager); + } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java new file mode 100644 index 000000000000..60485bace643 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +/** Request to read a set of blocks. Returns {@link StreamHandle}. */ +public class OpenBlocks extends BlockTransferMessage { + public final String appId; + public final String execId; + public final String[] blockIds; + + public OpenBlocks(String appId, String execId, String[] blockIds) { + this.appId = appId; + this.execId = execId; + this.blockIds = blockIds; + } + + @Override + protected Type type() { return Type.OPEN_BLOCKS; } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId) * 41 + Arrays.hashCode(blockIds); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("blockIds", Arrays.toString(blockIds)) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof OpenBlocks) { + OpenBlocks o = (OpenBlocks) other; + return Objects.equal(appId, o.appId) + && Objects.equal(execId, o.execId) + && Arrays.equals(blockIds, o.blockIds); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + Encoders.StringArrays.encodedLength(blockIds); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + Encoders.StringArrays.encode(buf, blockIds); + } + + public static OpenBlocks decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + String[] blockIds = Encoders.StringArrays.decode(buf); + return new OpenBlocks(appId, execId, blockIds); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java new file mode 100644 index 000000000000..38acae3b31d6 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +/** + * Initial registration message between an executor and its local shuffle server. + * Returns nothing (empty bye array). + */ +public class RegisterExecutor extends BlockTransferMessage { + public final String appId; + public final String execId; + public final ExecutorShuffleInfo executorInfo; + + public RegisterExecutor( + String appId, + String execId, + ExecutorShuffleInfo executorInfo) { + this.appId = appId; + this.execId = execId; + this.executorInfo = executorInfo; + } + + @Override + protected Type type() { return Type.REGISTER_EXECUTOR; } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId, executorInfo); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("executorInfo", executorInfo) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof RegisterExecutor) { + RegisterExecutor o = (RegisterExecutor) other; + return Objects.equal(appId, o.appId) + && Objects.equal(execId, o.execId) + && Objects.equal(executorInfo, o.executorInfo); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + executorInfo.encodedLength(); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + executorInfo.encode(buf); + } + + public static RegisterExecutor decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + ExecutorShuffleInfo executorShuffleInfo = ExecutorShuffleInfo.decode(buf); + return new RegisterExecutor(appId, execId, executorShuffleInfo); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java similarity index 65% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java rename to network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java index 9c9469122432..21369c8cfb0d 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java @@ -15,26 +15,29 @@ * limitations under the License. */ -package org.apache.spark.network.shuffle; +package org.apache.spark.network.shuffle.protocol; import java.io.Serializable; -import java.util.Arrays; import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; /** * Identifier for a fixed number of chunks to read from a stream created by an "open blocks" - * message. This is used by {@link OneForOneBlockFetcher}. + * message. This is used by {@link org.apache.spark.network.shuffle.OneForOneBlockFetcher}. */ -public class ShuffleStreamHandle implements Serializable { +public class StreamHandle extends BlockTransferMessage { public final long streamId; public final int numChunks; - public ShuffleStreamHandle(long streamId, int numChunks) { + public StreamHandle(long streamId, int numChunks) { this.streamId = streamId; this.numChunks = numChunks; } + @Override + protected Type type() { return Type.STREAM_HANDLE; } + @Override public int hashCode() { return Objects.hashCode(streamId, numChunks); @@ -50,11 +53,28 @@ public String toString() { @Override public boolean equals(Object other) { - if (other != null && other instanceof ShuffleStreamHandle) { - ShuffleStreamHandle o = (ShuffleStreamHandle) other; + if (other != null && other instanceof StreamHandle) { + StreamHandle o = (StreamHandle) other; return Objects.equal(streamId, o.streamId) && Objects.equal(numChunks, o.numChunks); } return false; } + + @Override + public int encodedLength() { + return 8 + 4; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(streamId); + buf.writeInt(numChunks); + } + + public static StreamHandle decode(ByteBuf buf) { + long streamId = buf.readLong(); + int numChunks = buf.readInt(); + return new StreamHandle(streamId, numChunks); + } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java new file mode 100644 index 000000000000..38abe29cc585 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +/** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */ +public class UploadBlock extends BlockTransferMessage { + public final String appId; + public final String execId; + public final String blockId; + // TODO: StorageLevel is serialized separately in here because StorageLevel is not available in + // this package. We should avoid this hack. + public final byte[] metadata; + public final byte[] blockData; + + /** + * @param metadata Meta-information about block, typically StorageLevel. + * @param blockData The actual block's bytes. + */ + public UploadBlock( + String appId, + String execId, + String blockId, + byte[] metadata, + byte[] blockData) { + this.appId = appId; + this.execId = execId; + this.blockId = blockId; + this.metadata = metadata; + this.blockData = blockData; + } + + @Override + protected Type type() { return Type.UPLOAD_BLOCK; } + + @Override + public int hashCode() { + int objectsHashCode = Objects.hashCode(appId, execId, blockId); + return (objectsHashCode * 41 + Arrays.hashCode(metadata)) * 41 + Arrays.hashCode(blockData); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("blockId", blockId) + .add("metadata size", metadata.length) + .add("block size", blockData.length) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof UploadBlock) { + UploadBlock o = (UploadBlock) other; + return Objects.equal(appId, o.appId) + && Objects.equal(execId, o.execId) + && Objects.equal(blockId, o.blockId) + && Arrays.equals(metadata, o.metadata) + && Arrays.equals(blockData, o.blockData); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + Encoders.Strings.encodedLength(blockId) + + Encoders.ByteArrays.encodedLength(metadata) + + Encoders.ByteArrays.encodedLength(blockData); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + Encoders.Strings.encode(buf, blockId); + Encoders.ByteArrays.encode(buf, metadata); + Encoders.ByteArrays.encode(buf, blockData); + } + + public static UploadBlock decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + String blockId = Encoders.Strings.decode(buf); + byte[] metadata = Encoders.ByteArrays.decode(buf); + byte[] blockData = Encoders.ByteArrays.decode(buf); + return new UploadBlock(appId, execId, blockId, metadata, blockData); + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java similarity index 55% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java rename to network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java index ee9482b49cfc..d65de9ca550a 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java @@ -21,31 +21,24 @@ import static org.junit.Assert.*; -import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.shuffle.protocol.*; -import static org.apache.spark.network.shuffle.ExternalShuffleMessages.*; - -public class ShuffleMessagesSuite { +/** Verifies that all BlockTransferMessages can be serialized correctly. */ +public class BlockTransferMessagesSuite { @Test public void serializeOpenShuffleBlocks() { - OpenShuffleBlocks msg = new OpenShuffleBlocks("app-1", "exec-2", - new String[] { "block0", "block1" }); - OpenShuffleBlocks msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg)); - assertEquals(msg, msg2); + checkSerializeDeserialize(new OpenBlocks("app-1", "exec-2", new String[] { "b1", "b2" })); + checkSerializeDeserialize(new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo( + new String[] { "/local1", "/local2" }, 32, "MyShuffleManager"))); + checkSerializeDeserialize(new UploadBlock("app-1", "exec-2", "block-3", new byte[] { 1, 2 }, + new byte[] { 4, 5, 6, 7} )); + checkSerializeDeserialize(new StreamHandle(12345, 16)); } - @Test - public void serializeRegisterExecutor() { - RegisterExecutor msg = new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo( - new String[] { "/local1", "/local2" }, 32, "MyShuffleManager")); - RegisterExecutor msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg)); - assertEquals(msg, msg2); - } - - @Test - public void serializeShuffleStreamHandle() { - ShuffleStreamHandle msg = new ShuffleStreamHandle(12345, 16); - ShuffleStreamHandle msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg)); + private void checkSerializeDeserialize(BlockTransferMessage msg) { + BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteArray(msg.toByteArray()); assertEquals(msg, msg2); + assertEquals(msg.hashCode(), msg2.hashCode()); + assertEquals(msg.toString(), msg2.toString()); } } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index 7939cb4d3269..3f9fe1681cf2 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -24,8 +24,6 @@ import org.junit.Test; import org.mockito.ArgumentCaptor; -import static org.apache.spark.network.shuffle.ExternalShuffleMessages.OpenShuffleBlocks; -import static org.apache.spark.network.shuffle.ExternalShuffleMessages.RegisterExecutor; import static org.junit.Assert.*; import static org.mockito.Matchers.any; import static org.mockito.Mockito.*; @@ -36,7 +34,12 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.RegisterExecutor; +import org.apache.spark.network.shuffle.protocol.StreamHandle; +import org.apache.spark.network.shuffle.protocol.UploadBlock; public class ExternalShuffleBlockHandlerSuite { TransportClient client = mock(TransportClient.class); @@ -57,8 +60,7 @@ public void testRegisterExecutor() { RpcResponseCallback callback = mock(RpcResponseCallback.class); ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort"); - byte[] registerMessage = JavaUtils.serialize( - new RegisterExecutor("app0", "exec1", config)); + byte[] registerMessage = new RegisterExecutor("app0", "exec1", config).toByteArray(); handler.receive(client, registerMessage, callback); verify(blockManager, times(1)).registerExecutor("app0", "exec1", config); @@ -75,9 +77,8 @@ public void testOpenShuffleBlocks() { ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); when(blockManager.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker); when(blockManager.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker); - byte[] openBlocksMessage = JavaUtils.serialize( - new OpenShuffleBlocks("app0", "exec1", new String[] { "b0", "b1" })); - handler.receive(client, openBlocksMessage, callback); + byte[] openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }).toByteArray(); + handler.receive(client, openBlocks, callback); verify(blockManager, times(1)).getBlockData("app0", "exec1", "b0"); verify(blockManager, times(1)).getBlockData("app0", "exec1", "b1"); @@ -85,7 +86,8 @@ public void testOpenShuffleBlocks() { verify(callback, times(1)).onSuccess(response.capture()); verify(callback, never()).onFailure((Throwable) any()); - ShuffleStreamHandle handle = JavaUtils.deserialize(response.getValue()); + StreamHandle handle = + (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response.getValue()); assertEquals(2, handle.numChunks); ArgumentCaptor stream = ArgumentCaptor.forClass(Iterator.class); @@ -100,18 +102,17 @@ public void testOpenShuffleBlocks() { public void testBadMessages() { RpcResponseCallback callback = mock(RpcResponseCallback.class); - byte[] unserializableMessage = new byte[] { 0x12, 0x34, 0x56 }; + byte[] unserializableMsg = new byte[] { 0x12, 0x34, 0x56 }; try { - handler.receive(client, unserializableMessage, callback); + handler.receive(client, unserializableMsg, callback); fail("Should have thrown"); } catch (Exception e) { // pass } - byte[] unexpectedMessage = JavaUtils.serialize( - new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort")); + byte[] unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteArray(); try { - handler.receive(client, unexpectedMessage, callback); + handler.receive(client, unexpectedMsg, callback); fail("Should have thrown"); } catch (UnsupportedOperationException e) { // pass diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 3bea5b0f253c..687bde59fdae 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -42,6 +42,7 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index 848c88f743d5..8afceab1d585 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -31,6 +31,7 @@ import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index c18346f6966d..842741e3d354 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -40,7 +40,9 @@ import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.StreamHandle; public class OneForOneBlockFetcherSuite { @Test @@ -119,17 +121,19 @@ public void testEmptyBlockFetch() { private BlockFetchingListener fetchBlocks(final LinkedHashMap blocks) { TransportClient client = mock(TransportClient.class); BlockFetchingListener listener = mock(BlockFetchingListener.class); - String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); - OneForOneBlockFetcher fetcher = new OneForOneBlockFetcher(client, blockIds, listener); + final String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); + OneForOneBlockFetcher fetcher = + new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener); // Respond to the "OpenBlocks" message with an appropirate ShuffleStreamHandle with streamId 123 doAnswer(new Answer() { @Override public Void answer(InvocationOnMock invocationOnMock) throws Throwable { - String message = JavaUtils.deserialize((byte[]) invocationOnMock.getArguments()[0]); + BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteArray( + (byte[]) invocationOnMock.getArguments()[0]); RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1]; - callback.onSuccess(JavaUtils.serialize(new ShuffleStreamHandle(123, blocks.size()))); - assertEquals("OpenZeBlocks", message); + callback.onSuccess(new StreamHandle(123, blocks.size()).toByteArray()); + assertEquals(new OpenBlocks("app-id", "exec-id", blockIds), message); return null; } }).when(client).sendRpc((byte[]) any(), (RpcResponseCallback) any()); @@ -161,7 +165,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable { } }).when(client).fetchChunk(anyLong(), anyInt(), (ChunkReceivedCallback) any()); - fetcher.start("OpenZeBlocks"); + fetcher.start(); return listener; } } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java index 337b5c7bdb5d..76639114df5d 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -25,6 +25,8 @@ import com.google.common.io.Files; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; + /** * Manages some sort- and hash-based shuffle data, including the creation * and cleanup of directories that can be read by the {@link ExternalShuffleBlockManager}. From e5b8cea7ef219be33df1db77a0921885833a4254 Mon Sep 17 00:00:00 2001 From: wangfei Date: Fri, 7 Nov 2014 11:43:35 -0800 Subject: [PATCH 047/652] [SQL][DOC][Minor] Spark SQL Hive now support dynamic partitioning Author: wangfei Closes #3127 from scwf/patch-9 and squashes the following commits: e39a560 [wangfei] now support dynamic partitioning (cherry picked from commit 636d7bcc96b912f5b5caa91110cd55b55fa38ad8) Signed-off-by: Michael Armbrust --- docs/sql-programming-guide.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index e399fecbbc78..ffcce2c58887 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1059,7 +1059,6 @@ in Hive deployments. **Major Hive Features** -* Spark SQL does not currently support inserting to tables using dynamic partitioning. * Tables with buckets: bucket is the hash partitioning within a Hive table partition. Spark SQL doesn't support buckets yet. From 2cd8e3e2b00c6191bccfb70743df7a4c9ffd98b2 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 7 Nov 2014 11:45:25 -0800 Subject: [PATCH 048/652] [SPARK-4225][SQL] Resorts to SparkContext.version to inspect Spark version This PR resorts to `SparkContext.version` rather than META-INF/MANIFEST.MF in the assembly jar to inspect Spark version. Currently, when built with Maven, the MANIFEST.MF file in the assembly jar is incorrectly replaced by Guava 15.0 MANIFEST.MF, probably because of the assembly/shading tricks. Another related PR is #3103, which tries to fix the MANIFEST issue. Author: Cheng Lian Closes #3105 from liancheng/spark-4225 and squashes the following commits: d9585e1 [Cheng Lian] Resorts to SparkContext.version to inspect Spark version (cherry picked from commit 86e9eaa3f0ec23cb38bce67585adb2d5f484f4ee) Signed-off-by: Michael Armbrust --- .../scala/org/apache/spark/util/Utils.scala | 24 ++++++------------- .../thriftserver/SparkSQLCLIService.scala | 12 ++++------ 2 files changed, 12 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index a14d6125484f..6b85c03da533 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -21,10 +21,8 @@ import java.io._ import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer -import java.util.jar.Attributes.Name -import java.util.{Properties, Locale, Random, UUID} -import java.util.concurrent.{ThreadFactory, ConcurrentHashMap, Executors, ThreadPoolExecutor} -import java.util.jar.{Manifest => JarManifest} +import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor} +import java.util.{Locale, Properties, Random, UUID} import scala.collection.JavaConversions._ import scala.collection.Map @@ -38,11 +36,11 @@ import com.google.common.io.{ByteStreams, Files} import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.conf.Configuration -import org.apache.log4j.PropertyConfigurator import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} +import org.apache.log4j.PropertyConfigurator import org.eclipse.jetty.util.MultiException import org.json4s._ -import tachyon.client.{TachyonFile,TachyonFS} +import tachyon.client.{TachyonFS, TachyonFile} import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil @@ -352,8 +350,8 @@ private[spark] object Utils extends Logging { * Download a file to target directory. Supports fetching the file in a variety of ways, * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. * - * If `useCache` is true, first attempts to fetch the file to a local cache that's shared - * across executors running the same application. `useCache` is used mainly for + * If `useCache` is true, first attempts to fetch the file to a local cache that's shared + * across executors running the same application. `useCache` is used mainly for * the executors, and not in local mode. * * Throws SparkException if the target file already exists and has different contents than @@ -400,7 +398,7 @@ private[spark] object Utils extends Logging { } else { doFetchFile(url, targetDir, fileName, conf, securityMgr, hadoopConf) } - + // Decompress the file if it's a .tar or .tar.gz if (fileName.endsWith(".tar.gz") || fileName.endsWith(".tgz")) { logInfo("Untarring " + fileName) @@ -1776,13 +1774,6 @@ private[spark] object Utils extends Logging { s"$libraryPathEnvName=$libraryPath$ampersand" } - lazy val sparkVersion = - SparkContext.jarOfObject(this).map { path => - val manifestUrl = new URL(s"jar:file:$path!/META-INF/MANIFEST.MF") - val manifest = new JarManifest(manifestUrl.openStream()) - manifest.getMainAttributes.getValue(Name.IMPLEMENTATION_VERSION) - }.getOrElse("Unknown") - /** * Return the value of a config either through the SparkConf or the Hadoop configuration * if this is Yarn mode. In the latter case, this defaults to the value set through SparkConf @@ -1796,7 +1787,6 @@ private[spark] object Utils extends Logging { sparkValue } } - } /** diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index ecfb74473e92..499e077d7294 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -17,18 +17,16 @@ package org.apache.spark.sql.hive.thriftserver -import java.util.jar.Attributes.Name - -import scala.collection.JavaConversions._ - import java.io.IOException import java.util.{List => JList} import javax.security.auth.login.LoginException +import scala.collection.JavaConversions._ + import org.apache.commons.logging.Log -import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.shims.ShimLoader +import org.apache.hadoop.security.UserGroupInformation import org.apache.hive.service.Service.STATE import org.apache.hive.service.auth.HiveAuthFactory import org.apache.hive.service.cli._ @@ -50,7 +48,7 @@ private[hive] class SparkSQLCLIService(hiveContext: HiveContext) addService(sparkSqlSessionManager) var sparkServiceUGI: UserGroupInformation = null - if (ShimLoader.getHadoopShims().isSecurityEnabled()) { + if (ShimLoader.getHadoopShims.isSecurityEnabled) { try { HiveAuthFactory.loginFromKeytab(hiveConf) sparkServiceUGI = ShimLoader.getHadoopShims.getUGIForConf(hiveConf) @@ -68,7 +66,7 @@ private[hive] class SparkSQLCLIService(hiveContext: HiveContext) getInfoType match { case GetInfoType.CLI_SERVER_NAME => new GetInfoValue("Spark SQL") case GetInfoType.CLI_DBMS_NAME => new GetInfoValue("Spark SQL") - case GetInfoType.CLI_DBMS_VER => new GetInfoValue(Utils.sparkVersion) + case GetInfoType.CLI_DBMS_VER => new GetInfoValue(hiveContext.sparkContext.version) case _ => super.getInfo(sessionHandle, getInfoType) } } From f1f1ae418031957256e7dac896e29d64c81bf1a4 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 7 Nov 2014 11:51:20 -0800 Subject: [PATCH 049/652] [SQL] Support ScalaReflection of schema in different universes Author: Michael Armbrust Closes #3096 from marmbrus/reflectionContext and squashes the following commits: adc221f [Michael Armbrust] Support ScalaReflection of schema in different universes (cherry picked from commit 8154ed7df6c5407e638f465d3bd86b43f36216ef) Signed-off-by: Michael Armbrust --- .../spark/sql/catalyst/ScalaReflection.scala | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 9cda373623cb..71034c2c43c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -26,14 +26,26 @@ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.types.decimal.Decimal + /** - * Provides experimental support for generating catalyst schemas for scala objects. + * A default version of ScalaReflection that uses the runtime universe. */ -object ScalaReflection { +object ScalaReflection extends ScalaReflection { + val universe: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe +} + +/** + * Support for generating catalyst schemas for scala objects. + */ +trait ScalaReflection { + /** The universe we work in (runtime or macro) */ + val universe: scala.reflect.api.Universe + + import universe._ + // The Predef.Map is scala.collection.immutable.Map. // Since the map values can be mutable, we explicitly import scala.collection.Map at here. import scala.collection.Map - import scala.reflect.runtime.universe._ case class Schema(dataType: DataType, nullable: Boolean) From 51ef8ab8eca15addc476f47e04ecc578e6e9682c Mon Sep 17 00:00:00 2001 From: Jacky Li Date: Fri, 7 Nov 2014 11:52:08 -0800 Subject: [PATCH 050/652] [SQL] Modify keyword val location according to ordering 'DOUBLE' should be moved before 'ELSE' according to the ordering convension Author: Jacky Li Closes #3080 from jackylk/patch-5 and squashes the following commits: 3c11df7 [Jacky Li] [SQL] Modify keyword val location according to ordering (cherry picked from commit 68609c51ad1ab2def302df3c4a1c0bc1ec6e1075) Signed-off-by: Michael Armbrust --- .../main/scala/org/apache/spark/sql/catalyst/SqlParser.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 5e613e0f18ba..affef276c2a8 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -55,10 +55,10 @@ class SqlParser extends AbstractSparkSQLParser { protected val DECIMAL = Keyword("DECIMAL") protected val DESC = Keyword("DESC") protected val DISTINCT = Keyword("DISTINCT") + protected val DOUBLE = Keyword("DOUBLE") protected val ELSE = Keyword("ELSE") protected val END = Keyword("END") protected val EXCEPT = Keyword("EXCEPT") - protected val DOUBLE = Keyword("DOUBLE") protected val FALSE = Keyword("FALSE") protected val FIRST = Keyword("FIRST") protected val FROM = Keyword("FROM") From d530c3952131b29fd4d7a3e54496bfe634517af1 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Fri, 7 Nov 2014 11:56:40 -0800 Subject: [PATCH 051/652] [SPARK-4213][SQL] ParquetFilters - No support for LT, LTE, GT, GTE operators Following description is quoted from JIRA: When I issue a hql query against a HiveContext where my predicate uses a column of string type with one of LT, LTE, GT, or GTE operator, I get the following error: scala.MatchError: StringType (of class org.apache.spark.sql.catalyst.types.StringType$) Looking at the code in org.apache.spark.sql.parquet.ParquetFilters, StringType is absent from the corresponding functions for creating these filters. To reproduce, in a Hive 0.13.1 shell, I created the following table (at a specified DB): create table sparkbug ( id int, event string ) stored as parquet; Insert some sample data: insert into table sparkbug select 1, '2011-06-18' from limit 1; insert into table sparkbug select 2, '2012-01-01' from limit 1; Launch a spark shell and create a HiveContext to the metastore where the table above is located. import org.apache.spark.sql._ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.hive.HiveContext val hc = new HiveContext(sc) hc.setConf("spark.sql.shuffle.partitions", "10") hc.setConf("spark.sql.hive.convertMetastoreParquet", "true") hc.setConf("spark.sql.parquet.compression.codec", "snappy") import hc._ hc.hql("select * from .sparkbug where event >= '2011-12-01'") A scala.MatchError will appear in the output. Author: Kousuke Saruta Closes #3083 from sarutak/SPARK-4213 and squashes the following commits: 4ab6e56 [Kousuke Saruta] WIP b6890c6 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-4213 9a1fae7 [Kousuke Saruta] Fixed ParquetFilters so that compare Strings (cherry picked from commit 14c54f1876fcf91b5c10e80be2df5421c7328557) Signed-off-by: Michael Armbrust --- .../spark/sql/parquet/ParquetFilters.scala | 335 +++++++++++++++++- .../spark/sql/parquet/ParquetQuerySuite.scala | 40 +++ 2 files changed, 364 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index 517a5cf0029e..1e67799e8399 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -18,13 +18,15 @@ package org.apache.spark.sql.parquet import java.nio.ByteBuffer +import java.sql.{Date, Timestamp} import org.apache.hadoop.conf.Configuration +import parquet.common.schema.ColumnPath import parquet.filter2.compat.FilterCompat import parquet.filter2.compat.FilterCompat._ -import parquet.filter2.predicate.FilterPredicate -import parquet.filter2.predicate.FilterApi +import parquet.filter2.predicate.Operators.{Column, SupportsLtGt} +import parquet.filter2.predicate.{FilterApi, FilterPredicate} import parquet.filter2.predicate.FilterApi._ import parquet.io.api.Binary import parquet.column.ColumnReader @@ -33,9 +35,11 @@ import com.google.common.io.BaseEncoding import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.sql.catalyst.expressions.{Predicate => CatalystPredicate} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.parquet.ParquetColumns._ private[sql] object ParquetFilters { val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter" @@ -50,15 +54,25 @@ private[sql] object ParquetFilters { if (filters.length > 0) FilterCompat.get(filters.reduce(FilterApi.and)) else null } - def createFilter(expression: Expression): Option[CatalystFilter] ={ + def createFilter(expression: Expression): Option[CatalystFilter] = { def createEqualityFilter( name: String, literal: Literal, predicate: CatalystPredicate) = literal.dataType match { case BooleanType => - ComparisonFilter.createBooleanFilter( + ComparisonFilter.createBooleanEqualityFilter( name, - literal.value.asInstanceOf[Boolean], + literal.value.asInstanceOf[Boolean], + predicate) + case ByteType => + new ComparisonFilter( + name, + FilterApi.eq(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]), + predicate) + case ShortType => + new ComparisonFilter( + name, + FilterApi.eq(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]), predicate) case IntegerType => new ComparisonFilter( @@ -81,18 +95,49 @@ private[sql] object ParquetFilters { FilterApi.eq(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), predicate) case StringType => - ComparisonFilter.createStringFilter( + ComparisonFilter.createStringEqualityFilter( name, literal.value.asInstanceOf[String], predicate) + case BinaryType => + ComparisonFilter.createBinaryEqualityFilter( + name, + literal.value.asInstanceOf[Array[Byte]], + predicate) + case DateType => + new ComparisonFilter( + name, + FilterApi.eq(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])), + predicate) + case TimestampType => + new ComparisonFilter( + name, + FilterApi.eq(timestampColumn(name), + new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])), + predicate) + case DecimalType.Unlimited => + new ComparisonFilter( + name, + FilterApi.eq(decimalColumn(name), literal.value.asInstanceOf[Decimal]), + predicate) } def createLessThanFilter( name: String, literal: Literal, predicate: CatalystPredicate) = literal.dataType match { + case ByteType => + new ComparisonFilter( + name, + FilterApi.lt(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]), + predicate) + case ShortType => + new ComparisonFilter( + name, + FilterApi.lt(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]), + predicate) case IntegerType => - new ComparisonFilter( + new ComparisonFilter( name, FilterApi.lt(intColumn(name), literal.value.asInstanceOf[Integer]), predicate) @@ -111,11 +156,47 @@ private[sql] object ParquetFilters { name, FilterApi.lt(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), predicate) + case StringType => + ComparisonFilter.createStringLessThanFilter( + name, + literal.value.asInstanceOf[String], + predicate) + case BinaryType => + ComparisonFilter.createBinaryLessThanFilter( + name, + literal.value.asInstanceOf[Array[Byte]], + predicate) + case DateType => + new ComparisonFilter( + name, + FilterApi.lt(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])), + predicate) + case TimestampType => + new ComparisonFilter( + name, + FilterApi.lt(timestampColumn(name), + new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])), + predicate) + case DecimalType.Unlimited => + new ComparisonFilter( + name, + FilterApi.lt(decimalColumn(name), literal.value.asInstanceOf[Decimal]), + predicate) } def createLessThanOrEqualFilter( name: String, literal: Literal, predicate: CatalystPredicate) = literal.dataType match { + case ByteType => + new ComparisonFilter( + name, + FilterApi.ltEq(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]), + predicate) + case ShortType => + new ComparisonFilter( + name, + FilterApi.ltEq(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]), + predicate) case IntegerType => new ComparisonFilter( name, @@ -136,12 +217,48 @@ private[sql] object ParquetFilters { name, FilterApi.ltEq(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), predicate) + case StringType => + ComparisonFilter.createStringLessThanOrEqualFilter( + name, + literal.value.asInstanceOf[String], + predicate) + case BinaryType => + ComparisonFilter.createBinaryLessThanOrEqualFilter( + name, + literal.value.asInstanceOf[Array[Byte]], + predicate) + case DateType => + new ComparisonFilter( + name, + FilterApi.ltEq(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])), + predicate) + case TimestampType => + new ComparisonFilter( + name, + FilterApi.ltEq(timestampColumn(name), + new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])), + predicate) + case DecimalType.Unlimited => + new ComparisonFilter( + name, + FilterApi.ltEq(decimalColumn(name), literal.value.asInstanceOf[Decimal]), + predicate) } // TODO: combine these two types somehow? def createGreaterThanFilter( name: String, literal: Literal, predicate: CatalystPredicate) = literal.dataType match { + case ByteType => + new ComparisonFilter( + name, + FilterApi.gt(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]), + predicate) + case ShortType => + new ComparisonFilter( + name, + FilterApi.gt(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]), + predicate) case IntegerType => new ComparisonFilter( name, @@ -162,11 +279,47 @@ private[sql] object ParquetFilters { name, FilterApi.gt(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), predicate) + case StringType => + ComparisonFilter.createStringGreaterThanFilter( + name, + literal.value.asInstanceOf[String], + predicate) + case BinaryType => + ComparisonFilter.createBinaryGreaterThanFilter( + name, + literal.value.asInstanceOf[Array[Byte]], + predicate) + case DateType => + new ComparisonFilter( + name, + FilterApi.gt(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])), + predicate) + case TimestampType => + new ComparisonFilter( + name, + FilterApi.gt(timestampColumn(name), + new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])), + predicate) + case DecimalType.Unlimited => + new ComparisonFilter( + name, + FilterApi.gt(decimalColumn(name), literal.value.asInstanceOf[Decimal]), + predicate) } def createGreaterThanOrEqualFilter( name: String, literal: Literal, predicate: CatalystPredicate) = literal.dataType match { + case ByteType => + new ComparisonFilter( + name, + FilterApi.gtEq(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]), + predicate) + case ShortType => + new ComparisonFilter( + name, + FilterApi.gtEq(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]), + predicate) case IntegerType => new ComparisonFilter( name, @@ -187,6 +340,32 @@ private[sql] object ParquetFilters { name, FilterApi.gtEq(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), predicate) + case StringType => + ComparisonFilter.createStringGreaterThanOrEqualFilter( + name, + literal.value.asInstanceOf[String], + predicate) + case BinaryType => + ComparisonFilter.createBinaryGreaterThanOrEqualFilter( + name, + literal.value.asInstanceOf[Array[Byte]], + predicate) + case DateType => + new ComparisonFilter( + name, + FilterApi.gtEq(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])), + predicate) + case TimestampType => + new ComparisonFilter( + name, + FilterApi.gtEq(timestampColumn(name), + new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])), + predicate) + case DecimalType.Unlimited => + new ComparisonFilter( + name, + FilterApi.gtEq(decimalColumn(name), literal.value.asInstanceOf[Decimal]), + predicate) } /** @@ -221,9 +400,9 @@ private[sql] object ParquetFilters { case _ => None } } - case p @ EqualTo(left: Literal, right: NamedExpression) => + case p @ EqualTo(left: Literal, right: NamedExpression) if left.dataType != NullType => Some(createEqualityFilter(right.name, left, p)) - case p @ EqualTo(left: NamedExpression, right: Literal) => + case p @ EqualTo(left: NamedExpression, right: Literal) if right.dataType != NullType => Some(createEqualityFilter(left.name, right, p)) case p @ LessThan(left: Literal, right: NamedExpression) => Some(createLessThanFilter(right.name, left, p)) @@ -363,7 +542,7 @@ private[parquet] case class AndFilter( } private[parquet] object ComparisonFilter { - def createBooleanFilter( + def createBooleanEqualityFilter( columnName: String, value: Boolean, predicate: CatalystPredicate): CatalystFilter = @@ -372,7 +551,7 @@ private[parquet] object ComparisonFilter { FilterApi.eq(booleanColumn(columnName), value.asInstanceOf[java.lang.Boolean]), predicate) - def createStringFilter( + def createStringEqualityFilter( columnName: String, value: String, predicate: CatalystPredicate): CatalystFilter = @@ -380,4 +559,138 @@ private[parquet] object ComparisonFilter { columnName, FilterApi.eq(binaryColumn(columnName), Binary.fromString(value)), predicate) + + def createStringLessThanFilter( + columnName: String, + value: String, + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.lt(binaryColumn(columnName), Binary.fromString(value)), + predicate) + + def createStringLessThanOrEqualFilter( + columnName: String, + value: String, + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.ltEq(binaryColumn(columnName), Binary.fromString(value)), + predicate) + + def createStringGreaterThanFilter( + columnName: String, + value: String, + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.gt(binaryColumn(columnName), Binary.fromString(value)), + predicate) + + def createStringGreaterThanOrEqualFilter( + columnName: String, + value: String, + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.gtEq(binaryColumn(columnName), Binary.fromString(value)), + predicate) + + def createBinaryEqualityFilter( + columnName: String, + value: Array[Byte], + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.eq(binaryColumn(columnName), Binary.fromByteArray(value)), + predicate) + + def createBinaryLessThanFilter( + columnName: String, + value: Array[Byte], + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.lt(binaryColumn(columnName), Binary.fromByteArray(value)), + predicate) + + def createBinaryLessThanOrEqualFilter( + columnName: String, + value: Array[Byte], + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.ltEq(binaryColumn(columnName), Binary.fromByteArray(value)), + predicate) + + def createBinaryGreaterThanFilter( + columnName: String, + value: Array[Byte], + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.gt(binaryColumn(columnName), Binary.fromByteArray(value)), + predicate) + + def createBinaryGreaterThanOrEqualFilter( + columnName: String, + value: Array[Byte], + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.gtEq(binaryColumn(columnName), Binary.fromByteArray(value)), + predicate) +} + +private[spark] object ParquetColumns { + + def byteColumn(columnPath: String): ByteColumn = { + new ByteColumn(ColumnPath.fromDotString(columnPath)) + } + + final class ByteColumn(columnPath: ColumnPath) + extends Column[java.lang.Byte](columnPath, classOf[java.lang.Byte]) with SupportsLtGt + + def shortColumn(columnPath: String): ShortColumn = { + new ShortColumn(ColumnPath.fromDotString(columnPath)) + } + + final class ShortColumn(columnPath: ColumnPath) + extends Column[java.lang.Short](columnPath, classOf[java.lang.Short]) with SupportsLtGt + + + def dateColumn(columnPath: String): DateColumn = { + new DateColumn(ColumnPath.fromDotString(columnPath)) + } + + final class DateColumn(columnPath: ColumnPath) + extends Column[WrappedDate](columnPath, classOf[WrappedDate]) with SupportsLtGt + + def timestampColumn(columnPath: String): TimestampColumn = { + new TimestampColumn(ColumnPath.fromDotString(columnPath)) + } + + final class TimestampColumn(columnPath: ColumnPath) + extends Column[WrappedTimestamp](columnPath, classOf[WrappedTimestamp]) with SupportsLtGt + + def decimalColumn(columnPath: String): DecimalColumn = { + new DecimalColumn(ColumnPath.fromDotString(columnPath)) + } + + final class DecimalColumn(columnPath: ColumnPath) + extends Column[Decimal](columnPath, classOf[Decimal]) with SupportsLtGt + + final class WrappedDate(val date: Date) extends Comparable[WrappedDate] { + + override def compareTo(other: WrappedDate): Int = { + date.compareTo(other.date) + } + } + + final class WrappedTimestamp(val timestamp: Timestamp) extends Comparable[WrappedTimestamp] { + + override def compareTo(other: WrappedTimestamp): Int = { + timestamp.compareTo(other.timestamp) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 08d9da27f1b1..3cccafe92d4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -619,6 +619,46 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA fail(s"optional Int value in result row $i should be ${6*i}") } } + + val query12 = sql("SELECT * FROM testfiltersource WHERE mystring >= \"50\"") + assert( + query12.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result12 = query12.collect() + assert(result12.size === 54) + assert(result12(0).getString(2) == "6") + assert(result12(4).getString(2) == "50") + assert(result12(53).getString(2) == "99") + + val query13 = sql("SELECT * FROM testfiltersource WHERE mystring > \"50\"") + assert( + query13.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result13 = query13.collect() + assert(result13.size === 53) + assert(result13(0).getString(2) == "6") + assert(result13(4).getString(2) == "51") + assert(result13(52).getString(2) == "99") + + val query14 = sql("SELECT * FROM testfiltersource WHERE mystring <= \"50\"") + assert( + query14.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result14 = query14.collect() + assert(result14.size === 148) + assert(result14(0).getString(2) == "0") + assert(result14(46).getString(2) == "50") + assert(result14(147).getString(2) == "200") + + val query15 = sql("SELECT * FROM testfiltersource WHERE mystring < \"50\"") + assert( + query15.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result15 = query15.collect() + assert(result15.size === 147) + assert(result15(0).getString(2) == "0") + assert(result15(46).getString(2) == "100") + assert(result15(146).getString(2) == "200") } test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { From ff1a0825637690b3fce780d4dcaad68dce382fb9 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Fri, 7 Nov 2014 12:15:53 -0800 Subject: [PATCH 052/652] [SPARK-4272] [SQL] Add more unwrapper functions for primitive type in TableReader Currently, the data "unwrap" only support couple of primitive types, not all, it will not cause exception, but may get some performance in table scanning for the type like binary, date, timestamp, decimal etc. Author: Cheng Hao Closes #3136 from chenghao-intel/table_reader and squashes the following commits: fffb729 [Cheng Hao] fix bug for retrieving the timestamp object e9c97a4 [Cheng Hao] Add more unwrapper functions for primitive type in TableReader (cherry picked from commit 60ab80f501b8384ddf48a9ac0ba0c2b9eb548b28) Signed-off-by: Michael Armbrust --- .../apache/spark/sql/hive/HiveInspectors.scala | 4 ---- .../org/apache/spark/sql/hive/TableReader.scala | 15 +++++++++++++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 58815daa8227..bdc7e1dac192 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -115,10 +115,6 @@ private[hive] trait HiveInspectors { } - /** - * Wraps with Hive types based on object inspector. - * TODO: Consolidate all hive OI/data interface code. - */ /** * Wraps with Hive types based on object inspector. * TODO: Consolidate all hive OI/data interface code. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index e49f0957d188..f60bc3788e3e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -290,6 +290,21 @@ private[hive] object HadoopTableReader extends HiveInspectors { (value: Any, row: MutableRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value)) case oi: DoubleObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) + case oi: HiveVarcharObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => + row.setString(ordinal, oi.getPrimitiveJavaObject(value).getValue) + case oi: HiveDecimalObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => + row.update(ordinal, HiveShim.toCatalystDecimal(oi, value)) + case oi: TimestampObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => + row.update(ordinal, oi.getPrimitiveJavaObject(value).clone()) + case oi: DateObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => + row.update(ordinal, oi.getPrimitiveJavaObject(value)) + case oi: BinaryObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => + row.update(ordinal, oi.getPrimitiveJavaObject(value)) case oi => (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrap(value, oi) } From 684d1f0ecd77d639557b4ca3c26ced950c9ab9fc Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 7 Nov 2014 12:30:47 -0800 Subject: [PATCH 053/652] [SPARK-4270][SQL] Fix Cast from DateType to DecimalType. `Cast` from `DateType` to `DecimalType` throws `NullPointerException`. Author: Takuya UESHIN Closes #3134 from ueshin/issues/SPARK-4270 and squashes the following commits: 7394e4b [Takuya UESHIN] Fix Cast from DateType to DecimalType. (cherry picked from commit a6405c5ddcda112f8efd7d50d8e5f44f78a0fa41) Signed-off-by: Michael Armbrust --- .../scala/org/apache/spark/sql/catalyst/expressions/Cast.scala | 2 +- .../sql/catalyst/expressions/ExpressionEvaluationSuite.scala | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 22009666196a..55319e7a7910 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -281,7 +281,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case BooleanType => buildCast[Boolean](_, b => changePrecision(if (b) Decimal(1) else Decimal(0), target)) case DateType => - buildCast[Date](_, d => changePrecision(null, target)) // date can't cast to decimal in Hive + buildCast[Date](_, d => null) // date can't cast to decimal in Hive case TimestampType => // Note that we lose precision here. buildCast[Timestamp](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 6bfa0dbd65ba..918996f11da2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -412,6 +412,8 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Cast(d, LongType), null) checkEvaluation(Cast(d, FloatType), null) checkEvaluation(Cast(d, DoubleType), null) + checkEvaluation(Cast(d, DecimalType.Unlimited), null) + checkEvaluation(Cast(d, DecimalType(10, 2)), null) checkEvaluation(Cast(d, StringType), "1970-01-01") checkEvaluation(Cast(Cast(d, TimestampType), StringType), "1970-01-01 00:00:00") } From c96da3676c32579d0f97347d35d95353b1d2ef07 Mon Sep 17 00:00:00 2001 From: Matthew Taylor Date: Fri, 7 Nov 2014 12:53:08 -0800 Subject: [PATCH 054/652] [SPARK-4203][SQL] Partition directories in random order when inserting into hive table When doing an insert into hive table with partitions the folders written to the file system are in a random order instead of the order defined in table creation. Seems that the loadPartition method in Hive.java has a Map parameter but expects to be called with a map that has a defined ordering such as LinkedHashMap. Working on a test but having intillij problems Author: Matthew Taylor Closes #3076 from tbfenet/partition_dir_order_problem and squashes the following commits: f1b9a52 [Matthew Taylor] Comment format fix bca709f [Matthew Taylor] review changes 0e50f6b [Matthew Taylor] test fix 99f1a31 [Matthew Taylor] partition ordering fix 369e618 [Matthew Taylor] partition ordering fix (cherry picked from commit ac70c972a51952f801fd02dd5962c0a0c1aba8f8) Signed-off-by: Michael Armbrust --- .../hive/execution/InsertIntoHiveTable.scala | 13 +++++-- .../sql/hive/InsertIntoHiveTableSuite.scala | 34 +++++++++++++++++-- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 74b4e7aaa47a..81390f626726 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.execution +import java.util + import scala.collection.JavaConversions._ import org.apache.hadoop.hive.common.`type`.HiveVarchar @@ -203,6 +205,13 @@ case class InsertIntoHiveTable( // holdDDLTime will be true when TOK_HOLD_DDLTIME presents in the query as a hint. val holdDDLTime = false if (partition.nonEmpty) { + + // loadPartition call orders directories created on the iteration order of the this map + val orderedPartitionSpec = new util.LinkedHashMap[String,String]() + table.hiveQlTable.getPartCols().foreach{ + entry=> + orderedPartitionSpec.put(entry.getName,partitionSpec.get(entry.getName).getOrElse("")) + } val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols, partitionSpec) db.validatePartitionNameCharacters(partVals) // inheritTableSpecs is set to true. It should be set to false for a IMPORT query @@ -214,7 +223,7 @@ case class InsertIntoHiveTable( db.loadDynamicPartitions( outputPath, qualifiedTableName, - partitionSpec, + orderedPartitionSpec, overwrite, numDynamicPartitions, holdDDLTime, @@ -224,7 +233,7 @@ case class InsertIntoHiveTable( db.loadPartition( outputPath, qualifiedTableName, - partitionSpec, + orderedPartitionSpec, overwrite, holdDDLTime, inheritTableSpecs, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 18dc937dd2b2..5dbfb923139f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql._ +import java.io.File + +import com.google.common.io.Files +import org.apache.spark.sql.{QueryTest, _} import org.apache.spark.sql.hive.test.TestHive /* Implicits */ @@ -91,4 +93,32 @@ class InsertIntoHiveTableSuite extends QueryTest { sql("DROP TABLE hiveTableWithMapValue") } + + test("SPARK-4203:random partition directory order") { + createTable[TestData]("tmp_table") + val tmpDir = Files.createTempDir() + sql(s"CREATE TABLE table_with_partition(c1 string) PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string) location '${tmpDir.toURI.toString}' ") + sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='1') SELECT 'blarr' FROM tmp_table") + sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='2') SELECT 'blarr' FROM tmp_table") + sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='3') SELECT 'blarr' FROM tmp_table") + sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='4') SELECT 'blarr' FROM tmp_table") + def listFolders(path: File, acc: List[String]): List[List[String]] = { + val dir = path.listFiles() + val folders = dir.filter(_.isDirectory).toList + if (folders.isEmpty) { + List(acc.reverse) + } else { + folders.flatMap(x => listFolders(x, x.getName :: acc)) + } + } + val expected = List( + "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=2"::Nil, + "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=3"::Nil , + "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=1"::Nil , + "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=4"::Nil + ) + assert(listFolders(tmpDir,List()).sortBy(_.toString()) == expected.sortBy(_.toString)) + sql("DROP TABLE table_with_partition") + sql("DROP TABLE tmp_table") + } } From 47bd8f3020149a009f605e8390c2c28f3f835191 Mon Sep 17 00:00:00 2001 From: wangfei Date: Fri, 7 Nov 2014 12:55:11 -0800 Subject: [PATCH 055/652] [SPARK-4292][SQL] Result set iterator bug in JDBC/ODBC select * from src, get the wrong result set as follows: ``` ... | 309 | val_309 | | 309 | val_309 | | 309 | val_309 | | 309 | val_309 | | 309 | val_309 | | 309 | val_309 | | 309 | val_309 | | 309 | val_309 | | 309 | val_309 | | 309 | val_309 | | 97 | val_97 | | 97 | val_97 | | 97 | val_97 | | 97 | val_97 | | 97 | val_97 | | 97 | val_97 | | 97 | val_97 | | 97 | val_97 | | 97 | val_97 | | 97 | val_97 | | 97 | val_97 | ... ``` Author: wangfei Closes #3149 from scwf/SPARK-4292 and squashes the following commits: 1574a43 [wangfei] using result.collect 8b2d845 [wangfei] adding test f64eddf [wangfei] result set iter bug (cherry picked from commit d6e55524437026c0c76addeba8f99249a8316716) Signed-off-by: Michael Armbrust --- .../thriftserver/HiveThriftServer2Suite.scala | 23 +++++++++++++++++++ .../spark/sql/hive/thriftserver/Shim12.scala | 5 ++-- .../spark/sql/hive/thriftserver/Shim13.scala | 5 ++-- 3 files changed, 27 insertions(+), 6 deletions(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala index 65d910a0c3ff..bba29b2bdca4 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -267,4 +267,27 @@ class HiveThriftServer2Suite extends FunSuite with Logging { assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") } } + + test("SPARK-4292 regression: result set iterator issue") { + withJdbcStatement() { statement => + val dataFilePath = + Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt") + + val queries = Seq( + "DROP TABLE IF EXISTS test_4292", + "CREATE TABLE test_4292(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test_4292") + + queries.foreach(statement.execute) + + val resultSet = statement.executeQuery("SELECT key FROM test_4292") + + Seq(238, 86, 311, 27, 165).foreach { key => + resultSet.next() + assert(resultSet.getInt(1) == key) + } + + statement.executeQuery("DROP TABLE IF EXISTS test_4292") + } + } } diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala index 8077d0ec46fd..e3ba9914c6cc 100644 --- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala +++ b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala @@ -202,13 +202,12 @@ private[hive] class SparkExecuteStatementOperation( hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) } iter = { - val resultRdd = result.queryExecution.toRdd val useIncrementalCollect = hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean if (useIncrementalCollect) { - resultRdd.toLocalIterator + result.toLocalIterator } else { - resultRdd.collect().iterator + result.collect().iterator } } dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala index 2c1983de1d0d..f2ceba828296 100644 --- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala +++ b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala @@ -87,13 +87,12 @@ private[hive] class SparkExecuteStatementOperation( val groupId = round(random * 1000000).toString hiveContext.sparkContext.setJobGroup(groupId, statement) iter = { - val resultRdd = result.queryExecution.toRdd val useIncrementalCollect = hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean if (useIncrementalCollect) { - resultRdd.toLocalIterator + result.toLocalIterator } else { - resultRdd.collect().iterator + result.collect().iterator } } dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray From 8cefb63c122e7c7cf4af959f9606f4491148d9f4 Mon Sep 17 00:00:00 2001 From: xiao321 <1042460381@qq.com> Date: Fri, 7 Nov 2014 12:56:49 -0800 Subject: [PATCH 056/652] Update JavaCustomReceiver.java MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 数组下标越界 Author: xiao321 <1042460381@qq.com> Closes #3153 from xiao321/patch-1 and squashes the following commits: 0ed17b5 [xiao321] Update JavaCustomReceiver.java (cherry picked from commit 7c9ec529a3483fab48f728481dd1d3663369e50a) Signed-off-by: Tathagata Das --- .../org/apache/spark/examples/streaming/JavaCustomReceiver.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java index 981bc4f0613a..99df259b4e8e 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java @@ -70,7 +70,7 @@ public static void main(String[] args) { // Create a input stream with the custom receiver on target ip:port and count the // words in input stream of \n delimited text (eg. generated by 'nc') JavaReceiverInputDStream lines = ssc.receiverStream( - new JavaCustomReceiver(args[1], Integer.parseInt(args[2]))); + new JavaCustomReceiver(args[0], Integer.parseInt(args[1]))); JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override public Iterable call(String x) { From 3b07c483aa98965ac9dc8fdcc40e593e4edb97fd Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 7 Nov 2014 20:53:03 -0800 Subject: [PATCH 057/652] [SPARK-4304] [PySpark] Fix sort on empty RDD This PR fix sortBy()/sortByKey() on empty RDD. This should be back ported into 1.1/1.2 Author: Davies Liu Closes #3162 from davies/fix_sort and squashes the following commits: 84f64b7 [Davies Liu] add tests 52995b5 [Davies Liu] fix sortByKey() on empty RDD (cherry picked from commit 7779109796c90d789464ab0be35917f963bbe867) Signed-off-by: Josh Rosen --- python/pyspark/rdd.py | 2 ++ python/pyspark/tests.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 879655dc53f4..08d047402625 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -521,6 +521,8 @@ def sortPartition(iterator): # the key-space into bins such that the bins have roughly the same # number of (key, value) pairs falling into them rddSize = self.count() + if not rddSize: + return self # empty RDD maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner fraction = min(maxSampleSize / max(rddSize, 1), 1.0) samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 9f625c5c6ca4..491e445a216b 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -649,6 +649,9 @@ def test_distinct(self): self.assertEquals(result.getNumPartitions(), 5) self.assertEquals(result.count(), 3) + def test_sort_on_empty_rdd(self): + self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect()) + def test_sample(self): rdd = self.sc.parallelize(range(0, 100), 4) wo = rdd.sample(False, 0.1, 2).collect() From 427d7911f527e00e75dec0498b4bbdbe164db7ca Mon Sep 17 00:00:00 2001 From: Michelangelo D'Agostino Date: Fri, 7 Nov 2014 22:53:01 -0800 Subject: [PATCH 058/652] [MLLIB] [PYTHON] SPARK-4221: Expose nonnegative ALS in the python API SPARK-1553 added alternating nonnegative least squares to MLLib, however it's not possible to access it via the python API. This pull request resolves that. Author: Michelangelo D'Agostino Closes #3095 from mdagost/python_nmf and squashes the following commits: a6743ad [Michelangelo D'Agostino] Use setters instead of static methods in PythonMLLibAPI. Remove the new static methods I added. Set seed in tests. Change ratings to ratingsRDD in both train and trainImplicit for consistency. 7cffd39 [Michelangelo D'Agostino] Swapped nonnegative and seed in a few more places. 3fdc851 [Michelangelo D'Agostino] Moved seed to the end of the python parameter list. bdcc154 [Michelangelo D'Agostino] Change seed type to java.lang.Long so that it can handle null. cedf043 [Michelangelo D'Agostino] Added in ability to set the seed from python and made that play nice with the nonnegative changes. Also made the python ALS tests more exact. a72fdc9 [Michelangelo D'Agostino] Expose nonnegative ALS in the python API. (cherry picked from commit 7e9d975676d56ace0e84c2200137e4cd4eba074a) Signed-off-by: Xiangrui Meng --- .../mllib/api/python/PythonMLLibAPI.scala | 39 +++++++++++++++--- python/pyspark/mllib/recommendation.py | 40 ++++++++++++------- 2 files changed, 58 insertions(+), 21 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index d832ae34b55e..70d7138e3060 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -275,12 +275,25 @@ class PythonMLLibAPI extends Serializable { * the Py4J documentation. */ def trainALSModel( - ratings: JavaRDD[Rating], + ratingsJRDD: JavaRDD[Rating], rank: Int, iterations: Int, lambda: Double, - blocks: Int): MatrixFactorizationModel = { - new MatrixFactorizationModelWrapper(ALS.train(ratings.rdd, rank, iterations, lambda, blocks)) + blocks: Int, + nonnegative: Boolean, + seed: java.lang.Long): MatrixFactorizationModel = { + + val als = new ALS() + .setRank(rank) + .setIterations(iterations) + .setLambda(lambda) + .setBlocks(blocks) + .setNonnegative(nonnegative) + + if (seed != null) als.setSeed(seed) + + val model = als.run(ratingsJRDD.rdd) + new MatrixFactorizationModelWrapper(model) } /** @@ -295,9 +308,23 @@ class PythonMLLibAPI extends Serializable { iterations: Int, lambda: Double, blocks: Int, - alpha: Double): MatrixFactorizationModel = { - new MatrixFactorizationModelWrapper( - ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha)) + alpha: Double, + nonnegative: Boolean, + seed: java.lang.Long): MatrixFactorizationModel = { + + val als = new ALS() + .setImplicitPrefs(true) + .setRank(rank) + .setIterations(iterations) + .setLambda(lambda) + .setBlocks(blocks) + .setAlpha(alpha) + .setNonnegative(nonnegative) + + if (seed != null) als.setSeed(seed) + + val model = als.run(ratingsJRDD.rdd) + new MatrixFactorizationModelWrapper(model) } /** diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index e8b998414d31..e26b152e0cdf 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -44,31 +44,39 @@ class MatrixFactorizationModel(JavaModelWrapper): >>> r2 = (1, 2, 2.0) >>> r3 = (2, 1, 2.0) >>> ratings = sc.parallelize([r1, r2, r3]) - >>> model = ALS.trainImplicit(ratings, 1) - >>> model.predict(2,2) is not None - True + >>> model = ALS.trainImplicit(ratings, 1, seed=10) + >>> model.predict(2,2) + 0.4473... >>> testset = sc.parallelize([(1, 2), (1, 1)]) - >>> model = ALS.train(ratings, 1) - >>> model.predictAll(testset).count() == 2 - True + >>> model = ALS.train(ratings, 1, seed=10) + >>> model.predictAll(testset).collect() + [Rating(1, 1, 1), Rating(1, 2, 1)] - >>> model = ALS.train(ratings, 4) - >>> model.userFeatures().count() == 2 - True + >>> model = ALS.train(ratings, 4, seed=10) + >>> model.userFeatures().collect() + [(2, array('d', [...])), (1, array('d', [...]))] >>> first_user = model.userFeatures().take(1)[0] >>> latents = first_user[1] >>> len(latents) == 4 True - >>> model.productFeatures().count() == 2 - True + >>> model.productFeatures().collect() + [(2, array('d', [...])), (1, array('d', [...]))] >>> first_product = model.productFeatures().take(1)[0] >>> latents = first_product[1] >>> len(latents) == 4 True + + >>> model = ALS.train(ratings, 1, nonnegative=True, seed=10) + >>> model.predict(2,2) + 3.735... + + >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10) + >>> model.predict(2,2) + 0.4473... """ def predict(self, user, product): return self._java_model.predict(user, product) @@ -101,15 +109,17 @@ def _prepare(cls, ratings): return _to_java_object_rdd(ratings, True) @classmethod - def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1): + def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, nonnegative=False, + seed=None): model = callMLlibFunc("trainALSModel", cls._prepare(ratings), rank, iterations, - lambda_, blocks) + lambda_, blocks, nonnegative, seed) return MatrixFactorizationModel(model) @classmethod - def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01): + def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01, + nonnegative=False, seed=None): model = callMLlibFunc("trainImplicitALSModel", cls._prepare(ratings), rank, - iterations, lambda_, blocks, alpha) + iterations, lambda_, blocks, alpha, nonnegative, seed) return MatrixFactorizationModel(model) From fc51de3395f25983052ae9d3c5c17891f6e6b8a7 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 7 Nov 2014 23:16:13 -0800 Subject: [PATCH 059/652] [SPARK-4291][Build] Rename network module projects The names of the recently introduced network modules are inconsistent with those of the other modules in the project. We should just drop the "Code" suffix since it doesn't sacrifice any meaning, especially before they get into an official release. ``` [INFO] Reactor Build Order: [INFO] [INFO] Spark Project Parent POM [INFO] Spark Project Common Network Code [INFO] Spark Project Shuffle Streaming Service Code [INFO] Spark Project Core [INFO] Spark Project Bagel [INFO] Spark Project GraphX [INFO] Spark Project Streaming [INFO] Spark Project Catalyst [INFO] Spark Project SQL [INFO] Spark Project ML Library [INFO] Spark Project Tools [INFO] Spark Project Hive [INFO] Spark Project REPL [INFO] Spark Project YARN Parent POM [INFO] Spark Project YARN Stable API [INFO] Spark Project Assembly [INFO] Spark Project External Twitter [INFO] Spark Project External Kafka [INFO] Spark Project External Flume Sink [INFO] Spark Project External Flume [INFO] Spark Project External ZeroMQ [INFO] Spark Project External MQTT [INFO] Spark Project Examples [INFO] Spark Project Yarn Shuffle Service Code ``` Author: Andrew Or Closes #3148 from andrewor14/build-drop-code and squashes the following commits: eac839b [Andrew Or] Network -> Networking d01ad47 [Andrew Or] Rename network module project names (cherry picked from commit 7afc8564f33eb2868f458f85046f59a51b516ed6) Signed-off-by: Patrick Wendell --- network/common/pom.xml | 2 +- network/shuffle/pom.xml | 2 +- network/yarn/pom.xml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/network/common/pom.xml b/network/common/pom.xml index 6144548a8f99..8b24ebf1ba1f 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -29,7 +29,7 @@ org.apache.spark spark-network-common_2.10 jar - Spark Project Common Network Code + Spark Project Networking http://spark.apache.org/ network-common diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index fe5681d46349..27c8467687f1 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -29,7 +29,7 @@ org.apache.spark spark-network-shuffle_2.10 jar - Spark Project Shuffle Streaming Service Code + Spark Project Shuffle Streaming Service http://spark.apache.org/ network-shuffle diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml index e60d8c1f7876..6e6f6f3e7929 100644 --- a/network/yarn/pom.xml +++ b/network/yarn/pom.xml @@ -29,7 +29,7 @@ org.apache.spark spark-network-yarn_2.10 jar - Spark Project Yarn Shuffle Service Code + Spark Project YARN Shuffle Service http://spark.apache.org/ network-yarn From 05bffcc023989fb09281e59cbc094f6990527c51 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Sat, 8 Nov 2014 13:03:51 -0800 Subject: [PATCH 060/652] [Minor] [Core] Don't NPE on closeQuietly(null) Author: Aaron Davidson Closes #3166 from aarondav/closeQuietlyer and squashes the following commits: 78096b5 [Aaron Davidson] Don't NPE on closeQuietly(null) (cherry picked from commit 4af5c7e24455246c61c1f3c22225507e720d721d) Signed-off-by: Reynold Xin --- .../main/java/org/apache/spark/network/util/JavaUtils.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index 009dbcf01323..bf8a1fc42fc6 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -44,7 +44,9 @@ public class JavaUtils { /** Closes the given object, ignoring IOExceptions. */ public static void closeQuietly(Closeable closeable) { try { - closeable.close(); + if (closeable != null) { + closeable.close(); + } } catch (IOException e) { logger.error("IOException should not have been thrown.", e); } From 21b9ac062f9b9c4db7596195f8b3731596a16c9f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 8 Nov 2014 18:10:23 -0800 Subject: [PATCH 061/652] [SPARK-4301] StreamingContext should not allow start() to be called after calling stop() In Spark 1.0.0+, calling `stop()` on a StreamingContext that has not been started is a no-op which has no side-effects. This allows users to call `stop()` on a fresh StreamingContext followed by `start()`. I believe that this almost always indicates an error and is not behavior that we should support. Since we don't allow `start() stop() start()` then I don't think it makes sense to allow `stop() start()`. The current behavior can lead to resource leaks when StreamingContext constructs its own SparkContext: if I call `stop(stopSparkContext=True)`, then I expect StreamingContext's underlying SparkContext to be stopped irrespective of whether the StreamingContext has been started. This is useful when writing unit test fixtures. Prior discussions: - https://github.com/apache/spark/pull/3053#discussion-diff-19710333R490 - https://github.com/apache/spark/pull/3121#issuecomment-61927353 Author: Josh Rosen Closes #3160 from JoshRosen/SPARK-4301 and squashes the following commits: dbcc929 [Josh Rosen] Address more review comments bdbe5da [Josh Rosen] Stop SparkContext after stopping scheduler, not before. 03e9c40 [Josh Rosen] Always stop SparkContext, even if stop(false) has already been called. 832a7f4 [Josh Rosen] Address review comment 5142517 [Josh Rosen] Add tests; improve Scaladoc. 813e471 [Josh Rosen] Revert workaround added in https://github.com/apache/spark/pull/3053/files#diff-e144dbee130ed84f9465853ddce65f8eR49 5558e70 [Josh Rosen] StreamingContext.stop() should stop SparkContext even if StreamingContext has not been started yet. (cherry picked from commit 7b41b17f3296eea3282efbdceb6b28baf128287d) Signed-off-by: Tathagata Das --- .../spark/streaming/StreamingContext.scala | 38 ++++++++++--------- .../streaming/StreamingContextSuite.scala | 25 +++++++++--- 2 files changed, 40 insertions(+), 23 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 23d6d1c5e50f..54b219711efb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -436,10 +436,10 @@ class StreamingContext private[streaming] ( /** * Start the execution of the streams. + * + * @throws SparkException if the context has already been started or stopped. */ def start(): Unit = synchronized { - // Throw exception if the context has already been started once - // or if a stopped context is being started again if (state == Started) { throw new SparkException("StreamingContext has already been started") } @@ -472,8 +472,10 @@ class StreamingContext private[streaming] ( /** * Stop the execution of the streams immediately (does not wait for all received data * to be processed). - * @param stopSparkContext Stop the associated SparkContext or not * + * @param stopSparkContext if true, stops the associated SparkContext. The underlying SparkContext + * will be stopped regardless of whether this StreamingContext has been + * started. */ def stop(stopSparkContext: Boolean = true): Unit = synchronized { stop(stopSparkContext, false) @@ -482,25 +484,27 @@ class StreamingContext private[streaming] ( /** * Stop the execution of the streams, with option of ensuring all received data * has been processed. - * @param stopSparkContext Stop the associated SparkContext or not - * @param stopGracefully Stop gracefully by waiting for the processing of all + * + * @param stopSparkContext if true, stops the associated SparkContext. The underlying SparkContext + * will be stopped regardless of whether this StreamingContext has been + * started. + * @param stopGracefully if true, stops gracefully by waiting for the processing of all * received data to be completed */ def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = synchronized { - // Warn (but not fail) if context is stopped twice, - // or context is stopped before starting - if (state == Initialized) { - logWarning("StreamingContext has not been started yet") - return + state match { + case Initialized => logWarning("StreamingContext has not been started yet") + case Stopped => logWarning("StreamingContext has already been stopped") + case Started => + scheduler.stop(stopGracefully) + logInfo("StreamingContext stopped successfully") + waiter.notifyStop() } - if (state == Stopped) { - logWarning("StreamingContext has already been stopped") - return - } // no need to throw an exception as its okay to stop twice - scheduler.stop(stopGracefully) - logInfo("StreamingContext stopped successfully") - waiter.notifyStop() + // Even if the streaming context has not been started, we still need to stop the SparkContext. + // Even if we have already stopped, we still need to attempt to stop the SparkContext because + // a user might stop(stopSparkContext = false) and then call stop(stopSparkContext = true). if (stopSparkContext) sc.stop() + // The state should always be Stopped after calling `stop()`, even if we haven't started yet: state = Stopped } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index f47772947d67..4b49c4d25164 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -46,10 +46,6 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w after { if (ssc != null) { ssc.stop() - if (ssc.sc != null) { - // Calling ssc.stop() does not always stop the associated SparkContext. - ssc.sc.stop() - } ssc = null } if (sc != null) { @@ -137,11 +133,16 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w ssc.stop() } - test("stop before start and start after stop") { + test("stop before start") { ssc = new StreamingContext(master, appName, batchDuration) addInputStream(ssc).register() ssc.stop() // stop before start should not throw exception - ssc.start() + } + + test("start after stop") { + // Regression test for SPARK-4301 + ssc = new StreamingContext(master, appName, batchDuration) + addInputStream(ssc).register() ssc.stop() intercept[SparkException] { ssc.start() // start after stop should throw exception @@ -161,6 +162,18 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w ssc.stop() } + test("stop(stopSparkContext=true) after stop(stopSparkContext=false)") { + ssc = new StreamingContext(master, appName, batchDuration) + addInputStream(ssc).register() + ssc.stop(stopSparkContext = false) + assert(ssc.sc.makeRDD(1 to 100).collect().size === 100) + ssc.stop(stopSparkContext = true) + // Check that the SparkContext is actually stopped: + intercept[Exception] { + ssc.sc.makeRDD(1 to 100).collect() + } + } + test("stop gracefully") { val conf = new SparkConf().setMaster(master).setAppName(appName) conf.set("spark.cleaner.ttl", "3600") From 6824af0c3a29aa2d11606495c4a95915233ba96e Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 9 Nov 2014 17:40:48 -0800 Subject: [PATCH 062/652] SPARK-971 [DOCS] Link to Confluence wiki from project website / documentation This is a trivial change to add links to the wiki from `README.md` and the main docs page. It is already linked to from spark.apache.org. Author: Sean Owen Closes #3169 from srowen/SPARK-971 and squashes the following commits: dcb84d0 [Sean Owen] Add link to wiki from README, docs home page (cherry picked from commit 8c99a47a4f0369ff3c1ecaeb860fa61ee789e987) Signed-off-by: Patrick Wendell --- README.md | 3 ++- docs/index.md | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 9916ac7b1ae8..8d57d50da96c 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,8 @@ and Spark Streaming for stream processing. ## Online Documentation You can find the latest Spark documentation, including a programming -guide, on the [project web page](http://spark.apache.org/documentation.html). +guide, on the [project web page](http://spark.apache.org/documentation.html) +and [project wiki](https://cwiki.apache.org/confluence/display/SPARK). This README file only contains basic setup instructions. ## Building Spark diff --git a/docs/index.md b/docs/index.md index edd622ec90f6..171d6ddad62f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -112,6 +112,7 @@ options for deployment: **External Resources:** * [Spark Homepage](http://spark.apache.org) +* [Spark Wiki](https://cwiki.apache.org/confluence/display/SPARK) * [Mailing Lists](http://spark.apache.org/mailing-lists.html): ask questions about Spark here * [AMP Camps](http://ampcamp.berkeley.edu/): a series of training camps at UC Berkeley that featured talks and exercises about Spark, Spark Streaming, Mesos, and more. [Videos](http://ampcamp.berkeley.edu/3/), From a9debe8fe19fc980d860a41d77f53ac21fb49d0c Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 9 Nov 2014 17:42:08 -0800 Subject: [PATCH 063/652] SPARK-1344 [DOCS] Scala API docs for top methods Use "k" in javadoc of top and takeOrdered to avoid confusion with type K in pair RDDs. I think this resolves the discussion in SPARK-1344. Author: Sean Owen Closes #3168 from srowen/SPARK-1344 and squashes the following commits: 6963fcc [Sean Owen] Use "k" in javadoc of top and takeOrdered to avoid confusion with type K in pair RDDs (cherry picked from commit d1362659ef5d62db2c9ff0d2a24639abcef4e118) Signed-off-by: Patrick Wendell --- .../org/apache/spark/api/java/JavaRDDLike.scala | 16 ++++++++-------- .../main/scala/org/apache/spark/rdd/RDD.scala | 8 ++++---- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index efb8978f7ce1..5a8e5bb1f721 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -493,9 +493,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { } /** - * Returns the top K elements from this RDD as defined by + * Returns the top k (largest) elements from this RDD as defined by * the specified Comparator[T]. - * @param num the number of top elements to return + * @param num k, the number of top elements to return * @param comp the comparator that defines the order * @return an array of top elements */ @@ -507,9 +507,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { } /** - * Returns the top K elements from this RDD using the + * Returns the top k (largest) elements from this RDD using the * natural ordering for T. - * @param num the number of top elements to return + * @param num k, the number of top elements to return * @return an array of top elements */ def top(num: Int): JList[T] = { @@ -518,9 +518,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { } /** - * Returns the first K elements from this RDD as defined by + * Returns the first k (smallest) elements from this RDD as defined by * the specified Comparator[T] and maintains the order. - * @param num the number of top elements to return + * @param num k, the number of elements to return * @param comp the comparator that defines the order * @return an array of top elements */ @@ -552,9 +552,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { } /** - * Returns the first K elements from this RDD using the + * Returns the first k (smallest) elements from this RDD using the * natural ordering for T while maintain the order. - * @param num the number of top elements to return + * @param num k, the number of top elements to return * @return an array of top elements */ def takeOrdered(num: Int): JList[T] = { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index c169b2d3fe97..716f2dd17733 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1096,7 +1096,7 @@ abstract class RDD[T: ClassTag]( } /** - * Returns the top K (largest) elements from this RDD as defined by the specified + * Returns the top k (largest) elements from this RDD as defined by the specified * implicit Ordering[T]. This does the opposite of [[takeOrdered]]. For example: * {{{ * sc.parallelize(Seq(10, 4, 2, 12, 3)).top(1) @@ -1106,14 +1106,14 @@ abstract class RDD[T: ClassTag]( * // returns Array(6, 5) * }}} * - * @param num the number of top elements to return + * @param num k, the number of top elements to return * @param ord the implicit ordering for T * @return an array of top elements */ def top(num: Int)(implicit ord: Ordering[T]): Array[T] = takeOrdered(num)(ord.reverse) /** - * Returns the first K (smallest) elements from this RDD as defined by the specified + * Returns the first k (smallest) elements from this RDD as defined by the specified * implicit Ordering[T] and maintains the ordering. This does the opposite of [[top]]. * For example: * {{{ @@ -1124,7 +1124,7 @@ abstract class RDD[T: ClassTag]( * // returns Array(2, 3) * }}} * - * @param num the number of top elements to return + * @param num k, the number of elements to return * @param ord the implicit ordering for T * @return an array of top elements */ From 42d19aec13a290984def7287411262c434cb6a69 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 9 Nov 2014 22:11:20 -0800 Subject: [PATCH 064/652] SPARK-1209 [CORE] (Take 2) SparkHadoop{MapRed,MapReduce}Util should not use package org.apache.hadoop andrewor14 Another try at SPARK-1209, to address https://github.com/apache/spark/pull/2814#issuecomment-61197619 I successfully tested with `mvn -Dhadoop.version=1.0.4 -DskipTests clean package; mvn -Dhadoop.version=1.0.4 test` I assume that is what failed Jenkins last time. I also tried `-Dhadoop.version1.2.1` and `-Phadoop-2.4 -Pyarn -Phive` for more coverage. So this is why the class was put in `org.apache.hadoop` to begin with, I assume. One option is to leave this as-is for now and move it only when Hadoop 1.0.x support goes away. This is the other option, which adds a call to force the constructor to be public at run-time. It's probably less surprising than putting Spark code in `org.apache.hadoop`, but, does involve reflection. A `SecurityManager` might forbid this, but it would forbid a lot of stuff Spark does. This would also only affect Hadoop 1.0.x it seems. Author: Sean Owen Closes #3048 from srowen/SPARK-1209 and squashes the following commits: 0d48f4b [Sean Owen] For Hadoop 1.0.x, make certain constructors public, which were public in later versions 466e179 [Sean Owen] Disable MIMA warnings resulting from moving the class -- this was also part of the PairRDDFunctions type hierarchy though? eb61820 [Sean Owen] Move SparkHadoopMapRedUtil / SparkHadoopMapReduceUtil from org.apache.hadoop to org.apache.spark (cherry picked from commit f8e5732307dcb1482d9bcf1162a1090ef9a7b913) Signed-off-by: Patrick Wendell --- .../org/apache/spark/SparkHadoopWriter.scala | 1 + .../mapred/SparkHadoopMapRedUtil.scala | 17 +++++++++++++++-- .../mapreduce/SparkHadoopMapReduceUtil.scala | 5 +++-- .../org/apache/spark/rdd/NewHadoopRDD.scala | 1 + .../org/apache/spark/rdd/PairRDDFunctions.scala | 3 ++- project/MimaExcludes.scala | 8 ++++++++ .../sql/parquet/ParquetTableOperations.scala | 1 + .../spark/sql/hive/hiveWriterContainers.scala | 1 + 8 files changed, 32 insertions(+), 5 deletions(-) rename core/src/main/scala/org/apache/{hadoop => spark}/mapred/SparkHadoopMapRedUtil.scala (79%) rename core/src/main/scala/org/apache/{hadoop => spark}/mapreduce/SparkHadoopMapReduceUtil.scala (96%) diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 376e69cd997d..40237596570d 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.mapred._ import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.fs.Path +import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.HadoopRDD /** diff --git a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala similarity index 79% rename from core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala rename to core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 0c47afae54c8..21b782edd2a9 100644 --- a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -15,15 +15,24 @@ * limitations under the License. */ -package org.apache.hadoop.mapred +package org.apache.spark.mapred -private[apache] +import java.lang.reflect.Modifier + +import org.apache.hadoop.mapred.{TaskAttemptID, JobID, JobConf, JobContext, TaskAttemptContext} + +private[spark] trait SparkHadoopMapRedUtil { def newJobContext(conf: JobConf, jobId: JobID): JobContext = { val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl", "org.apache.hadoop.mapred.JobContext") val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[org.apache.hadoop.mapreduce.JobID]) + // In Hadoop 1.0.x, JobContext is an interface, and JobContextImpl is package private. + // Make it accessible if it's not in order to access it. + if (!Modifier.isPublic(ctor.getModifiers)) { + ctor.setAccessible(true) + } ctor.newInstance(conf, jobId).asInstanceOf[JobContext] } @@ -31,6 +40,10 @@ trait SparkHadoopMapRedUtil { val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl", "org.apache.hadoop.mapred.TaskAttemptContext") val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[TaskAttemptID]) + // See above + if (!Modifier.isPublic(ctor.getModifiers)) { + ctor.setAccessible(true) + } ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext] } diff --git a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala similarity index 96% rename from core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala rename to core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala index 1fca5729c609..3340673f9115 100644 --- a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala @@ -15,13 +15,14 @@ * limitations under the License. */ -package org.apache.hadoop.mapreduce +package org.apache.spark.mapreduce import java.lang.{Boolean => JBoolean, Integer => JInteger} import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapreduce.{JobContext, JobID, TaskAttemptContext, TaskAttemptID} -private[apache] +private[spark] trait SparkHadoopMapReduceUtil { def newJobContext(conf: Configuration, jobId: JobID): JobContext = { val klass = firstAvailableClass( diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 351e145f96f9..e55d03d391e0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -35,6 +35,7 @@ import org.apache.spark.Partition import org.apache.spark.SerializableWritable import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.executor.{DataReadMethod, InputMetrics} +import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.util.Utils import org.apache.spark.deploy.SparkHadoopUtil diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index da89f634abae..462f0d6268a8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -33,13 +33,14 @@ import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat, -RecordWriter => NewRecordWriter, SparkHadoopMapReduceUtil} +RecordWriter => NewRecordWriter} import org.apache.spark._ import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer import org.apache.spark.util.Utils diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 6a0495f8fd54..a94d09be3bec 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -77,6 +77,14 @@ object MimaExcludes { // SPARK-3822 ProblemFilters.exclude[IncompatibleResultTypeProblem]( "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler") + ) ++ Seq( + // SPARK-1209 + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.hadoop.mapred.SparkHadoopMapRedUtil"), + ProblemFilters.exclude[MissingTypesProblem]( + "org.apache.spark.rdd.PairRDDFunctions") ) case v if v.startsWith("1.1") => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index d00860a8bb8a..74c43e053b03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -43,6 +43,7 @@ import parquet.hadoop.util.ContextUtil import parquet.io.ParquetDecodingException import parquet.schema.MessageType +import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.SQLConf diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index bf2ce9df67c5..cc8bb3e172c6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -31,6 +31,7 @@ import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred._ +import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.Row import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} import org.apache.spark.sql.hive.{ShimFileSinkDesc => FileSinkDesc} From fb36cf9ea8d55dfe3119f6f5d8bd3e98ce68ce21 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Sun, 9 Nov 2014 22:29:03 -0800 Subject: [PATCH 065/652] SPARK-3179. Add task OutputMetrics. Author: Sandy Ryza This patch had conflicts when merged, resolved by Committer: Kay Ousterhout Closes #2968 from sryza/sandy-spark-3179 and squashes the following commits: dce4784 [Sandy Ryza] More review feedback 8d350d1 [Sandy Ryza] Fix test against Hadoop 2.5+ e7c74d0 [Sandy Ryza] More review feedback 6cff9c4 [Sandy Ryza] Review feedback fb2dde0 [Sandy Ryza] SPARK-3179 (cherry picked from commit 3c2cff4b9464f8d7535564fcd194631a8e5bb0a5) Signed-off-by: Kay Ousterhout --- .../apache/spark/deploy/SparkHadoopUtil.scala | 46 ++++++- .../apache/spark/executor/TaskMetrics.scala | 28 ++++ .../apache/spark/rdd/PairRDDFunctions.scala | 51 ++++++- .../apache/spark/scheduler/JobLogger.scala | 7 +- .../scala/org/apache/spark/ui/ToolTips.scala | 2 + .../apache/spark/ui/exec/ExecutorsTab.scala | 5 + .../apache/spark/ui/jobs/ExecutorTable.scala | 3 + .../spark/ui/jobs/JobProgressListener.scala | 6 + .../org/apache/spark/ui/jobs/StagePage.scala | 29 +++- .../org/apache/spark/ui/jobs/StageTable.scala | 4 + .../org/apache/spark/ui/jobs/UIData.scala | 2 + .../org/apache/spark/util/JsonProtocol.scala | 21 ++- ...te.scala => InputOutputMetricsSuite.scala} | 41 +++++- .../spark/scheduler/SparkListenerSuite.scala | 1 + .../ui/jobs/JobProgressListenerSuite.scala | 7 + .../apache/spark/util/JsonProtocolSuite.scala | 124 ++++++++++++++++-- 16 files changed, 346 insertions(+), 31 deletions(-) rename core/src/test/scala/org/apache/spark/metrics/{InputMetricsSuite.scala => InputOutputMetricsSuite.scala} (67%) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index e28eaad8a518..60ee115e393c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -17,6 +17,7 @@ package org.apache.spark.deploy +import java.lang.reflect.Method import java.security.PrivilegedExceptionAction import org.apache.hadoop.conf.Configuration @@ -133,14 +134,9 @@ class SparkHadoopUtil extends Logging { */ private[spark] def getFSBytesReadOnThreadCallback(path: Path, conf: Configuration) : Option[() => Long] = { - val qualifiedPath = path.getFileSystem(conf).makeQualified(path) - val scheme = qualifiedPath.toUri().getScheme() - val stats = FileSystem.getAllStatistics().filter(_.getScheme().equals(scheme)) try { - val threadStats = stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics")) - val statisticsDataClass = - Class.forName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData") - val getBytesReadMethod = statisticsDataClass.getDeclaredMethod("getBytesRead") + val threadStats = getFileSystemThreadStatistics(path, conf) + val getBytesReadMethod = getFileSystemThreadStatisticsMethod("getBytesRead") val f = () => threadStats.map(getBytesReadMethod.invoke(_).asInstanceOf[Long]).sum val baselineBytesRead = f() Some(() => f() - baselineBytesRead) @@ -151,6 +147,42 @@ class SparkHadoopUtil extends Logging { } } } + + /** + * Returns a function that can be called to find Hadoop FileSystem bytes written. If + * getFSBytesWrittenOnThreadCallback is called from thread r at time t, the returned callback will + * return the bytes written on r since t. Reflection is required because thread-level FileSystem + * statistics are only available as of Hadoop 2.5 (see HADOOP-10688). + * Returns None if the required method can't be found. + */ + private[spark] def getFSBytesWrittenOnThreadCallback(path: Path, conf: Configuration) + : Option[() => Long] = { + try { + val threadStats = getFileSystemThreadStatistics(path, conf) + val getBytesWrittenMethod = getFileSystemThreadStatisticsMethod("getBytesWritten") + val f = () => threadStats.map(getBytesWrittenMethod.invoke(_).asInstanceOf[Long]).sum + val baselineBytesWritten = f() + Some(() => f() - baselineBytesWritten) + } catch { + case e: NoSuchMethodException => { + logDebug("Couldn't find method for retrieving thread-level FileSystem output data", e) + None + } + } + } + + private def getFileSystemThreadStatistics(path: Path, conf: Configuration): Seq[AnyRef] = { + val qualifiedPath = path.getFileSystem(conf).makeQualified(path) + val scheme = qualifiedPath.toUri().getScheme() + val stats = FileSystem.getAllStatistics().filter(_.getScheme().equals(scheme)) + stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics")) + } + + private def getFileSystemThreadStatisticsMethod(methodName: String): Method = { + val statisticsDataClass = + Class.forName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData") + statisticsDataClass.getDeclaredMethod(methodName) + } } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 57bc2b40cec4..51b5328cb4c8 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -82,6 +82,12 @@ class TaskMetrics extends Serializable { */ var inputMetrics: Option[InputMetrics] = None + /** + * If this task writes data externally (e.g. to a distributed filesystem), metrics on how much + * data was written are stored here. + */ + var outputMetrics: Option[OutputMetrics] = None + /** * If this task reads from shuffle output, metrics on getting shuffle data will be collected here. * This includes read metrics aggregated over all the task's shuffle dependencies. @@ -157,6 +163,16 @@ object DataReadMethod extends Enumeration with Serializable { val Memory, Disk, Hadoop, Network = Value } +/** + * :: DeveloperApi :: + * Method by which output data was written. + */ +@DeveloperApi +object DataWriteMethod extends Enumeration with Serializable { + type DataWriteMethod = Value + val Hadoop = Value +} + /** * :: DeveloperApi :: * Metrics about reading input data. @@ -169,6 +185,18 @@ case class InputMetrics(readMethod: DataReadMethod.Value) { var bytesRead: Long = 0L } +/** + * :: DeveloperApi :: + * Metrics about writing output data. + */ +@DeveloperApi +case class OutputMetrics(writeMethod: DataWriteMethod.Value) { + /** + * Total bytes written + */ + var bytesWritten: Long = 0L +} + /** * :: DeveloperApi :: * Metrics pertaining to shuffle data read in a given task. diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 462f0d6268a8..8c2c959e73bb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -28,7 +28,7 @@ import scala.reflect.ClassTag import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus import org.apache.hadoop.conf.{Configurable, Configuration} -import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} @@ -40,6 +40,7 @@ import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.executor.{DataWriteMethod, OutputMetrics} import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer @@ -962,30 +963,40 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => { + val config = wrappedConf.value // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. val attemptNumber = (context.attemptId % Int.MaxValue).toInt /* "reduce task" */ val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, attemptNumber) - val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) + val hadoopContext = newTaskAttemptContext(config, attemptId) val format = outfmt.newInstance format match { - case c: Configurable => c.setConf(wrappedConf.value) + case c: Configurable => c.setConf(config) case _ => () } val committer = format.getOutputCommitter(hadoopContext) committer.setupTask(hadoopContext) + + val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config) + val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] try { + var recordsWritten = 0L while (iter.hasNext) { val pair = iter.next() writer.write(pair._1, pair._2) + + // Update bytes written metric every few records + maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten) + recordsWritten += 1 } } finally { writer.close(hadoopContext) } committer.commitTask(hadoopContext) + bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() } 1 } : Int @@ -1006,6 +1017,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) def saveAsHadoopDataset(conf: JobConf) { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf + val wrappedConf = new SerializableWritable(hadoopConf) val outputFormatInstance = hadoopConf.getOutputFormat val keyClass = hadoopConf.getOutputKeyClass val valueClass = hadoopConf.getOutputValueClass @@ -1033,27 +1045,56 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.preSetup() val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => { + val config = wrappedConf.value // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. val attemptNumber = (context.attemptId % Int.MaxValue).toInt + val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config) + writer.setup(context.stageId, context.partitionId, attemptNumber) writer.open() try { + var recordsWritten = 0L while (iter.hasNext) { val record = iter.next() writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) + + // Update bytes written metric every few records + maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten) + recordsWritten += 1 } } finally { writer.close() } writer.commit() + bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() } } self.context.runJob(self, writeToFile) writer.commitJob() } + private def initHadoopOutputMetrics(context: TaskContext, config: Configuration) + : (OutputMetrics, Option[() => Long]) = { + val bytesWrittenCallback = Option(config.get("mapreduce.output.fileoutputformat.outputdir")) + .map(new Path(_)) + .flatMap(SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback(_, config)) + val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop) + if (bytesWrittenCallback.isDefined) { + context.taskMetrics.outputMetrics = Some(outputMetrics) + } + (outputMetrics, bytesWrittenCallback) + } + + private def maybeUpdateOutputMetrics(bytesWrittenCallback: Option[() => Long], + outputMetrics: OutputMetrics, recordsWritten: Long): Unit = { + if (recordsWritten % PairRDDFunctions.RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0 + && bytesWrittenCallback.isDefined) { + bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() } + } + } + /** * Return an RDD with the keys of each tuple. */ @@ -1070,3 +1111,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) private[spark] def keyOrdering: Option[Ordering[K]] = Option(ord) } + +private[spark] object PairRDDFunctions { + val RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES = 256 +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 4e3d9de54078..3bb54855bae4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -158,6 +158,11 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener " INPUT_BYTES=" + metrics.bytesRead case None => "" } + val outputMetrics = taskMetrics.outputMetrics match { + case Some(metrics) => + " OUTPUT_BYTES=" + metrics.bytesWritten + case None => "" + } val shuffleReadMetrics = taskMetrics.shuffleReadMetrics match { case Some(metrics) => " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched + @@ -173,7 +178,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener " SHUFFLE_WRITE_TIME=" + metrics.shuffleWriteTime case None => "" } - stageLogInfo(stageId, status + info + executorRunTime + gcTime + inputMetrics + + stageLogInfo(stageId, status + info + executorRunTime + gcTime + inputMetrics + outputMetrics + shuffleReadMetrics + writeMetrics) } diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index 51dc08f668a4..6f446c5a95a0 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -29,6 +29,8 @@ private[spark] object ToolTips { val INPUT = "Bytes read from Hadoop or from Spark storage." + val OUTPUT = "Bytes written to Hadoop." + val SHUFFLE_WRITE = "Bytes written to disk in order to be read by a shuffle in a future stage." val SHUFFLE_READ = diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index ba97630f025c..dd1c2b78c409 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -48,6 +48,7 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp val executorToTasksFailed = HashMap[String, Int]() val executorToDuration = HashMap[String, Long]() val executorToInputBytes = HashMap[String, Long]() + val executorToOutputBytes = HashMap[String, Long]() val executorToShuffleRead = HashMap[String, Long]() val executorToShuffleWrite = HashMap[String, Long]() @@ -78,6 +79,10 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp executorToInputBytes(eid) = executorToInputBytes.getOrElse(eid, 0L) + inputMetrics.bytesRead } + metrics.outputMetrics.foreach { outputMetrics => + executorToOutputBytes(eid) = + executorToOutputBytes.getOrElse(eid, 0L) + outputMetrics.bytesWritten + } metrics.shuffleReadMetrics.foreach { shuffleRead => executorToShuffleRead(eid) = executorToShuffleRead.getOrElse(eid, 0L) + shuffleRead.remoteBytesRead diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index f0e43fbf7097..fa0f96bff34f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -45,6 +45,7 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: JobPr Failed Tasks Succeeded Tasks Input + Output Shuffle Read Shuffle Write Shuffle Spill (Memory) @@ -77,6 +78,8 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: JobPr {v.succeededTasks} {Utils.bytesToString(v.inputBytes)} + + {Utils.bytesToString(v.outputBytes)} {Utils.bytesToString(v.shuffleRead)} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index e3223403c17f..8bbde51e1801 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -259,6 +259,12 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageData.inputBytes += inputBytesDelta execSummary.inputBytes += inputBytesDelta + val outputBytesDelta = + (taskMetrics.outputMetrics.map(_.bytesWritten).getOrElse(0L) + - oldMetrics.flatMap(_.outputMetrics).map(_.bytesWritten).getOrElse(0L)) + stageData.outputBytes += outputBytesDelta + execSummary.outputBytes += outputBytesDelta + val diskSpillDelta = taskMetrics.diskBytesSpilled - oldMetrics.map(_.diskBytesSpilled).getOrElse(0L) stageData.diskBytesSpilled += diskSpillDelta diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 250bddbe2f26..16bc3f6c18d0 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -57,6 +57,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { val accumulables = listener.stageIdToData((stageId, stageAttemptId)).accumulables val hasAccumulators = accumulables.size > 0 val hasInput = stageData.inputBytes > 0 + val hasOutput = stageData.outputBytes > 0 val hasShuffleRead = stageData.shuffleReadBytes > 0 val hasShuffleWrite = stageData.shuffleWriteBytes > 0 val hasBytesSpilled = stageData.memoryBytesSpilled > 0 && stageData.diskBytesSpilled > 0 @@ -74,6 +75,12 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { {Utils.bytesToString(stageData.inputBytes)} }} + {if (hasOutput) { +
  • + Output: + {Utils.bytesToString(stageData.outputBytes)} +
  • + }} {if (hasShuffleRead) {
  • Shuffle read: @@ -162,6 +169,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++ {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ {if (hasInput) Seq(("Input", "")) else Nil} ++ + {if (hasOutput) Seq(("Output", "")) else Nil} ++ {if (hasShuffleRead) Seq(("Shuffle Read", "")) else Nil} ++ {if (hasShuffleWrite) Seq(("Write Time", ""), ("Shuffle Write", "")) else Nil} ++ {if (hasBytesSpilled) Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) @@ -172,7 +180,8 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { val taskTable = UIUtils.listingTable( unzipped._1, - taskRow(hasAccumulators, hasInput, hasShuffleRead, hasShuffleWrite, hasBytesSpilled), + taskRow(hasAccumulators, hasInput, hasOutput, hasShuffleRead, hasShuffleWrite, + hasBytesSpilled), tasks, headerClasses = unzipped._2) // Excludes tasks which failed and have incomplete metrics @@ -260,6 +269,11 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { } val inputQuantiles = Input +: getFormattedSizeQuantiles(inputSizes) + val outputSizes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.outputMetrics.map(_.bytesWritten).getOrElse(0L).toDouble + } + val outputQuantiles = Output +: getFormattedSizeQuantiles(outputSizes) + val shuffleReadSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble } @@ -296,6 +310,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { , {gettingResultQuantiles}, if (hasInput) {inputQuantiles} else Nil, + if (hasOutput) {outputQuantiles} else Nil, if (hasShuffleRead) {shuffleReadQuantiles} else Nil, if (hasShuffleWrite) {shuffleWriteQuantiles} else Nil, if (hasBytesSpilled) {memoryBytesSpilledQuantiles} else Nil, @@ -328,6 +343,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { def taskRow( hasAccumulators: Boolean, hasInput: Boolean, + hasOutput: Boolean, hasShuffleRead: Boolean, hasShuffleWrite: Boolean, hasBytesSpilled: Boolean)(taskData: TaskUIData): Seq[Node] = { @@ -351,6 +367,12 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { .map(m => s"${Utils.bytesToString(m.bytesRead)} (${m.readMethod.toString.toLowerCase()})") .getOrElse("") + val maybeOutput = metrics.flatMap(_.outputMetrics) + val outputSortable = maybeOutput.map(_.bytesWritten.toString).getOrElse("") + val outputReadable = maybeOutput + .map(m => s"${Utils.bytesToString(m.bytesWritten)}") + .getOrElse("") + val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead) val shuffleReadSortable = maybeShuffleRead.map(_.toString).getOrElse("") val shuffleReadReadable = maybeShuffleRead.map(Utils.bytesToString).getOrElse("") @@ -417,6 +439,11 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { {inputReadable} }} + {if (hasOutput) { + + {outputReadable} + + }} {if (hasShuffleRead) { {shuffleReadReadable} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 3b4866e05956..eae542df85d0 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -45,6 +45,7 @@ private[ui] class StageTableBase( Duration Tasks: Succeeded/Total Input + Output Shuffle Read + commons-io + commons-io + + + + + org.apache.hbase + hbase-hadoop-compat + ${hbase.version} + + + org.apache.hbase + hbase-hadoop-compat + ${hbase.version} + test-jar + test + com.twitter algebird-core_${scala.binary.version} From 19dcb5714ba326c272981e6e7e547ff7990648b9 Mon Sep 17 00:00:00 2001 From: Varadharajan Mukundan Date: Mon, 10 Nov 2014 14:32:29 -0800 Subject: [PATCH 073/652] [SPARK-4047] - Generate runtime warnings for example implementation of PageRank Based on SPARK-2434, this PR generates runtime warnings for example implementations (Python, Scala) of PageRank. Author: Varadharajan Mukundan Closes #2894 from varadharajan/SPARK-4047 and squashes the following commits: 5f9406b [Varadharajan Mukundan] [SPARK-4047] - Point users to LogisticRegressionWithSGD and LogisticRegressionWithLBFGS instead of LogisticRegressionModel 252f595 [Varadharajan Mukundan] a. Generate runtime warnings for 05a018b [Varadharajan Mukundan] Fix PageRank implementation's package reference 5c2bf54 [Varadharajan Mukundan] [SPARK-4047] - Generate runtime warnings for example implementation of PageRank (cherry picked from commit 974d334cf06a84317234a6c8e2e9ecca8271fa41) Signed-off-by: Xiangrui Meng --- .../org/apache/spark/examples/JavaHdfsLR.java | 15 +++++++++++++++ .../org/apache/spark/examples/JavaPageRank.java | 13 +++++++++++++ examples/src/main/python/pagerank.py | 8 ++++++++ .../org/apache/spark/examples/LocalFileLR.scala | 6 ++++-- .../org/apache/spark/examples/LocalLR.scala | 6 ++++-- .../org/apache/spark/examples/SparkHdfsLR.scala | 6 ++++-- .../org/apache/spark/examples/SparkLR.scala | 6 ++++-- .../apache/spark/examples/SparkPageRank.scala | 15 +++++++++++++++ .../spark/examples/SparkTachyonHdfsLR.scala | 16 ++++++++++++++++ 9 files changed, 83 insertions(+), 8 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java index 6c177de359b6..31a79ddd3fff 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java @@ -30,12 +30,25 @@ /** * Logistic regression based classification. + * + * This is an example implementation for learning how to use Spark. For more conventional use, + * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. */ public final class JavaHdfsLR { private static final int D = 10; // Number of dimensions private static final Random rand = new Random(42); + static void showWarning() { + String warning = "WARN: This is a naive implementation of Logistic Regression " + + "and is given as an example!\n" + + "Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD " + + "or org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS " + + "for more conventional use."; + System.err.println(warning); + } + static class DataPoint implements Serializable { DataPoint(double[] x, double y) { this.x = x; @@ -109,6 +122,8 @@ public static void main(String[] args) { System.exit(1); } + showWarning(); + SparkConf sparkConf = new SparkConf().setAppName("JavaHdfsLR"); JavaSparkContext sc = new JavaSparkContext(sparkConf); JavaRDD lines = sc.textFile(args[0]); diff --git a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java index c22506491fbf..a5db8accdf13 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java @@ -45,10 +45,21 @@ * URL neighbor URL * ... * where URL and their neighbors are separated by space(s). + * + * This is an example implementation for learning how to use Spark. For more conventional use, + * please refer to org.apache.spark.graphx.lib.PageRank */ public final class JavaPageRank { private static final Pattern SPACES = Pattern.compile("\\s+"); + static void showWarning() { + String warning = "WARN: This is a naive implementation of PageRank " + + "and is given as an example! \n" + + "Please use the PageRank implementation found in " + + "org.apache.spark.graphx.lib.PageRank for more conventional use."; + System.err.println(warning); + } + private static class Sum implements Function2 { @Override public Double call(Double a, Double b) { @@ -62,6 +73,8 @@ public static void main(String[] args) throws Exception { System.exit(1); } + showWarning(); + SparkConf sparkConf = new SparkConf().setAppName("JavaPageRank"); JavaSparkContext ctx = new JavaSparkContext(sparkConf); diff --git a/examples/src/main/python/pagerank.py b/examples/src/main/python/pagerank.py index b539c4128cdc..a5f25d78c114 100755 --- a/examples/src/main/python/pagerank.py +++ b/examples/src/main/python/pagerank.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +This is an example implementation of PageRank. For more conventional use, +Please refer to PageRank implementation provided by graphx +""" + import re import sys from operator import add @@ -40,6 +45,9 @@ def parseNeighbors(urls): print >> sys.stderr, "Usage: pagerank " exit(-1) + print >> sys.stderr, """WARN: This is a naive implementation of PageRank and is + given as an example! Please refer to PageRank implementation provided by graphx""" + # Initialize the spark context. sc = SparkContext(appName="PythonPageRank") diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala index 931faac5463c..ac2ea35bbd0e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala @@ -25,7 +25,8 @@ import breeze.linalg.{Vector, DenseVector} * Logistic regression based classification. * * This is an example implementation for learning how to use Spark. For more conventional use, - * please refer to org.apache.spark.mllib.classification.LogisticRegression + * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. */ object LocalFileLR { val D = 10 // Numer of dimensions @@ -41,7 +42,8 @@ object LocalFileLR { def showWarning() { System.err.println( """WARN: This is a naive implementation of Logistic Regression and is given as an example! - |Please use the LogisticRegression method found in org.apache.spark.mllib.classification + |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS |for more conventional use. """.stripMargin) } diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala index 2d75b9d2590f..92a683ad57ea 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -25,7 +25,8 @@ import breeze.linalg.{Vector, DenseVector} * Logistic regression based classification. * * This is an example implementation for learning how to use Spark. For more conventional use, - * please refer to org.apache.spark.mllib.classification.LogisticRegression + * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. */ object LocalLR { val N = 10000 // Number of data points @@ -48,7 +49,8 @@ object LocalLR { def showWarning() { System.err.println( """WARN: This is a naive implementation of Logistic Regression and is given as an example! - |Please use the LogisticRegression method found in org.apache.spark.mllib.classification + |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS |for more conventional use. """.stripMargin) } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index 325851089437..9099c2fcc90b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -32,7 +32,8 @@ import org.apache.spark.scheduler.InputFormatInfo * Logistic regression based classification. * * This is an example implementation for learning how to use Spark. For more conventional use, - * please refer to org.apache.spark.mllib.classification.LogisticRegression + * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. */ object SparkHdfsLR { val D = 10 // Numer of dimensions @@ -54,7 +55,8 @@ object SparkHdfsLR { def showWarning() { System.err.println( """WARN: This is a naive implementation of Logistic Regression and is given as an example! - |Please use the LogisticRegression method found in org.apache.spark.mllib.classification + |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS |for more conventional use. """.stripMargin) } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index fc23308fc4ad..257a7d29f922 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -30,7 +30,8 @@ import org.apache.spark._ * Usage: SparkLR [slices] * * This is an example implementation for learning how to use Spark. For more conventional use, - * please refer to org.apache.spark.mllib.classification.LogisticRegression + * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. */ object SparkLR { val N = 10000 // Number of data points @@ -53,7 +54,8 @@ object SparkLR { def showWarning() { System.err.println( """WARN: This is a naive implementation of Logistic Regression and is given as an example! - |Please use the LogisticRegression method found in org.apache.spark.mllib.classification + |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS |for more conventional use. """.stripMargin) } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala index 4c7e006da061..8d092b6506d3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -28,13 +28,28 @@ import org.apache.spark.{SparkConf, SparkContext} * URL neighbor URL * ... * where URL and their neighbors are separated by space(s). + * + * This is an example implementation for learning how to use Spark. For more conventional use, + * please refer to org.apache.spark.graphx.lib.PageRank */ object SparkPageRank { + + def showWarning() { + System.err.println( + """WARN: This is a naive implementation of PageRank and is given as an example! + |Please use the PageRank implementation found in org.apache.spark.graphx.lib.PageRank + |for more conventional use. + """.stripMargin) + } + def main(args: Array[String]) { if (args.length < 1) { System.err.println("Usage: SparkPageRank ") System.exit(1) } + + showWarning() + val sparkConf = new SparkConf().setAppName("PageRank") val iters = if (args.length > 0) args(1).toInt else 10 val ctx = new SparkContext(sparkConf) diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala index 96d13612e46d..4393b99e636b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala @@ -32,11 +32,24 @@ import org.apache.spark.storage.StorageLevel /** * Logistic regression based classification. * This example uses Tachyon to persist rdds during computation. + * + * This is an example implementation for learning how to use Spark. For more conventional use, + * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. */ object SparkTachyonHdfsLR { val D = 10 // Numer of dimensions val rand = new Random(42) + def showWarning() { + System.err.println( + """WARN: This is a naive implementation of Logistic Regression and is given as an example! + |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS + |for more conventional use. + """.stripMargin) + } + case class DataPoint(x: Vector[Double], y: Double) def parsePoint(line: String): DataPoint = { @@ -51,6 +64,9 @@ object SparkTachyonHdfsLR { } def main(args: Array[String]) { + + showWarning() + val inputPath = args(0) val sparkConf = new SparkConf().setAppName("SparkTachyonHdfsLR") val conf = new Configuration() From 0089a4f64d90f923dc02aee45bcda4be726d740a Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 10 Nov 2014 15:55:15 -0800 Subject: [PATCH 074/652] [SPARK-4319][SQL] Enable an ignored test "null count". Author: Takuya UESHIN Closes #3185 from ueshin/issues/SPARK-4319 and squashes the following commits: a44a38e [Takuya UESHIN] Enable an ignored test "null count". (cherry picked from commit dbf10588de03e8ea993fff687a78727eff55db1f) Signed-off-by: Michael Armbrust --- .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 9 ++++----- .../src/test/scala/org/apache/spark/sql/TestData.scala | 9 +++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 702714af5308..8a80724c08c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -281,14 +281,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { 3) } - // No support for primitive nulls yet. - ignore("null count") { + test("null count") { checkAnswer( - sql("SELECT a, COUNT(b) FROM testData3"), - Seq((1,0), (2, 1))) + sql("SELECT a, COUNT(b) FROM testData3 GROUP BY a"), + Seq((1, 0), (2, 1))) checkAnswer( - testData3.groupBy()(Count('a), Count('b), Count(1), CountDistinct('a :: Nil), CountDistinct('b :: Nil)), + sql("SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3"), (2, 1, 2, 2, 1) :: Nil) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index ef87a230639b..92b49e815590 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -64,11 +64,12 @@ object TestData { BinaryData("123".getBytes(), 4) :: Nil).toSchemaRDD binaryData.registerTempTable("binaryData") - // TODO: There is no way to express null primitives as case classes currently... + case class TestData3(a: Int, b: Option[Int]) val testData3 = - logical.LocalRelation('a.int, 'b.int).loadData( - (1, null) :: - (2, 2) :: Nil) + TestSQLContext.sparkContext.parallelize( + TestData3(1, None) :: + TestData3(2, Some(2)) :: Nil).toSchemaRDD + testData3.registerTempTable("testData3") val emptyTableData = logical.LocalRelation('a.int, 'b.int) From 1ed1c68c0aa8f4da517cd4ac5c4ab117d2cee839 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 10 Nov 2014 17:20:52 -0800 Subject: [PATCH 075/652] [SQL] remove a decimal case branch that has no effect at runtime it generates warnings at compile time marmbrus Author: Xiangrui Meng Closes #3192 from mengxr/dtc-decimal and squashes the following commits: 955e9fb [Xiangrui Meng] remove a decimal case branch that has no effect (cherry picked from commit d793d80c8084923ea04dcf7d268eec8ede490127) Signed-off-by: Michael Armbrust --- .../org/apache/spark/sql/types/util/DataTypeConversions.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala index 3fa4a7c6481d..9aad7b3df4ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -133,7 +133,6 @@ protected[sql] object DataTypeConversions { def convertJavaToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { case (obj, udt: UserDefinedType[_]) => ScalaReflection.convertToCatalyst(obj, udt) // Scala type case (d: java.math.BigDecimal, _) => Decimal(BigDecimal(d)) - case (d: java.math.BigDecimal, _) => BigDecimal(d) case (other, _) => other } From ff071e35173224546a879685c2febdd9ea0ab630 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Mon, 10 Nov 2014 17:22:57 -0800 Subject: [PATCH 076/652] [SPARK-4250] [SQL] Fix bug of constant null value mapping to ConstantObjectInspector Author: Cheng Hao Closes #3114 from chenghao-intel/constant_null_oi and squashes the following commits: e603bda [Cheng Hao] fix the bug of null value for primitive types 50a13ba [Cheng Hao] fix the timezone issue f54f369 [Cheng Hao] fix bug of constant null value for ObjectInspector (cherry picked from commit fa777833b52b6f339cdc335e8e3935cfe9a2a7eb) Signed-off-by: Michael Armbrust --- .../spark/sql/hive/HiveInspectors.scala | 78 ++++++++++-------- ...testing-0-9a02bc7de09bcabcbd4c91f54a814c20 | 1 + .../udf_if-0-b7ffa85b5785cccef2af1b285348cc2c | 1 + .../udf_if-1-30cf7f51f92b5684e556deff3032d49a | 1 + .../udf_if-2-f2b010128e922d0096a65ddd9ae1d0b4 | 0 .../udf_if-3-20206f17367ff284d67044abd745ce9f | 1 + .../udf_if-4-174dae8a1eb4cad6ccf6f67203de71ca | 0 .../udf_if-5-a7db13aec05c97792f9331d63709d8cc | 1 + .../sql/hive/execution/HiveQuerySuite.scala | 52 +++++++++++- .../org/apache/spark/sql/hive/Shim12.scala | 70 ++++++++++------ .../org/apache/spark/sql/hive/Shim13.scala | 80 +++++++++++++------ 11 files changed, 199 insertions(+), 86 deletions(-) create mode 100644 sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 create mode 100644 sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c create mode 100644 sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a create mode 100644 sql/hive/src/test/resources/golden/udf_if-2-f2b010128e922d0096a65ddd9ae1d0b4 create mode 100644 sql/hive/src/test/resources/golden/udf_if-3-20206f17367ff284d67044abd745ce9f create mode 100644 sql/hive/src/test/resources/golden/udf_if-4-174dae8a1eb4cad6ccf6f67203de71ca create mode 100644 sql/hive/src/test/resources/golden/udf_if-5-a7db13aec05c97792f9331d63709d8cc diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index bdc7e1dac192..7e76aff642bb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -88,6 +88,7 @@ private[hive] trait HiveInspectors { * @return convert the data into catalyst type */ def unwrap(data: Any, oi: ObjectInspector): Any = oi match { + case _ if data == null => null case hvoi: HiveVarcharObjectInspector => if (data == null) null else hvoi.getPrimitiveJavaObject(data).getValue case hdoi: HiveDecimalObjectInspector => @@ -250,46 +251,53 @@ private[hive] trait HiveInspectors { } def toInspector(expr: Expression): ObjectInspector = expr match { - case Literal(value: String, StringType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: Int, IntegerType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: Double, DoubleType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: Boolean, BooleanType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: Long, LongType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: Float, FloatType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: Short, ShortType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: Byte, ByteType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: Array[Byte], BinaryType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: java.sql.Date, DateType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: java.sql.Timestamp, TimestampType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: BigDecimal, DecimalType()) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: Decimal, DecimalType()) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value.toBigDecimal) + case Literal(value, StringType) => + HiveShim.getStringWritableConstantObjectInspector(value) + case Literal(value, IntegerType) => + HiveShim.getIntWritableConstantObjectInspector(value) + case Literal(value, DoubleType) => + HiveShim.getDoubleWritableConstantObjectInspector(value) + case Literal(value, BooleanType) => + HiveShim.getBooleanWritableConstantObjectInspector(value) + case Literal(value, LongType) => + HiveShim.getLongWritableConstantObjectInspector(value) + case Literal(value, FloatType) => + HiveShim.getFloatWritableConstantObjectInspector(value) + case Literal(value, ShortType) => + HiveShim.getShortWritableConstantObjectInspector(value) + case Literal(value, ByteType) => + HiveShim.getByteWritableConstantObjectInspector(value) + case Literal(value, BinaryType) => + HiveShim.getBinaryWritableConstantObjectInspector(value) + case Literal(value, DateType) => + HiveShim.getDateWritableConstantObjectInspector(value) + case Literal(value, TimestampType) => + HiveShim.getTimestampWritableConstantObjectInspector(value) + case Literal(value, DecimalType()) => + HiveShim.getDecimalWritableConstantObjectInspector(value) case Literal(_, NullType) => HiveShim.getPrimitiveNullWritableConstantObjectInspector - case Literal(value: Seq[_], ArrayType(dt, _)) => + case Literal(value, ArrayType(dt, _)) => val listObjectInspector = toInspector(dt) - val list = new java.util.ArrayList[Object]() - value.foreach(v => list.add(wrap(v, listObjectInspector))) - ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list) - case Literal(map: Map[_, _], MapType(keyType, valueType, _)) => - val value = new java.util.HashMap[Object, Object]() + if (value == null) { + ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, null) + } else { + val list = new java.util.ArrayList[Object]() + value.asInstanceOf[Seq[_]].foreach(v => list.add(wrap(v, listObjectInspector))) + ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list) + } + case Literal(value, MapType(keyType, valueType, _)) => val keyOI = toInspector(keyType) val valueOI = toInspector(valueType) - map.foreach (entry => value.put(wrap(entry._1, keyOI), wrap(entry._2, valueOI))) - ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, value) - case Literal(_, dt) => sys.error(s"Hive doesn't support the constant type [$dt].") + if (value == null) { + ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, null) + } else { + val map = new java.util.HashMap[Object, Object]() + value.asInstanceOf[Map[_, _]].foreach (entry => { + map.put(wrap(entry._1, keyOI), wrap(entry._2, valueOI)) + }) + ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, map) + } case _ => toInspector(expr.dataType) } diff --git a/sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 b/sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 new file mode 100644 index 000000000000..7c41615f8c18 --- /dev/null +++ b/sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 @@ -0,0 +1 @@ +1 NULL 1 NULL 1.0 NULL true NULL 1 NULL 1.0 NULL 1 NULL 1 NULL 1 NULL 1970-01-01 NULL 1969-12-31 16:00:00.001 NULL 1 NULL diff --git a/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c b/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c new file mode 100644 index 000000000000..2cf0d9d61882 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c @@ -0,0 +1 @@ +There is no documentation for function 'if' diff --git a/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a b/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a new file mode 100644 index 000000000000..2cf0d9d61882 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a @@ -0,0 +1 @@ +There is no documentation for function 'if' diff --git a/sql/hive/src/test/resources/golden/udf_if-2-f2b010128e922d0096a65ddd9ae1d0b4 b/sql/hive/src/test/resources/golden/udf_if-2-f2b010128e922d0096a65ddd9ae1d0b4 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/udf_if-3-20206f17367ff284d67044abd745ce9f b/sql/hive/src/test/resources/golden/udf_if-3-20206f17367ff284d67044abd745ce9f new file mode 100644 index 000000000000..a29e96cbd1db --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_if-3-20206f17367ff284d67044abd745ce9f @@ -0,0 +1 @@ +1 1 1 1 NULL 2 diff --git a/sql/hive/src/test/resources/golden/udf_if-4-174dae8a1eb4cad6ccf6f67203de71ca b/sql/hive/src/test/resources/golden/udf_if-4-174dae8a1eb4cad6ccf6f67203de71ca new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/udf_if-5-a7db13aec05c97792f9331d63709d8cc b/sql/hive/src/test/resources/golden/udf_if-5-a7db13aec05c97792f9331d63709d8cc new file mode 100644 index 000000000000..f0669b86989d --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_if-5-a7db13aec05c97792f9331d63709d8cc @@ -0,0 +1 @@ +128 1.1 ABC 12.3 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index b897dff0159f..684d22807c0c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -18,6 +18,9 @@ package org.apache.spark.sql.hive.execution import java.io.File +import java.util.{Locale, TimeZone} + +import org.scalatest.BeforeAndAfter import scala.util.Try @@ -28,14 +31,59 @@ import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{Row, SchemaRDD} +import org.apache.spark.sql.{SQLConf, Row, SchemaRDD} case class TestData(a: Int, b: String) /** * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. */ -class HiveQuerySuite extends HiveComparisonTest { +class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { + private val originalTimeZone = TimeZone.getDefault + private val originalLocale = Locale.getDefault + + override def beforeAll() { + TestHive.cacheTables = true + // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + // Add Locale setting + Locale.setDefault(Locale.US) + } + + override def afterAll() { + TestHive.cacheTables = false + TimeZone.setDefault(originalTimeZone) + Locale.setDefault(originalLocale) + } + + createQueryTest("constant null testing", + """SELECT + |IF(FALSE, CAST(NULL AS STRING), CAST(1 AS STRING)) AS COL1, + |IF(TRUE, CAST(NULL AS STRING), CAST(1 AS STRING)) AS COL2, + |IF(FALSE, CAST(NULL AS INT), CAST(1 AS INT)) AS COL3, + |IF(TRUE, CAST(NULL AS INT), CAST(1 AS INT)) AS COL4, + |IF(FALSE, CAST(NULL AS DOUBLE), CAST(1 AS DOUBLE)) AS COL5, + |IF(TRUE, CAST(NULL AS DOUBLE), CAST(1 AS DOUBLE)) AS COL6, + |IF(FALSE, CAST(NULL AS BOOLEAN), CAST(1 AS BOOLEAN)) AS COL7, + |IF(TRUE, CAST(NULL AS BOOLEAN), CAST(1 AS BOOLEAN)) AS COL8, + |IF(FALSE, CAST(NULL AS BIGINT), CAST(1 AS BIGINT)) AS COL9, + |IF(TRUE, CAST(NULL AS BIGINT), CAST(1 AS BIGINT)) AS COL10, + |IF(FALSE, CAST(NULL AS FLOAT), CAST(1 AS FLOAT)) AS COL11, + |IF(TRUE, CAST(NULL AS FLOAT), CAST(1 AS FLOAT)) AS COL12, + |IF(FALSE, CAST(NULL AS SMALLINT), CAST(1 AS SMALLINT)) AS COL13, + |IF(TRUE, CAST(NULL AS SMALLINT), CAST(1 AS SMALLINT)) AS COL14, + |IF(FALSE, CAST(NULL AS TINYINT), CAST(1 AS TINYINT)) AS COL15, + |IF(TRUE, CAST(NULL AS TINYINT), CAST(1 AS TINYINT)) AS COL16, + |IF(FALSE, CAST(NULL AS BINARY), CAST("1" AS BINARY)) AS COL17, + |IF(TRUE, CAST(NULL AS BINARY), CAST("1" AS BINARY)) AS COL18, + |IF(FALSE, CAST(NULL AS DATE), CAST("1970-01-01" AS DATE)) AS COL19, + |IF(TRUE, CAST(NULL AS DATE), CAST("1970-01-01" AS DATE)) AS COL20, + |IF(FALSE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL21, + |IF(TRUE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL22, + |IF(FALSE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL23, + |IF(TRUE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL24 + |FROM src LIMIT 1""".stripMargin) + createQueryTest("constant array", """ |SELECT sort_array( diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala index 8e946b7e82f5..8ba25f889d17 100644 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala @@ -57,54 +57,74 @@ private[hive] object HiveShim { new TableDesc(serdeClass, inputFormatClass, outputFormatClass, properties) } - def getPrimitiveWritableConstantObjectInspector(value: String): ObjectInspector = + def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.STRING, new hadoopIo.Text(value)) + PrimitiveCategory.STRING, + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String])) - def getPrimitiveWritableConstantObjectInspector(value: Int): ObjectInspector = + def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.INT, new hadoopIo.IntWritable(value)) + PrimitiveCategory.INT, + if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int])) - def getPrimitiveWritableConstantObjectInspector(value: Double): ObjectInspector = + def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.DOUBLE, new hiveIo.DoubleWritable(value)) + PrimitiveCategory.DOUBLE, + if (value == null) null else new hiveIo.DoubleWritable(value.asInstanceOf[Double])) - def getPrimitiveWritableConstantObjectInspector(value: Boolean): ObjectInspector = + def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.BOOLEAN, new hadoopIo.BooleanWritable(value)) + PrimitiveCategory.BOOLEAN, + if (value == null) null else new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean])) - def getPrimitiveWritableConstantObjectInspector(value: Long): ObjectInspector = + def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.LONG, new hadoopIo.LongWritable(value)) + PrimitiveCategory.LONG, + if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long])) - def getPrimitiveWritableConstantObjectInspector(value: Float): ObjectInspector = + def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.FLOAT, new hadoopIo.FloatWritable(value)) + PrimitiveCategory.FLOAT, + if (value == null) null else new hadoopIo.FloatWritable(value.asInstanceOf[Float])) - def getPrimitiveWritableConstantObjectInspector(value: Short): ObjectInspector = + def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.SHORT, new hiveIo.ShortWritable(value)) + PrimitiveCategory.SHORT, + if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short])) - def getPrimitiveWritableConstantObjectInspector(value: Byte): ObjectInspector = + def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.BYTE, new hiveIo.ByteWritable(value)) + PrimitiveCategory.BYTE, + if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte])) - def getPrimitiveWritableConstantObjectInspector(value: Array[Byte]): ObjectInspector = + def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.BINARY, new hadoopIo.BytesWritable(value)) + PrimitiveCategory.BINARY, + if (value == null) null else new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]])) - def getPrimitiveWritableConstantObjectInspector(value: java.sql.Date): ObjectInspector = + def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.DATE, new hiveIo.DateWritable(value)) + PrimitiveCategory.DATE, + if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date])) - def getPrimitiveWritableConstantObjectInspector(value: java.sql.Timestamp): ObjectInspector = + def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.TIMESTAMP, new hiveIo.TimestampWritable(value)) - - def getPrimitiveWritableConstantObjectInspector(value: BigDecimal): ObjectInspector = + PrimitiveCategory.TIMESTAMP, + if (value == null) { + null + } else { + new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) + }) + + def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.DECIMAL, - new hiveIo.HiveDecimalWritable(HiveShim.createDecimal(value.underlying()))) + if (value == null) { + null + } else { + new hiveIo.HiveDecimalWritable( + HiveShim.createDecimal(value.asInstanceOf[Decimal].toBigDecimal.underlying())) + }) def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index 0bc330cdbecb..e4aee57f0ad9 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -56,54 +56,86 @@ private[hive] object HiveShim { new TableDesc(inputFormatClass, outputFormatClass, properties) } - def getPrimitiveWritableConstantObjectInspector(value: String): ObjectInspector = + def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.stringTypeInfo, new hadoopIo.Text(value)) + TypeInfoFactory.stringTypeInfo, + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String])) - def getPrimitiveWritableConstantObjectInspector(value: Int): ObjectInspector = + def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.intTypeInfo, new hadoopIo.IntWritable(value)) + TypeInfoFactory.intTypeInfo, + if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int])) - def getPrimitiveWritableConstantObjectInspector(value: Double): ObjectInspector = + def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.doubleTypeInfo, new hiveIo.DoubleWritable(value)) + TypeInfoFactory.doubleTypeInfo, if (value == null) { + null + } else { + new hiveIo.DoubleWritable(value.asInstanceOf[Double]) + }) - def getPrimitiveWritableConstantObjectInspector(value: Boolean): ObjectInspector = + def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.booleanTypeInfo, new hadoopIo.BooleanWritable(value)) + TypeInfoFactory.booleanTypeInfo, if (value == null) { + null + } else { + new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean]) + }) - def getPrimitiveWritableConstantObjectInspector(value: Long): ObjectInspector = + def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.longTypeInfo, new hadoopIo.LongWritable(value)) + TypeInfoFactory.longTypeInfo, + if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long])) - def getPrimitiveWritableConstantObjectInspector(value: Float): ObjectInspector = + def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.floatTypeInfo, new hadoopIo.FloatWritable(value)) + TypeInfoFactory.floatTypeInfo, if (value == null) { + null + } else { + new hadoopIo.FloatWritable(value.asInstanceOf[Float]) + }) - def getPrimitiveWritableConstantObjectInspector(value: Short): ObjectInspector = + def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.shortTypeInfo, new hiveIo.ShortWritable(value)) + TypeInfoFactory.shortTypeInfo, + if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short])) - def getPrimitiveWritableConstantObjectInspector(value: Byte): ObjectInspector = + def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.byteTypeInfo, new hiveIo.ByteWritable(value)) + TypeInfoFactory.byteTypeInfo, + if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte])) - def getPrimitiveWritableConstantObjectInspector(value: Array[Byte]): ObjectInspector = + def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.binaryTypeInfo, new hadoopIo.BytesWritable(value)) + TypeInfoFactory.binaryTypeInfo, if (value == null) { + null + } else { + new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) + }) - def getPrimitiveWritableConstantObjectInspector(value: java.sql.Date): ObjectInspector = + def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.dateTypeInfo, new hiveIo.DateWritable(value)) + TypeInfoFactory.dateTypeInfo, + if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date])) - def getPrimitiveWritableConstantObjectInspector(value: java.sql.Timestamp): ObjectInspector = + def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.timestampTypeInfo, new hiveIo.TimestampWritable(value)) + TypeInfoFactory.timestampTypeInfo, if (value == null) { + null + } else { + new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) + }) - def getPrimitiveWritableConstantObjectInspector(value: BigDecimal): ObjectInspector = + def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( TypeInfoFactory.decimalTypeInfo, - new hiveIo.HiveDecimalWritable(HiveShim.createDecimal(value.underlying()))) + if (value == null) { + null + } else { + // TODO precise, scale? + new hiveIo.HiveDecimalWritable( + HiveShim.createDecimal(value.asInstanceOf[Decimal].toBigDecimal.underlying())) + }) def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( From f0eb0a79cc68c0f254ddf1a1bba672321c84d341 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Mon, 10 Nov 2014 17:26:03 -0800 Subject: [PATCH 077/652] [SPARK-4149][SQL] ISO 8601 support for json date time strings This implement the feature davies mentioned in https://github.com/apache/spark/pull/2901#discussion-diff-19313312 Author: Daoyuan Wang Closes #3012 from adrian-wang/iso8601 and squashes the following commits: 50df6e7 [Daoyuan Wang] json data timestamp ISO8601 support (cherry picked from commit a1fc059b69c9ed150bf8a284404cc149ddaa27d6) Signed-off-by: Michael Armbrust --- .../org/apache/spark/sql/json/JsonRDD.scala | 5 ++-- .../sql/types/util/DataTypeConversions.scala | 30 +++++++++++++++++++ .../org/apache/spark/sql/json/JsonSuite.scala | 7 +++++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 0f2dcdcacf0c..d9d7a3fea396 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.json import org.apache.spark.sql.catalyst.types.decimal.Decimal +import org.apache.spark.sql.types.util.DataTypeConversions import scala.collection.Map import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} @@ -378,7 +379,7 @@ private[sql] object JsonRDD extends Logging { private def toDate(value: Any): Date = { value match { // only support string as date - case value: java.lang.String => Date.valueOf(value) + case value: java.lang.String => new Date(DataTypeConversions.stringToTime(value).getTime) } } @@ -386,7 +387,7 @@ private[sql] object JsonRDD extends Logging { value match { case value: java.lang.Integer => new Timestamp(value.asInstanceOf[Int].toLong) case value: java.lang.Long => new Timestamp(value) - case value: java.lang.String => Timestamp.valueOf(value) + case value: java.lang.String => toTimestamp(DataTypeConversions.stringToTime(value).getTime) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala index 9aad7b3df4ee..d4258156f18f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.types.util +import java.text.SimpleDateFormat + import scala.collection.JavaConverters._ import org.apache.spark.sql._ @@ -129,6 +131,34 @@ protected[sql] object DataTypeConversions { StructType(structType.getFields.map(asScalaStructField)) } + def stringToTime(s: String): java.util.Date = { + if (!s.contains('T')) { + // JDBC escape string + if (s.contains(' ')) { + java.sql.Timestamp.valueOf(s) + } else { + java.sql.Date.valueOf(s) + } + } else if (s.endsWith("Z")) { + // this is zero timezone of ISO8601 + stringToTime(s.substring(0, s.length - 1) + "GMT-00:00") + } else if (s.indexOf("GMT") == -1) { + // timezone with ISO8601 + val inset = "+00.00".length + val s0 = s.substring(0, s.length - inset) + val s1 = s.substring(s.length - inset, s.length) + if (s0.substring(s0.lastIndexOf(':')).contains('.')) { + stringToTime(s0 + "GMT" + s1) + } else { + stringToTime(s0 + ".0GMT" + s1) + } + } else { + // ISO8601 with GMT insert + val ISO8601GMT: SimpleDateFormat = new SimpleDateFormat( "yyyy-MM-dd'T'HH:mm:ss.SSSz" ) + ISO8601GMT.parse(s) + } + } + /** Converts Java objects to catalyst rows / types */ def convertJavaToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { case (obj, udt: UserDefinedType[_]) => ScalaReflection.convertToCatalyst(obj, udt) // Scala type diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index cade244f7ac3..f8ca2c773d9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -66,6 +66,13 @@ class JsonSuite extends QueryTest { val strDate = "2014-10-15" checkTypePromotion(Date.valueOf(strDate), enforceCorrectType(strDate, DateType)) + + val ISO8601Time1 = "1970-01-01T01:00:01.0Z" + checkTypePromotion(new Timestamp(3601000), enforceCorrectType(ISO8601Time1, TimestampType)) + checkTypePromotion(new Date(3601000), enforceCorrectType(ISO8601Time1, DateType)) + val ISO8601Time2 = "1970-01-01T02:00:01-01:00" + checkTypePromotion(new Timestamp(10801000), enforceCorrectType(ISO8601Time2, TimestampType)) + checkTypePromotion(new Date(10801000), enforceCorrectType(ISO8601Time2, DateType)) } test("Get compatible type") { From 07ba50f7eff3db68f120d979a5f0ca37cb2a886e Mon Sep 17 00:00:00 2001 From: surq Date: Mon, 10 Nov 2014 17:37:16 -0800 Subject: [PATCH 078/652] [SPARK-3954][Streaming] Optimization to FileInputDStream about convert files to RDDS there are 3 loops with files sequence in spark source. loops files sequence: 1.files.map(...) 2.files.zip(fileRDDs) 3.files-size.foreach It's will very time consuming when lots of files.So I do the following correction: 3 loops with files sequence => only one loop Author: surq Closes #2811 from surq/SPARK-3954 and squashes the following commits: 321bbe8 [surq] updated the code style.The style from [for...yield]to [files.map(file=>{})] 88a2c20 [surq] Merge branch 'master' of https://github.com/apache/spark into SPARK-3954 178066f [surq] modify code's style. [Exceeds 100 columns] 626ef97 [surq] remove redundant import(ArrayBuffer) 739341f [surq] promote the speed of convert files to RDDS (cherry picked from commit ce6ed2abd14de26b9ceaa415e9a42fbb1338f5fa) Signed-off-by: Tathagata Das --- .../apache/spark/streaming/dstream/FileInputDStream.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 8152b7542ac5..55d6cf6a783e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -120,14 +120,15 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas /** Generate one RDD from an array of files */ private def filesToRDD(files: Seq[String]): RDD[(K, V)] = { - val fileRDDs = files.map(file => context.sparkContext.newAPIHadoopFile[K, V, F](file)) - files.zip(fileRDDs).foreach { case (file, rdd) => { + val fileRDDs = files.map(file =>{ + val rdd = context.sparkContext.newAPIHadoopFile[K, V, F](file) if (rdd.partitions.size == 0) { logError("File " + file + " has no data in it. Spark Streaming can only ingest " + "files that have been \"moved\" to the directory assigned to the file stream. " + "Refer to the streaming programming guide for more details.") } - }} + rdd + }) new UnionRDD(context.sparkContext, fileRDDs) } From 50c02d68a7fbc9e91c01fea4997846f46f7ea910 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Mon, 10 Nov 2014 17:46:05 -0800 Subject: [PATCH 079/652] [SPARK-4274] [SQL] Fix NPE in printing the details of the query plan Author: Cheng Hao Closes #3139 from chenghao-intel/comparison_test and squashes the following commits: f5d7146 [Cheng Hao] avoid exception in printing the codegen enabled (cherry picked from commit c764d0ac1c6410ca2dd2558cb6bcbe8ad5f02481) Signed-off-by: Michael Armbrust --- sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 84eaf401f240..31cc4170aa86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -444,7 +444,7 @@ class SQLContext(@transient val sparkContext: SparkContext) |${stringOrError(optimizedPlan)} |== Physical Plan == |${stringOrError(executedPlan)} - |Code Generation: ${executedPlan.codegenEnabled} + |Code Generation: ${stringOrError(executedPlan.codegenEnabled)} |== RDD == """.stripMargin.trim } From e725cab66441a5de4f32630c865d0fcb25f8aed2 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Mon, 10 Nov 2014 19:31:52 -0800 Subject: [PATCH 080/652] [SPARK-3649] Remove GraphX custom serializers As [reported][1] on the mailing list, GraphX throws ``` java.lang.ClassCastException: java.lang.Long cannot be cast to scala.Tuple2 at org.apache.spark.graphx.impl.RoutingTableMessageSerializer$$anon$1$$anon$2.writeObject(Serializers.scala:39) at org.apache.spark.storage.DiskBlockObjectWriter.write(BlockObjectWriter.scala:195) at org.apache.spark.util.collection.ExternalSorter.spillToMergeableFile(ExternalSorter.scala:329) ``` when sort-based shuffle attempts to spill to disk. This is because GraphX defines custom serializers for shuffling pair RDDs that assume Spark will always serialize the entire pair object rather than breaking it up into its components. However, the spill code path in sort-based shuffle [violates this assumption][2]. GraphX uses the custom serializers to compress vertex ID keys using variable-length integer encoding. However, since the serializer can no longer rely on the key and value being serialized and deserialized together, performing such encoding would either require writing a tag byte (costly) or maintaining state in the serializer and assuming that serialization calls will alternate between key and value (fragile). Instead, this PR simply removes the custom serializers. This causes a **10% slowdown** (494 s to 543 s) and **16% increase in per-iteration communication** (2176 MB to 2518 MB) for PageRank (averages across 3 trials, 10 iterations per trial, uk-2007-05 graph, 16 r3.2xlarge nodes). [1]: http://apache-spark-user-list.1001560.n3.nabble.com/java-lang-ClassCastException-java-lang-Long-cannot-be-cast-to-scala-Tuple2-td13926.html#a14501 [2]: https://github.com/apache/spark/blob/f9d6220c792b779be385f3022d146911a22c2130/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala#L329 Author: Ankur Dave Closes #2503 from ankurdave/SPARK-3649 and squashes the following commits: a49c2ad [Ankur Dave] [SPARK-3649] Remove GraphX custom serializers (cherry picked from commit 300887bd76c5018bfe396c5d47443be251368359) Signed-off-by: Reynold Xin --- .../org/apache/spark/graphx/VertexRDD.scala | 14 +- .../graphx/impl/MessageToPartition.scala | 50 --- .../graphx/impl/RoutingTablePartition.scala | 18 - .../spark/graphx/impl/Serializers.scala | 369 ------------------ .../apache/spark/graphx/SerializerSuite.scala | 122 ------ 5 files changed, 6 insertions(+), 567 deletions(-) delete mode 100644 graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala delete mode 100644 graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala delete mode 100644 graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala index 2c8b245955d1..12216d9d33d6 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala @@ -27,8 +27,6 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.graphx.impl.RoutingTablePartition import org.apache.spark.graphx.impl.ShippableVertexPartition import org.apache.spark.graphx.impl.VertexAttributeBlock -import org.apache.spark.graphx.impl.RoutingTableMessageRDDFunctions._ -import org.apache.spark.graphx.impl.VertexRDDFunctions._ /** * Extends `RDD[(VertexId, VD)]` by ensuring that there is only one entry for each vertex and by @@ -233,7 +231,7 @@ class VertexRDD[@specialized VD: ClassTag]( case _ => this.withPartitionsRDD[VD3]( partitionsRDD.zipPartitions( - other.copartitionWithVertices(this.partitioner.get), preservesPartitioning = true) { + other.partitionBy(this.partitioner.get), preservesPartitioning = true) { (partIter, msgs) => partIter.map(_.leftJoin(msgs)(f)) } ) @@ -277,7 +275,7 @@ class VertexRDD[@specialized VD: ClassTag]( case _ => this.withPartitionsRDD( partitionsRDD.zipPartitions( - other.copartitionWithVertices(this.partitioner.get), preservesPartitioning = true) { + other.partitionBy(this.partitioner.get), preservesPartitioning = true) { (partIter, msgs) => partIter.map(_.innerJoin(msgs)(f)) } ) @@ -297,7 +295,7 @@ class VertexRDD[@specialized VD: ClassTag]( */ def aggregateUsingIndex[VD2: ClassTag]( messages: RDD[(VertexId, VD2)], reduceFunc: (VD2, VD2) => VD2): VertexRDD[VD2] = { - val shuffled = messages.copartitionWithVertices(this.partitioner.get) + val shuffled = messages.partitionBy(this.partitioner.get) val parts = partitionsRDD.zipPartitions(shuffled, true) { (thisIter, msgIter) => thisIter.map(_.aggregateUsingIndex(msgIter, reduceFunc)) } @@ -371,7 +369,7 @@ object VertexRDD { def apply[VD: ClassTag](vertices: RDD[(VertexId, VD)]): VertexRDD[VD] = { val vPartitioned: RDD[(VertexId, VD)] = vertices.partitioner match { case Some(p) => vertices - case None => vertices.copartitionWithVertices(new HashPartitioner(vertices.partitions.size)) + case None => vertices.partitionBy(new HashPartitioner(vertices.partitions.size)) } val vertexPartitions = vPartitioned.mapPartitions( iter => Iterator(ShippableVertexPartition(iter)), @@ -412,7 +410,7 @@ object VertexRDD { ): VertexRDD[VD] = { val vPartitioned: RDD[(VertexId, VD)] = vertices.partitioner match { case Some(p) => vertices - case None => vertices.copartitionWithVertices(new HashPartitioner(vertices.partitions.size)) + case None => vertices.partitionBy(new HashPartitioner(vertices.partitions.size)) } val routingTables = createRoutingTables(edges, vPartitioned.partitioner.get) val vertexPartitions = vPartitioned.zipPartitions(routingTables, preservesPartitioning = true) { @@ -454,7 +452,7 @@ object VertexRDD { .setName("VertexRDD.createRoutingTables - vid2pid (aggregation)") val numEdgePartitions = edges.partitions.size - vid2pid.copartitionWithVertices(vertexPartitioner).mapPartitions( + vid2pid.partitionBy(vertexPartitioner).mapPartitions( iter => Iterator(RoutingTablePartition.fromMsgs(numEdgePartitions, iter)), preservesPartitioning = true) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala deleted file mode 100644 index 714f3b81c9da..000000000000 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.graphx.impl - -import scala.language.implicitConversions -import scala.reflect.{classTag, ClassTag} - -import org.apache.spark.Partitioner -import org.apache.spark.graphx.{PartitionID, VertexId} -import org.apache.spark.rdd.{ShuffledRDD, RDD} - - -private[graphx] -class VertexRDDFunctions[VD: ClassTag](self: RDD[(VertexId, VD)]) { - def copartitionWithVertices(partitioner: Partitioner): RDD[(VertexId, VD)] = { - val rdd = new ShuffledRDD[VertexId, VD, VD](self, partitioner) - - // Set a custom serializer if the data is of int or double type. - if (classTag[VD] == ClassTag.Int) { - rdd.setSerializer(new IntAggMsgSerializer) - } else if (classTag[VD] == ClassTag.Long) { - rdd.setSerializer(new LongAggMsgSerializer) - } else if (classTag[VD] == ClassTag.Double) { - rdd.setSerializer(new DoubleAggMsgSerializer) - } - rdd - } -} - -private[graphx] -object VertexRDDFunctions { - implicit def rdd2VertexRDDFunctions[VD: ClassTag](rdd: RDD[(VertexId, VD)]) = { - new VertexRDDFunctions(rdd) - } -} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala index b27485953f71..7a7fa91aadfe 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala @@ -29,24 +29,6 @@ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage -private[graphx] -class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) { - /** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */ - def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = { - new ShuffledRDD[VertexId, Int, Int]( - self, partitioner).setSerializer(new RoutingTableMessageSerializer) - } -} - -private[graphx] -object RoutingTableMessageRDDFunctions { - import scala.language.implicitConversions - - implicit def rdd2RoutingTableMessageRDDFunctions(rdd: RDD[RoutingTableMessage]) = { - new RoutingTableMessageRDDFunctions(rdd) - } -} - private[graphx] object RoutingTablePartition { /** diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala deleted file mode 100644 index 3909efcdfc99..000000000000 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala +++ /dev/null @@ -1,369 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.graphx.impl - -import scala.language.existentials - -import java.io.{EOFException, InputStream, OutputStream} -import java.nio.ByteBuffer - -import scala.reflect.ClassTag - -import org.apache.spark.serializer._ - -import org.apache.spark.graphx._ -import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage - -private[graphx] -class RoutingTableMessageSerializer extends Serializer with Serializable { - override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { - - override def serializeStream(s: OutputStream): SerializationStream = - new ShuffleSerializationStream(s) { - def writeObject[T: ClassTag](t: T): SerializationStream = { - val msg = t.asInstanceOf[RoutingTableMessage] - writeVarLong(msg._1, optimizePositive = false) - writeInt(msg._2) - this - } - } - - override def deserializeStream(s: InputStream): DeserializationStream = - new ShuffleDeserializationStream(s) { - override def readObject[T: ClassTag](): T = { - val a = readVarLong(optimizePositive = false) - val b = readInt() - (a, b).asInstanceOf[T] - } - } - } -} - -private[graphx] -class VertexIdMsgSerializer extends Serializer with Serializable { - override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { - - override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T: ClassTag](t: T) = { - val msg = t.asInstanceOf[(VertexId, _)] - writeVarLong(msg._1, optimizePositive = false) - this - } - } - - override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - override def readObject[T: ClassTag](): T = { - (readVarLong(optimizePositive = false), null).asInstanceOf[T] - } - } - } -} - -/** A special shuffle serializer for AggregationMessage[Int]. */ -private[graphx] -class IntAggMsgSerializer extends Serializer with Serializable { - override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { - - override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T: ClassTag](t: T) = { - val msg = t.asInstanceOf[(VertexId, Int)] - writeVarLong(msg._1, optimizePositive = false) - writeUnsignedVarInt(msg._2) - this - } - } - - override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - override def readObject[T: ClassTag](): T = { - val a = readVarLong(optimizePositive = false) - val b = readUnsignedVarInt() - (a, b).asInstanceOf[T] - } - } - } -} - -/** A special shuffle serializer for AggregationMessage[Long]. */ -private[graphx] -class LongAggMsgSerializer extends Serializer with Serializable { - override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { - - override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T: ClassTag](t: T) = { - val msg = t.asInstanceOf[(VertexId, Long)] - writeVarLong(msg._1, optimizePositive = false) - writeVarLong(msg._2, optimizePositive = true) - this - } - } - - override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - override def readObject[T: ClassTag](): T = { - val a = readVarLong(optimizePositive = false) - val b = readVarLong(optimizePositive = true) - (a, b).asInstanceOf[T] - } - } - } -} - -/** A special shuffle serializer for AggregationMessage[Double]. */ -private[graphx] -class DoubleAggMsgSerializer extends Serializer with Serializable { - override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { - - override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T: ClassTag](t: T) = { - val msg = t.asInstanceOf[(VertexId, Double)] - writeVarLong(msg._1, optimizePositive = false) - writeDouble(msg._2) - this - } - } - - override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - def readObject[T: ClassTag](): T = { - val a = readVarLong(optimizePositive = false) - val b = readDouble() - (a, b).asInstanceOf[T] - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Helper classes to shorten the implementation of those special serializers. -//////////////////////////////////////////////////////////////////////////////// - -private[graphx] -abstract class ShuffleSerializationStream(s: OutputStream) extends SerializationStream { - // The implementation should override this one. - def writeObject[T: ClassTag](t: T): SerializationStream - - def writeInt(v: Int) { - s.write(v >> 24) - s.write(v >> 16) - s.write(v >> 8) - s.write(v) - } - - def writeUnsignedVarInt(value: Int) { - if ((value >>> 7) == 0) { - s.write(value.toInt) - } else if ((value >>> 14) == 0) { - s.write((value & 0x7F) | 0x80) - s.write(value >>> 7) - } else if ((value >>> 21) == 0) { - s.write((value & 0x7F) | 0x80) - s.write(value >>> 7 | 0x80) - s.write(value >>> 14) - } else if ((value >>> 28) == 0) { - s.write((value & 0x7F) | 0x80) - s.write(value >>> 7 | 0x80) - s.write(value >>> 14 | 0x80) - s.write(value >>> 21) - } else { - s.write((value & 0x7F) | 0x80) - s.write(value >>> 7 | 0x80) - s.write(value >>> 14 | 0x80) - s.write(value >>> 21 | 0x80) - s.write(value >>> 28) - } - } - - def writeVarLong(value: Long, optimizePositive: Boolean) { - val v = if (!optimizePositive) (value << 1) ^ (value >> 63) else value - if ((v >>> 7) == 0) { - s.write(v.toInt) - } else if ((v >>> 14) == 0) { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7).toInt) - } else if ((v >>> 21) == 0) { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7 | 0x80).toInt) - s.write((v >>> 14).toInt) - } else if ((v >>> 28) == 0) { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7 | 0x80).toInt) - s.write((v >>> 14 | 0x80).toInt) - s.write((v >>> 21).toInt) - } else if ((v >>> 35) == 0) { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7 | 0x80).toInt) - s.write((v >>> 14 | 0x80).toInt) - s.write((v >>> 21 | 0x80).toInt) - s.write((v >>> 28).toInt) - } else if ((v >>> 42) == 0) { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7 | 0x80).toInt) - s.write((v >>> 14 | 0x80).toInt) - s.write((v >>> 21 | 0x80).toInt) - s.write((v >>> 28 | 0x80).toInt) - s.write((v >>> 35).toInt) - } else if ((v >>> 49) == 0) { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7 | 0x80).toInt) - s.write((v >>> 14 | 0x80).toInt) - s.write((v >>> 21 | 0x80).toInt) - s.write((v >>> 28 | 0x80).toInt) - s.write((v >>> 35 | 0x80).toInt) - s.write((v >>> 42).toInt) - } else if ((v >>> 56) == 0) { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7 | 0x80).toInt) - s.write((v >>> 14 | 0x80).toInt) - s.write((v >>> 21 | 0x80).toInt) - s.write((v >>> 28 | 0x80).toInt) - s.write((v >>> 35 | 0x80).toInt) - s.write((v >>> 42 | 0x80).toInt) - s.write((v >>> 49).toInt) - } else { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7 | 0x80).toInt) - s.write((v >>> 14 | 0x80).toInt) - s.write((v >>> 21 | 0x80).toInt) - s.write((v >>> 28 | 0x80).toInt) - s.write((v >>> 35 | 0x80).toInt) - s.write((v >>> 42 | 0x80).toInt) - s.write((v >>> 49 | 0x80).toInt) - s.write((v >>> 56).toInt) - } - } - - def writeLong(v: Long) { - s.write((v >>> 56).toInt) - s.write((v >>> 48).toInt) - s.write((v >>> 40).toInt) - s.write((v >>> 32).toInt) - s.write((v >>> 24).toInt) - s.write((v >>> 16).toInt) - s.write((v >>> 8).toInt) - s.write(v.toInt) - } - - def writeDouble(v: Double): Unit = writeLong(java.lang.Double.doubleToLongBits(v)) - - override def flush(): Unit = s.flush() - - override def close(): Unit = s.close() -} - -private[graphx] -abstract class ShuffleDeserializationStream(s: InputStream) extends DeserializationStream { - // The implementation should override this one. - def readObject[T: ClassTag](): T - - def readInt(): Int = { - val first = s.read() - if (first < 0) throw new EOFException - (first & 0xFF) << 24 | (s.read() & 0xFF) << 16 | (s.read() & 0xFF) << 8 | (s.read() & 0xFF) - } - - def readUnsignedVarInt(): Int = { - var value: Int = 0 - var i: Int = 0 - def readOrThrow(): Int = { - val in = s.read() - if (in < 0) throw new EOFException - in & 0xFF - } - var b: Int = readOrThrow() - while ((b & 0x80) != 0) { - value |= (b & 0x7F) << i - i += 7 - if (i > 35) throw new IllegalArgumentException("Variable length quantity is too long") - b = readOrThrow() - } - value | (b << i) - } - - def readVarLong(optimizePositive: Boolean): Long = { - def readOrThrow(): Int = { - val in = s.read() - if (in < 0) throw new EOFException - in & 0xFF - } - var b = readOrThrow() - var ret: Long = b & 0x7F - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= (b & 0x7F) << 7 - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= (b & 0x7F) << 14 - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= (b & 0x7F) << 21 - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= (b & 0x7F).toLong << 28 - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= (b & 0x7F).toLong << 35 - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= (b & 0x7F).toLong << 42 - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= (b & 0x7F).toLong << 49 - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= b.toLong << 56 - } - } - } - } - } - } - } - } - if (!optimizePositive) (ret >>> 1) ^ -(ret & 1) else ret - } - - def readLong(): Long = { - val first = s.read() - if (first < 0) throw new EOFException() - (first.toLong << 56) | - (s.read() & 0xFF).toLong << 48 | - (s.read() & 0xFF).toLong << 40 | - (s.read() & 0xFF).toLong << 32 | - (s.read() & 0xFF).toLong << 24 | - (s.read() & 0xFF) << 16 | - (s.read() & 0xFF) << 8 | - (s.read() & 0xFF) - } - - def readDouble(): Double = java.lang.Double.longBitsToDouble(readLong()) - - override def close(): Unit = s.close() -} - -private[graphx] sealed trait ShuffleSerializerInstance extends SerializerInstance { - - override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException - - override def deserialize[T: ClassTag](bytes: ByteBuffer): T = - throw new UnsupportedOperationException - - override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = - throw new UnsupportedOperationException - - // The implementation should override the following two. - override def serializeStream(s: OutputStream): SerializationStream - override def deserializeStream(s: InputStream): DeserializationStream -} diff --git a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala deleted file mode 100644 index 864cb1fdf002..000000000000 --- a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.graphx - -import java.io.{EOFException, ByteArrayInputStream, ByteArrayOutputStream} - -import scala.util.Random -import scala.reflect.ClassTag - -import org.scalatest.FunSuite - -import org.apache.spark._ -import org.apache.spark.graphx.impl._ -import org.apache.spark.serializer.SerializationStream - - -class SerializerSuite extends FunSuite with LocalSparkContext { - - test("IntAggMsgSerializer") { - val outMsg = (4: VertexId, 5) - val bout = new ByteArrayOutputStream - val outStrm = new IntAggMsgSerializer().newInstance().serializeStream(bout) - outStrm.writeObject(outMsg) - outStrm.writeObject(outMsg) - bout.flush() - val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new IntAggMsgSerializer().newInstance().deserializeStream(bin) - val inMsg1: (VertexId, Int) = inStrm.readObject() - val inMsg2: (VertexId, Int) = inStrm.readObject() - assert(outMsg === inMsg1) - assert(outMsg === inMsg2) - - intercept[EOFException] { - inStrm.readObject() - } - } - - test("LongAggMsgSerializer") { - val outMsg = (4: VertexId, 1L << 32) - val bout = new ByteArrayOutputStream - val outStrm = new LongAggMsgSerializer().newInstance().serializeStream(bout) - outStrm.writeObject(outMsg) - outStrm.writeObject(outMsg) - bout.flush() - val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new LongAggMsgSerializer().newInstance().deserializeStream(bin) - val inMsg1: (VertexId, Long) = inStrm.readObject() - val inMsg2: (VertexId, Long) = inStrm.readObject() - assert(outMsg === inMsg1) - assert(outMsg === inMsg2) - - intercept[EOFException] { - inStrm.readObject() - } - } - - test("DoubleAggMsgSerializer") { - val outMsg = (4: VertexId, 5.0) - val bout = new ByteArrayOutputStream - val outStrm = new DoubleAggMsgSerializer().newInstance().serializeStream(bout) - outStrm.writeObject(outMsg) - outStrm.writeObject(outMsg) - bout.flush() - val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new DoubleAggMsgSerializer().newInstance().deserializeStream(bin) - val inMsg1: (VertexId, Double) = inStrm.readObject() - val inMsg2: (VertexId, Double) = inStrm.readObject() - assert(outMsg === inMsg1) - assert(outMsg === inMsg2) - - intercept[EOFException] { - inStrm.readObject() - } - } - - test("variable long encoding") { - def testVarLongEncoding(v: Long, optimizePositive: Boolean) { - val bout = new ByteArrayOutputStream - val stream = new ShuffleSerializationStream(bout) { - def writeObject[T: ClassTag](t: T): SerializationStream = { - writeVarLong(t.asInstanceOf[Long], optimizePositive = optimizePositive) - this - } - } - stream.writeObject(v) - - val bin = new ByteArrayInputStream(bout.toByteArray) - val dstream = new ShuffleDeserializationStream(bin) { - def readObject[T: ClassTag](): T = { - readVarLong(optimizePositive).asInstanceOf[T] - } - } - val read = dstream.readObject[Long]() - assert(read === v) - } - - // Test all variable encoding code path (each branch uses 7 bits, i.e. 1L << 7 difference) - val d = Random.nextLong() % 128 - Seq[Long](0, 1L << 0 + d, 1L << 7 + d, 1L << 14 + d, 1L << 21 + d, 1L << 28 + d, 1L << 35 + d, - 1L << 42 + d, 1L << 49 + d, 1L << 56 + d, 1L << 63 + d).foreach { number => - testVarLongEncoding(number, optimizePositive = false) - testVarLongEncoding(number, optimizePositive = true) - testVarLongEncoding(-number, optimizePositive = false) - testVarLongEncoding(-number, optimizePositive = true) - } - } -} From 4eeaf3395a885b0a9ef79c31b720969155b0b7af Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Mon, 10 Nov 2014 22:18:00 -0800 Subject: [PATCH 081/652] [SPARK-4330][Doc] Link to proper URL for YARN overview In running-on-yarn.md, a link to YARN overview is here. But the URL is to YARN alpha's. It should be stable's. Author: Kousuke Saruta Closes #3196 from sarutak/SPARK-4330 and squashes the following commits: 30baa21 [Kousuke Saruta] Fixed running-on-yarn.md to point proper URL for YARN (cherry picked from commit 3c07b8f08240bafcdff5d174989fb433f4bc80b6) Signed-off-by: Matei Zaharia --- docs/running-on-yarn.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 695813a2ba88..2f7e4981e5bb 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -4,7 +4,7 @@ title: Running Spark on YARN --- Support for running on [YARN (Hadoop -NextGen)](http://hadoop.apache.org/docs/r2.0.2-alpha/hadoop-yarn/hadoop-yarn-site/YARN.html) +NextGen)](http://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/YARN.html) was added to Spark in version 0.6.0, and improved in subsequent releases. # Preparations From df8242c9b6307c085d4c1a7ec446b1701a7e7cde Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 10 Nov 2014 22:26:16 -0800 Subject: [PATCH 082/652] [SPARK-4324] [PySpark] [MLlib] support numpy.array for all MLlib API This PR check all of the existing Python MLlib API to make sure that numpy.array is supported as Vector (also RDD of numpy.array). It also improve some docstring and doctest. cc mateiz mengxr Author: Davies Liu Closes #3189 from davies/numpy and squashes the following commits: d5057c4 [Davies Liu] fix tests 6987611 [Davies Liu] support numpy.array for all MLlib API (cherry picked from commit 65083e93ddd552b7d3e4eb09f87c091ef2ae83a2) Signed-off-by: Xiangrui Meng --- python/pyspark/mllib/classification.py | 13 +++++--- python/pyspark/mllib/feature.py | 31 ++++++++++++++---- python/pyspark/mllib/random.py | 45 ++++++++++++++++++++++++-- python/pyspark/mllib/recommendation.py | 6 ++-- python/pyspark/mllib/regression.py | 15 ++++++--- python/pyspark/mllib/stat.py | 16 ++++++++- python/pyspark/mllib/util.py | 11 ++----- 7 files changed, 105 insertions(+), 32 deletions(-) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 297a2bf37d2c..5d90dddb5df1 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -62,6 +62,7 @@ class LogisticRegressionModel(LinearModel): """ def predict(self, x): + x = _convert_to_vector(x) margin = self.weights.dot(x) + self._intercept if margin > 0: prob = 1 / (1 + exp(-margin)) @@ -79,7 +80,7 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, """ Train a logistic regression model on the given data. - :param data: The training data. + :param data: The training data, an RDD of LabeledPoint. :param iterations: The number of iterations (default: 100). :param step: The step parameter used in SGD (default: 1.0). @@ -136,6 +137,7 @@ class SVMModel(LinearModel): """ def predict(self, x): + x = _convert_to_vector(x) margin = self.weights.dot(x) + self.intercept return 1 if margin >= 0 else 0 @@ -148,7 +150,7 @@ def train(cls, data, iterations=100, step=1.0, regParam=1.0, """ Train a support vector machine on the given data. - :param data: The training data. + :param data: The training data, an RDD of LabeledPoint. :param iterations: The number of iterations (default: 100). :param step: The step parameter used in SGD (default: 1.0). @@ -233,11 +235,12 @@ def train(cls, data, lambda_=1.0): classification. By making every vector a 0-1 vector, it can also be used as Bernoulli NB (U{http://tinyurl.com/p7c96j6}). - :param data: RDD of NumPy vectors, one per element, where the first - coordinate is the label and the rest is the feature vector - (e.g. a count vector). + :param data: RDD of LabeledPoint. :param lambda_: The smoothing parameter """ + first = data.first() + if not isinstance(first, LabeledPoint): + raise ValueError("`data` should be an RDD of LabeledPoint") labels, pi, theta = callMLlibFunc("trainNaiveBayes", data, lambda_) return NaiveBayesModel(labels.toArray(), pi.toArray(), numpy.array(theta)) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 44bf6f269d7a..9ec28079aef4 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -25,7 +25,7 @@ from pyspark import RDD, SparkContext from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper -from pyspark.mllib.linalg import Vectors +from pyspark.mllib.linalg import Vectors, _convert_to_vector __all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler', 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel'] @@ -81,12 +81,16 @@ def transform(self, vector): """ Applies unit length normalization on a vector. - :param vector: vector to be normalized. + :param vector: vector or RDD of vector to be normalized. :return: normalized vector. If the norm of the input is zero, it will return the input vector. """ sc = SparkContext._active_spark_context assert sc is not None, "SparkContext should be initialized first" + if isinstance(vector, RDD): + vector = vector.map(_convert_to_vector) + else: + vector = _convert_to_vector(vector) return callMLlibFunc("normalizeVector", self.p, vector) @@ -95,8 +99,12 @@ class JavaVectorTransformer(JavaModelWrapper, VectorTransformer): Wrapper for the model in JVM """ - def transform(self, dataset): - return self.call("transform", dataset) + def transform(self, vector): + if isinstance(vector, RDD): + vector = vector.map(_convert_to_vector) + else: + vector = _convert_to_vector(vector) + return self.call("transform", vector) class StandardScalerModel(JavaVectorTransformer): @@ -109,7 +117,7 @@ def transform(self, vector): """ Applies standardization transformation on a vector. - :param vector: Vector to be standardized. + :param vector: Vector or RDD of Vector to be standardized. :return: Standardized vector. If the variance of a column is zero, it will return default `0.0` for the column with zero variance. """ @@ -154,6 +162,7 @@ def fit(self, dataset): the transformation model. :return: a StandardScalarModel """ + dataset = dataset.map(_convert_to_vector) jmodel = callMLlibFunc("fitStandardScaler", self.withMean, self.withStd, dataset) return StandardScalerModel(jmodel) @@ -211,6 +220,8 @@ def transform(self, dataset): :param dataset: an RDD of term frequency vectors :return: an RDD of TF-IDF vectors """ + if not isinstance(dataset, RDD): + raise TypeError("dataset should be an RDD of term frequency vectors") return JavaVectorTransformer.transform(self, dataset) @@ -255,7 +266,9 @@ def fit(self, dataset): :param dataset: an RDD of term frequency vectors """ - jmodel = callMLlibFunc("fitIDF", self.minDocFreq, dataset) + if not isinstance(dataset, RDD): + raise TypeError("dataset should be an RDD of term frequency vectors") + jmodel = callMLlibFunc("fitIDF", self.minDocFreq, dataset.map(_convert_to_vector)) return IDFModel(jmodel) @@ -287,6 +300,8 @@ def findSynonyms(self, word, num): Note: local use only """ + if not isinstance(word, basestring): + word = _convert_to_vector(word) words, similarity = self.call("findSynonyms", word, num) return zip(words, similarity) @@ -374,9 +389,11 @@ def fit(self, data): """ Computes the vector representation of each word in vocabulary. - :param data: training data. RDD of subtype of Iterable[String] + :param data: training data. RDD of list of string :return: Word2VecModel instance """ + if not isinstance(data, RDD): + raise TypeError("data should be an RDD of list of string") jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize), float(self.learningRate), int(self.numPartitions), int(self.numIterations), long(self.seed)) diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 7eebfc6bcd89..cb4304f92152 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -52,6 +52,12 @@ def uniformRDD(sc, size, numPartitions=None, seed=None): C{RandomRDDs.uniformRDD(sc, n, p, seed)\ .map(lambda v: a + (b - a) * v)} + :param sc: SparkContext used to create the RDD. + :param size: Size of the RDD. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of float comprised of i.i.d. samples ~ `U(0.0, 1.0)`. + >>> x = RandomRDDs.uniformRDD(sc, 100).collect() >>> len(x) 100 @@ -76,6 +82,12 @@ def normalRDD(sc, size, numPartitions=None, seed=None): C{RandomRDDs.normal(sc, n, p, seed)\ .map(lambda v: mean + sigma * v)} + :param sc: SparkContext used to create the RDD. + :param size: Size of the RDD. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of float comprised of i.i.d. samples ~ N(0.0, 1.0). + >>> x = RandomRDDs.normalRDD(sc, 1000, seed=1L) >>> stats = x.stats() >>> stats.count() @@ -93,6 +105,13 @@ def poissonRDD(sc, mean, size, numPartitions=None, seed=None): Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. + :param sc: SparkContext used to create the RDD. + :param mean: Mean, or lambda, for the Poisson distribution. + :param size: Size of the RDD. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of float comprised of i.i.d. samples ~ Pois(mean). + >>> mean = 100.0 >>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=2L) >>> stats = x.stats() @@ -104,7 +123,7 @@ def poissonRDD(sc, mean, size, numPartitions=None, seed=None): >>> abs(stats.stdev() - sqrt(mean)) < 0.5 True """ - return callMLlibFunc("poissonRDD", sc._jsc, mean, size, numPartitions, seed) + return callMLlibFunc("poissonRDD", sc._jsc, float(mean), size, numPartitions, seed) @staticmethod @toArray @@ -113,6 +132,13 @@ def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): Generates an RDD comprised of vectors containing i.i.d. samples drawn from the uniform distribution U(0.0, 1.0). + :param sc: SparkContext used to create the RDD. + :param numRows: Number of Vectors in the RDD. + :param numCols: Number of elements in each Vector. + :param numPartitions: Number of partitions in the RDD. + :param seed: Seed for the RNG that generates the seed for the generator in each partition. + :return: RDD of Vector with vectors containing i.i.d samples ~ `U(0.0, 1.0)`. + >>> import numpy as np >>> mat = np.matrix(RandomRDDs.uniformVectorRDD(sc, 10, 10).collect()) >>> mat.shape @@ -131,6 +157,13 @@ def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): Generates an RDD comprised of vectors containing i.i.d. samples drawn from the standard normal distribution. + :param sc: SparkContext used to create the RDD. + :param numRows: Number of Vectors in the RDD. + :param numCols: Number of elements in each Vector. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of Vector with vectors containing i.i.d. samples ~ `N(0.0, 1.0)`. + >>> import numpy as np >>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1L).collect()) >>> mat.shape @@ -149,6 +182,14 @@ def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): Generates an RDD comprised of vectors containing i.i.d. samples drawn from the Poisson distribution with the input mean. + :param sc: SparkContext used to create the RDD. + :param mean: Mean, or lambda, for the Poisson distribution. + :param numRows: Number of Vectors in the RDD. + :param numCols: Number of elements in each Vector. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`) + :param seed: Random seed (default: a random long integer). + :return: RDD of Vector with vectors containing i.i.d. samples ~ Pois(mean). + >>> import numpy as np >>> mean = 100.0 >>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1L) @@ -161,7 +202,7 @@ def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): >>> abs(mat.std() - sqrt(mean)) < 0.5 True """ - return callMLlibFunc("poissonVectorRDD", sc._jsc, mean, numRows, numCols, + return callMLlibFunc("poissonVectorRDD", sc._jsc, float(mean), numRows, numCols, numPartitions, seed) diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index e26b152e0cdf..41bbd9a779c7 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -32,7 +32,7 @@ def __reduce__(self): return Rating, (self.user, self.product, self.rating) def __repr__(self): - return "Rating(%d, %d, %d)" % (self.user, self.product, self.rating) + return "Rating(%d, %d, %s)" % (self.user, self.product, self.rating) class MatrixFactorizationModel(JavaModelWrapper): @@ -51,7 +51,7 @@ class MatrixFactorizationModel(JavaModelWrapper): >>> testset = sc.parallelize([(1, 2), (1, 1)]) >>> model = ALS.train(ratings, 1, seed=10) >>> model.predictAll(testset).collect() - [Rating(1, 1, 1), Rating(1, 2, 1)] + [Rating(1, 1, 1.0471...), Rating(1, 2, 1.9679...)] >>> model = ALS.train(ratings, 4, seed=10) >>> model.userFeatures().collect() @@ -79,7 +79,7 @@ class MatrixFactorizationModel(JavaModelWrapper): 0.4473... """ def predict(self, user, product): - return self._java_model.predict(user, product) + return self._java_model.predict(int(user), int(product)) def predictAll(self, user_product): assert isinstance(user_product, RDD), "user_product should be RDD of (user, product)" diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 43c1a2fc101d..66e25a48dfa7 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -36,7 +36,7 @@ class LabeledPoint(object): """ def __init__(self, label, features): - self.label = label + self.label = float(label) self.features = _convert_to_vector(features) def __reduce__(self): @@ -46,7 +46,7 @@ def __str__(self): return "(" + ",".join((str(self.label), str(self.features))) + ")" def __repr__(self): - return "LabeledPoint(" + ",".join((repr(self.label), repr(self.features))) + ")" + return "LabeledPoint(%s, %s)" % (self.label, self.features) class LinearModel(object): @@ -55,7 +55,7 @@ class LinearModel(object): def __init__(self, weights, intercept): self._coeff = _convert_to_vector(weights) - self._intercept = intercept + self._intercept = float(intercept) @property def weights(self): @@ -66,7 +66,7 @@ def intercept(self): return self._intercept def __repr__(self): - return "(weights=%s, intercept=%s)" % (self._coeff, self._intercept) + return "(weights=%s, intercept=%r)" % (self._coeff, self._intercept) class LinearRegressionModelBase(LinearModel): @@ -85,6 +85,7 @@ def predict(self, x): Predict the value of the dependent variable given a vector x containing values for the independent variables. """ + x = _convert_to_vector(x) return self.weights.dot(x) + self.intercept @@ -124,6 +125,9 @@ class LinearRegressionModel(LinearRegressionModelBase): # return the result of a call to the appropriate JVM stub. # _regression_train_wrapper is responsible for setup and error checking. def _regression_train_wrapper(train_func, modelClass, data, initial_weights): + first = data.first() + if not isinstance(first, LabeledPoint): + raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first) initial_weights = initial_weights or [0.0] * len(data.first().features) weights, intercept = train_func(_to_java_object_rdd(data, cache=True), _convert_to_vector(initial_weights)) @@ -264,7 +268,8 @@ def train(rdd, i): def _test(): import doctest from pyspark import SparkContext - globs = globals().copy() + import pyspark.mllib.regression + globs = pyspark.mllib.regression.__dict__.copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index 0700f8a8e5a8..1980f5b03f43 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -22,6 +22,7 @@ from pyspark import RDD from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper from pyspark.mllib.linalg import Matrix, _convert_to_vector +from pyspark.mllib.regression import LabeledPoint __all__ = ['MultivariateStatisticalSummary', 'ChiSqTestResult', 'Statistics'] @@ -107,6 +108,11 @@ def colStats(rdd): """ Computes column-wise summary statistics for the input RDD[Vector]. + :param rdd: an RDD[Vector] for which column-wise summary statistics + are to be computed. + :return: :class:`MultivariateStatisticalSummary` object containing + column-wise summary statistics. + >>> from pyspark.mllib.linalg import Vectors >>> rdd = sc.parallelize([Vectors.dense([2, 0, 0, -2]), ... Vectors.dense([4, 5, 0, 3]), @@ -140,6 +146,13 @@ def corr(x, y=None, method=None): to specify the method to be used for single RDD inout. If two RDDs of floats are passed in, a single float is returned. + :param x: an RDD of vector for which the correlation matrix is to be computed, + or an RDD of float of the same cardinality as y when y is specified. + :param y: an RDD of float of the same cardinality as x. + :param method: String specifying the method to use for computing correlation. + Supported: `pearson` (default), `spearman` + :return: Correlation matrix comparing columns in x. + >>> x = sc.parallelize([1.0, 0.0, -2.0], 2) >>> y = sc.parallelize([4.0, 5.0, 3.0], 2) >>> zeros = sc.parallelize([0.0, 0.0, 0.0], 2) @@ -242,7 +255,6 @@ def chiSqTest(observed, expected=None): >>> print round(chi.statistic, 4) 21.9958 - >>> from pyspark.mllib.regression import LabeledPoint >>> data = [LabeledPoint(0.0, Vectors.dense([0.5, 10.0])), ... LabeledPoint(0.0, Vectors.dense([1.5, 20.0])), ... LabeledPoint(1.0, Vectors.dense([1.5, 30.0])), @@ -257,6 +269,8 @@ def chiSqTest(observed, expected=None): 1.5 """ if isinstance(observed, RDD): + if not isinstance(observed.first(), LabeledPoint): + raise ValueError("observed should be an RDD of LabeledPoint") jmodels = callMLlibFunc("chiSqTest", observed) return [ChiSqTestResult(m) for m in jmodels] diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 96aef8f510fa..4ed978b45409 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -161,15 +161,8 @@ def loadLabeledPoints(sc, path, minPartitions=None): >>> tempFile = NamedTemporaryFile(delete=True) >>> tempFile.close() >>> sc.parallelize(examples, 1).saveAsTextFile(tempFile.name) - >>> loaded = MLUtils.loadLabeledPoints(sc, tempFile.name).collect() - >>> type(loaded[0]) == LabeledPoint - True - >>> print examples[0] - (1.1,(3,[0,2],[-1.23,4.56e-07])) - >>> type(examples[1]) == LabeledPoint - True - >>> print examples[1] - (0.0,[1.01,2.02,3.03]) + >>> MLUtils.loadLabeledPoints(sc, tempFile.name).collect() + [LabeledPoint(1.1, (3,[0,2],[-1.23,4.56e-07])), LabeledPoint(0.0, [1.01,2.02,3.03])] """ minPartitions = minPartitions or min(sc.defaultParallelism, 2) return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions) From e9d009dc348bc06198ed2c9e03f1ba870401e6df Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 11 Nov 2014 00:25:31 -0800 Subject: [PATCH 083/652] [SPARK-4307] Initialize FileDescriptor lazily in FileRegion. Netty's DefaultFileRegion requires a FileDescriptor in its constructor, which means we need to have a opened file handle. In super large workloads, this could lead to too many open files due to the way these file descriptors are cleaned. This pull request creates a new LazyFileRegion that initializes the FileDescriptor when we are sending data for the first time. Author: Reynold Xin Author: Reynold Xin Closes #3172 from rxin/lazyFD and squashes the following commits: 0bdcdc6 [Reynold Xin] Added reference to Netty's DefaultFileRegion d4564ae [Reynold Xin] Added SparkConf to the ctor argument of IndexShuffleBlockManager. 6ed369e [Reynold Xin] Code review feedback. 04cddc8 [Reynold Xin] [SPARK-4307] Initialize FileDescriptor lazily in FileRegion. (cherry picked from commit ef29a9a9aa85468869eb67ca67b66c65f508d0ee) Signed-off-by: Aaron Davidson --- .../StandaloneWorkerShuffleService.scala | 2 +- .../shuffle/FileShuffleBlockManager.scala | 8 +- .../shuffle/IndexShuffleBlockManager.scala | 8 +- .../shuffle/sort/SortShuffleManager.scala | 2 +- .../spark/ExternalShuffleServiceSuite.scala | 2 +- .../buffer/FileSegmentManagedBuffer.java | 23 ++-- .../spark/network/buffer/LazyFileRegion.java | 111 ++++++++++++++++++ .../spark/network/util/TransportConf.java | 17 +++ .../network/ChunkFetchIntegrationSuite.java | 9 +- .../shuffle/ExternalShuffleBlockHandler.java | 5 +- .../shuffle/ExternalShuffleBlockManager.java | 13 +- .../ExternalShuffleBlockManagerSuite.java | 10 +- .../shuffle/ExternalShuffleCleanupSuite.java | 13 +- .../ExternalShuffleIntegrationSuite.java | 2 +- .../shuffle/ExternalShuffleSecuritySuite.java | 2 +- .../network/yarn/YarnShuffleService.java | 4 +- 16 files changed, 191 insertions(+), 40 deletions(-) create mode 100644 network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala index 88118e283774..d044e1d01d42 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala @@ -40,7 +40,7 @@ class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: Secu private val useSasl: Boolean = securityManager.isAuthenticationEnabled() private val transportConf = SparkTransportConf.fromSparkConf(sparkConf) - private val blockHandler = new ExternalShuffleBlockHandler() + private val blockHandler = new ExternalShuffleBlockHandler(transportConf) private val transportContext: TransportContext = { val handler = if (useSasl) new SaslRpcHandler(blockHandler, securityManager) else blockHandler new TransportContext(transportConf, handler) diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index f03e8e4bf1b7..7de2f9cbb286 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -27,6 +27,7 @@ import scala.collection.JavaConversions._ import org.apache.spark.{Logging, SparkConf, SparkEnv} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FileShuffleBlockManager.ShuffleFileGroup import org.apache.spark.storage._ @@ -68,6 +69,8 @@ private[spark] class FileShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager with Logging { + private val transportConf = SparkTransportConf.fromSparkConf(conf) + private lazy val blockManager = SparkEnv.get.blockManager // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. @@ -182,13 +185,14 @@ class FileShuffleBlockManager(conf: SparkConf) val segmentOpt = iter.next.getFileSegmentFor(blockId.mapId, blockId.reduceId) if (segmentOpt.isDefined) { val segment = segmentOpt.get - return new FileSegmentManagedBuffer(segment.file, segment.offset, segment.length) + return new FileSegmentManagedBuffer( + transportConf, segment.file, segment.offset, segment.length) } } throw new IllegalStateException("Failed to find shuffle block: " + blockId) } else { val file = blockManager.diskBlockManager.getFile(blockId) - new FileSegmentManagedBuffer(file, 0, file.length) + new FileSegmentManagedBuffer(transportConf, file, 0, file.length) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala index a48f0c9eceb5..b292587d3702 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala @@ -22,8 +22,9 @@ import java.nio.ByteBuffer import com.google.common.io.ByteStreams -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.storage._ /** @@ -38,10 +39,12 @@ import org.apache.spark.storage._ // Note: Changes to the format in this file should be kept in sync with // org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getSortBasedShuffleBlockData(). private[spark] -class IndexShuffleBlockManager extends ShuffleBlockManager { +class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { private lazy val blockManager = SparkEnv.get.blockManager + private val transportConf = SparkTransportConf.fromSparkConf(conf) + /** * Mapping to a single shuffleBlockId with reduce ID 0. * */ @@ -109,6 +112,7 @@ class IndexShuffleBlockManager extends ShuffleBlockManager { val offset = in.readLong() val nextOffset = in.readLong() new FileSegmentManagedBuffer( + transportConf, getDataFile(blockId.shuffleId, blockId.mapId), offset, nextOffset - offset) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index b727438ae7e4..bda30a56d808 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -25,7 +25,7 @@ import org.apache.spark.shuffle.hash.HashShuffleReader private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager { - private val indexShuffleBlockManager = new IndexShuffleBlockManager() + private val indexShuffleBlockManager = new IndexShuffleBlockManager(conf) private val shuffleMapNumber = new ConcurrentHashMap[Int, Int]() /** diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 6608ed1e57b3..9623d665177e 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -39,7 +39,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { override def beforeAll() { val transportConf = SparkTransportConf.fromSparkConf(conf) - rpcHandler = new ExternalShuffleBlockHandler() + rpcHandler = new ExternalShuffleBlockHandler(transportConf) val transportContext = new TransportContext(transportConf, rpcHandler) server = transportContext.createServer() diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index 5fa1527ddff9..844eff4f4c70 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -31,24 +31,19 @@ import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.LimitedInputStream; +import org.apache.spark.network.util.TransportConf; /** * A {@link ManagedBuffer} backed by a segment in a file. */ public final class FileSegmentManagedBuffer extends ManagedBuffer { - - /** - * Memory mapping is expensive and can destabilize the JVM (SPARK-1145, SPARK-3889). - * Avoid unless there's a good reason not to. - */ - // TODO: Make this configurable - private static final long MIN_MEMORY_MAP_BYTES = 2 * 1024 * 1024; - + private final TransportConf conf; private final File file; private final long offset; private final long length; - public FileSegmentManagedBuffer(File file, long offset, long length) { + public FileSegmentManagedBuffer(TransportConf conf, File file, long offset, long length) { + this.conf = conf; this.file = file; this.offset = offset; this.length = length; @@ -65,7 +60,7 @@ public ByteBuffer nioByteBuffer() throws IOException { try { channel = new RandomAccessFile(file, "r").getChannel(); // Just copy the buffer if it's sufficiently small, as memory mapping has a high overhead. - if (length < MIN_MEMORY_MAP_BYTES) { + if (length < conf.memoryMapBytes()) { ByteBuffer buf = ByteBuffer.allocate((int) length); channel.position(offset); while (buf.remaining() != 0) { @@ -134,8 +129,12 @@ public ManagedBuffer release() { @Override public Object convertToNetty() throws IOException { - FileChannel fileChannel = new FileInputStream(file).getChannel(); - return new DefaultFileRegion(fileChannel, offset, length); + if (conf.lazyFileDescriptor()) { + return new LazyFileRegion(file, offset, length); + } else { + FileChannel fileChannel = new FileInputStream(file).getChannel(); + return new DefaultFileRegion(fileChannel, offset, length); + } } public File getFile() { return file; } diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java b/network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java new file mode 100644 index 000000000000..81bc8ec40fc8 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.FileInputStream; +import java.io.File; +import java.io.IOException; +import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; + +import com.google.common.base.Objects; +import io.netty.channel.FileRegion; +import io.netty.util.AbstractReferenceCounted; + +import org.apache.spark.network.util.JavaUtils; + +/** + * A FileRegion implementation that only creates the file descriptor when the region is being + * transferred. This cannot be used with Epoll because there is no native support for it. + * + * This is mostly copied from DefaultFileRegion implementation in Netty. In the future, we + * should push this into Netty so the native Epoll transport can support this feature. + */ +public final class LazyFileRegion extends AbstractReferenceCounted implements FileRegion { + + private final File file; + private final long position; + private final long count; + + private FileChannel channel; + + private long numBytesTransferred = 0L; + + /** + * @param file file to transfer. + * @param position start position for the transfer. + * @param count number of bytes to transfer starting from position. + */ + public LazyFileRegion(File file, long position, long count) { + this.file = file; + this.position = position; + this.count = count; + } + + @Override + protected void deallocate() { + JavaUtils.closeQuietly(channel); + } + + @Override + public long position() { + return position; + } + + @Override + public long transfered() { + return numBytesTransferred; + } + + @Override + public long count() { + return count; + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + if (channel == null) { + channel = new FileInputStream(file).getChannel(); + } + + long count = this.count - position; + if (count < 0 || position < 0) { + throw new IllegalArgumentException( + "position out of range: " + position + " (expected: 0 - " + (count - 1) + ')'); + } + + if (count == 0) { + return 0L; + } + + long written = channel.transferTo(this.position + position, count, target); + if (written > 0) { + numBytesTransferred += written; + } + return written; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("file", file) + .add("position", position) + .add("count", count) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index 787a8f0031af..621427d8cba5 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -75,4 +75,21 @@ public int connectionTimeoutMs() { * Only relevant if maxIORetries > 0. */ public int ioRetryWaitTime() { return conf.getInt("spark.shuffle.io.retryWaitMs", 5000); } + + /** + * Minimum size of a block that we should start using memory map rather than reading in through + * normal IO operations. This prevents Spark from memory mapping very small blocks. In general, + * memory mapping has high overhead for blocks close to or below the page size of the OS. + */ + public int memoryMapBytes() { + return conf.getInt("spark.storage.memoryMapThreshold", 2 * 1024 * 1024); + } + + /** + * Whether to initialize shuffle FileDescriptor lazily or not. If true, file descriptors are + * created only when data is going to be transferred. This can reduce the number of open files. + */ + public boolean lazyFileDescriptor() { + return conf.getBoolean("spark.shuffle.io.lazyFD", true); + } } diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index c4158833976a..dfb7740344ed 100644 --- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -63,6 +63,8 @@ public class ChunkFetchIntegrationSuite { static ManagedBuffer bufferChunk; static ManagedBuffer fileChunk; + private TransportConf transportConf; + @BeforeClass public static void setUp() throws Exception { int bufSize = 100000; @@ -80,9 +82,10 @@ public static void setUp() throws Exception { new Random().nextBytes(fileContent); fp.write(fileContent); fp.close(); - fileChunk = new FileSegmentManagedBuffer(testFile, 10, testFile.length() - 25); - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); + streamManager = new StreamManager() { @Override public ManagedBuffer getChunk(long streamId, int chunkIndex) { @@ -90,7 +93,7 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { if (chunkIndex == BUFFER_CHUNK_INDEX) { return new NioManagedBuffer(buf); } else if (chunkIndex == FILE_CHUNK_INDEX) { - return new FileSegmentManagedBuffer(testFile, 10, testFile.length() - 25); + return new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); } else { throw new IllegalArgumentException("Invalid chunk index: " + chunkIndex); } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index a6db4b2abd6c..46ca9708621b 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -21,6 +21,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Lists; +import org.apache.spark.network.util.TransportConf; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -48,8 +49,8 @@ public class ExternalShuffleBlockHandler extends RpcHandler { private final ExternalShuffleBlockManager blockManager; private final OneForOneStreamManager streamManager; - public ExternalShuffleBlockHandler() { - this(new OneForOneStreamManager(), new ExternalShuffleBlockManager()); + public ExternalShuffleBlockHandler(TransportConf conf) { + this(new OneForOneStreamManager(), new ExternalShuffleBlockManager(conf)); } /** Enables mocking out the StreamManager and BlockManager. */ diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java index ffb7faa3dbdc..dfe0ba059509 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java @@ -37,6 +37,7 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.TransportConf; /** * Manages converting shuffle BlockIds into physical segments of local files, from a process outside @@ -56,14 +57,17 @@ public class ExternalShuffleBlockManager { // Single-threaded Java executor used to perform expensive recursive directory deletion. private final Executor directoryCleaner; - public ExternalShuffleBlockManager() { + private final TransportConf conf; + + public ExternalShuffleBlockManager(TransportConf conf) { // TODO: Give this thread a name. - this(Executors.newSingleThreadExecutor()); + this(conf, Executors.newSingleThreadExecutor()); } // Allows tests to have more control over when directories are cleaned up. @VisibleForTesting - ExternalShuffleBlockManager(Executor directoryCleaner) { + ExternalShuffleBlockManager(TransportConf conf, Executor directoryCleaner) { + this.conf = conf; this.executors = Maps.newConcurrentMap(); this.directoryCleaner = directoryCleaner; } @@ -167,7 +171,7 @@ private void deleteExecutorDirs(String[] dirs) { // TODO: Support consolidated hash shuffle files private ManagedBuffer getHashBasedShuffleBlockData(ExecutorShuffleInfo executor, String blockId) { File shuffleFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, blockId); - return new FileSegmentManagedBuffer(shuffleFile, 0, shuffleFile.length()); + return new FileSegmentManagedBuffer(conf, shuffleFile, 0, shuffleFile.length()); } /** @@ -187,6 +191,7 @@ private ManagedBuffer getSortBasedShuffleBlockData( long offset = in.readLong(); long nextOffset = in.readLong(); return new FileSegmentManagedBuffer( + conf, getFile(executor.localDirs, executor.subDirsPerLocalDir, "shuffle_" + shuffleId + "_" + mapId + "_0.data"), offset, diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java index da54797e8923..dad6428a836f 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java @@ -22,6 +22,8 @@ import java.io.InputStreamReader; import com.google.common.io.CharStreams; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -37,6 +39,8 @@ public class ExternalShuffleBlockManagerSuite { static TestShuffleDataContext dataContext; + static TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + @BeforeClass public static void beforeAll() throws IOException { dataContext = new TestShuffleDataContext(2, 5); @@ -56,7 +60,7 @@ public static void afterAll() { @Test public void testBadRequests() { - ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(); + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf); // Unregistered executor try { manager.getBlockData("app0", "exec1", "shuffle_1_1_0"); @@ -87,7 +91,7 @@ public void testBadRequests() { @Test public void testSortShuffleBlocks() throws IOException { - ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(); + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf); manager.registerExecutor("app0", "exec0", dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager")); @@ -106,7 +110,7 @@ public void testSortShuffleBlocks() throws IOException { @Test public void testHashShuffleBlocks() throws IOException { - ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(); + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf); manager.registerExecutor("app0", "exec0", dataContext.createExecutorInfo("org.apache.spark.shuffle.hash.HashShuffleManager")); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java index c8ece3bc53ac..254e3a7a32b9 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java @@ -25,20 +25,23 @@ import com.google.common.util.concurrent.MoreExecutors; import org.junit.Test; - import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + public class ExternalShuffleCleanupSuite { // Same-thread Executor used to ensure cleanup happens synchronously in test thread. Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); @Test public void noCleanupAndCleanup() throws IOException { TestShuffleDataContext dataContext = createSomeData(); - ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(sameThreadExecutor); + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf, sameThreadExecutor); manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); manager.applicationRemoved("app", false /* cleanup */); @@ -61,7 +64,7 @@ public void cleanupUsesExecutor() throws IOException { @Override public void execute(Runnable runnable) { cleanupCalled.set(true); } }; - ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(noThreadExecutor); + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf, noThreadExecutor); manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); manager.applicationRemoved("app", true); @@ -78,7 +81,7 @@ public void cleanupMultipleExecutors() throws IOException { TestShuffleDataContext dataContext0 = createSomeData(); TestShuffleDataContext dataContext1 = createSomeData(); - ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(sameThreadExecutor); + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf, sameThreadExecutor); manager.registerExecutor("app", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); manager.registerExecutor("app", "exec1", dataContext1.createExecutorInfo("shuffleMgr")); @@ -93,7 +96,7 @@ public void cleanupOnlyRemovedApp() throws IOException { TestShuffleDataContext dataContext0 = createSomeData(); TestShuffleDataContext dataContext1 = createSomeData(); - ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(sameThreadExecutor); + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf, sameThreadExecutor); manager.registerExecutor("app-0", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); manager.registerExecutor("app-1", "exec0", dataContext1.createExecutorInfo("shuffleMgr")); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 687bde59fdae..02c10bcb7b26 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -92,7 +92,7 @@ public static void beforeAll() throws IOException { dataContext1.insertHashShuffleData(1, 0, exec1Blocks); conf = new TransportConf(new SystemPropertyConfigProvider()); - handler = new ExternalShuffleBlockHandler(); + handler = new ExternalShuffleBlockHandler(conf); TransportContext transportContext = new TransportContext(conf, handler); server = transportContext.createServer(); } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index 8afceab1d585..759a12910c94 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -42,7 +42,7 @@ public class ExternalShuffleSecuritySuite { @Before public void beforeEach() { - RpcHandler handler = new SaslRpcHandler(new ExternalShuffleBlockHandler(), + RpcHandler handler = new SaslRpcHandler(new ExternalShuffleBlockHandler(conf), new TestSecretKeyHolder("my-app-id", "secret")); TransportContext context = new TransportContext(conf, handler); this.server = context.createServer(); diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index bb0b8f7e6cba..a34aabe9e78a 100644 --- a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -95,10 +95,11 @@ private boolean isAuthenticationEnabled() { */ @Override protected void serviceInit(Configuration conf) { + TransportConf transportConf = new TransportConf(new HadoopConfigProvider(conf)); // If authentication is enabled, set up the shuffle server to use a // special RPC handler that filters out unauthenticated fetch requests boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); - RpcHandler rpcHandler = new ExternalShuffleBlockHandler(); + RpcHandler rpcHandler = new ExternalShuffleBlockHandler(transportConf); if (authEnabled) { secretManager = new ShuffleSecretManager(); rpcHandler = new SaslRpcHandler(rpcHandler, secretManager); @@ -106,7 +107,6 @@ protected void serviceInit(Configuration conf) { int port = conf.getInt( SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT); - TransportConf transportConf = new TransportConf(new HadoopConfigProvider(conf)); TransportContext transportContext = new TransportContext(transportConf, rpcHandler); shuffleServer = transportContext.createServer(port); String authEnabledString = authEnabled ? "enabled" : "not enabled"; From fe8a1cd292ff067aabf78dd009204a4500d0cf75 Mon Sep 17 00:00:00 2001 From: maji2014 Date: Tue, 11 Nov 2014 02:18:27 -0800 Subject: [PATCH 084/652] [SPARK-4295][External]Fix exception in SparkSinkSuite Handle exception in SparkSinkSuite, please refer to [SPARK-4295] Author: maji2014 Closes #3177 from maji2014/spark-4295 and squashes the following commits: 312620a [maji2014] change a new statement for spark-4295 24c3d21 [maji2014] add log4j.properties for SparkSinkSuite and spark-4295 c807bf6 [maji2014] Fix exception in SparkSinkSuite (cherry picked from commit f8811a5695af2dfe156f07431288db7b8cd97159) Signed-off-by: Tathagata Das --- .../src/test/resources/log4j.properties | 29 +++++++++++++++++++ .../streaming/flume/sink/SparkSinkSuite.scala | 1 + 2 files changed, 30 insertions(+) create mode 100644 external/flume-sink/src/test/resources/log4j.properties diff --git a/external/flume-sink/src/test/resources/log4j.properties b/external/flume-sink/src/test/resources/log4j.properties new file mode 100644 index 000000000000..4411d6e20c52 --- /dev/null +++ b/external/flume-sink/src/test/resources/log4j.properties @@ -0,0 +1,29 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file streaming/target/unit-tests.log +log4j.rootCategory=INFO, file +# log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=false +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.eclipse.jetty=WARN + diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala index a2b2cc6149d9..650b2fbe1c14 100644 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -159,6 +159,7 @@ class SparkSinkSuite extends FunSuite { channelContext.put("transactionCapacity", 1000.toString) channelContext.put("keep-alive", 0.toString) channelContext.putAll(overrides) + channel.setName(scala.util.Random.nextString(10)) channel.configure(channelContext) val sink = new SparkSink() From 7710b7156e0c82445783c3709a4a793d820627b2 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 11 Nov 2014 02:22:23 -0800 Subject: [PATCH 085/652] [SPARK-2492][Streaming] kafkaReceiver minor changes to align with Kafka 0.8 Update the KafkaReceiver's behavior when auto.offset.reset is set. In Kafka 0.8, `auto.offset.reset` is a hint for out-range offset to seek to the beginning or end of the partition. While in the previous code `auto.offset.reset` is a enforcement to seek to the beginning or end immediately, this is different from Kafka 0.8 defined behavior. Also deleting extesting ZK metadata in Receiver when multiple consumers are launched will introduce issue as mentioned in [SPARK-2383](https://issues.apache.org/jira/browse/SPARK-2383). So Here we change to offer user to API to explicitly reset offset before create Kafka stream, while in the meantime keep the same behavior as Kafka 0.8 for parameter `auto.offset.reset`. @tdas, would you please review this PR? Thanks a lot. Author: jerryshao Closes #1420 from jerryshao/kafka-fix and squashes the following commits: d6ae94d [jerryshao] Address the comment to remove the resetOffset() function de3a4c8 [jerryshao] Fix compile error 4a1c3f9 [jerryshao] Doc changes b2c1430 [jerryshao] Move offset reset to a helper function to let user explicitly delete ZK metadata by calling this API fac8fd6 [jerryshao] Changes to align with Kafka 0.8 (cherry picked from commit c8850a3d6d948f9dd9ee026ee350428968d3c21b) Signed-off-by: Tathagata Das --- .../streaming/kafka/KafkaInputDStream.scala | 30 ------------------- .../spark/streaming/kafka/KafkaUtils.scala | 11 ++++--- 2 files changed, 5 insertions(+), 36 deletions(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index e20e2c8f2699..28ac5929df44 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -26,8 +26,6 @@ import java.util.concurrent.Executors import kafka.consumer._ import kafka.serializer.Decoder import kafka.utils.VerifiableProperties -import kafka.utils.ZKStringSerializer -import org.I0Itec.zkclient._ import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel @@ -97,12 +95,6 @@ class KafkaReceiver[ consumerConnector = Consumer.create(consumerConfig) logInfo("Connected to " + zkConnect) - // When auto.offset.reset is defined, it is our responsibility to try and whack the - // consumer group zk node. - if (kafkaParams.contains("auto.offset.reset")) { - tryZookeeperConsumerGroupCleanup(zkConnect, kafkaParams("group.id")) - } - val keyDecoder = classTag[U].runtimeClass.getConstructor(classOf[VerifiableProperties]) .newInstance(consumerConfig.props) .asInstanceOf[Decoder[K]] @@ -139,26 +131,4 @@ class KafkaReceiver[ } } } - - // It is our responsibility to delete the consumer group when specifying auto.offset.reset. This - // is because Kafka 0.7.2 only honors this param when the group is not in zookeeper. - // - // The kafka high level consumer doesn't expose setting offsets currently, this is a trick copied - // from Kafka's ConsoleConsumer. See code related to 'auto.offset.reset' when it is set to - // 'smallest'/'largest': - // scalastyle:off - // https://github.com/apache/kafka/blob/0.7.2/core/src/main/scala/kafka/consumer/ConsoleConsumer.scala - // scalastyle:on - private def tryZookeeperConsumerGroupCleanup(zkUrl: String, groupId: String) { - val dir = "/consumers/" + groupId - logInfo("Cleaning up temporary Zookeeper data under " + dir + ".") - val zk = new ZkClient(zkUrl, 30*1000, 30*1000, ZKStringSerializer) - try { - zk.deleteRecursive(dir) - } catch { - case e: Throwable => logWarning("Error cleaning up temporary Zookeeper data", e) - } finally { - zk.close() - } - } } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 48668f763e41..ec812e1ef3b0 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -17,19 +17,18 @@ package org.apache.spark.streaming.kafka -import scala.reflect.ClassTag -import scala.collection.JavaConversions._ - import java.lang.{Integer => JInt} import java.util.{Map => JMap} +import scala.reflect.ClassTag +import scala.collection.JavaConversions._ + import kafka.serializer.{Decoder, StringDecoder} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaPairReceiverInputDStream, JavaStreamingContext, JavaPairDStream} -import org.apache.spark.streaming.dstream.{ReceiverInputDStream, DStream} - +import org.apache.spark.streaming.api.java.{JavaPairReceiverInputDStream, JavaStreamingContext} +import org.apache.spark.streaming.dstream.ReceiverInputDStream object KafkaUtils { /** From cc1f3a0d6bfc5299e9db1d8ca50e33d2411d7cd9 Mon Sep 17 00:00:00 2001 From: huangzhaowei Date: Tue, 11 Nov 2014 03:02:12 -0800 Subject: [PATCH 086/652] [Streaming][Minor]Replace some 'if-else' in Clock Replace some 'if-else' statement by math.min and math.max in Clock.scala Author: huangzhaowei Closes #3088 from SaintBacchus/StreamingClock and squashes the following commits: 7b7f8e7 [huangzhaowei] [Streaming][Minor]Replace some 'if-else' in Clock (cherry picked from commit 6e03de304e0294017d832763fd71e642736f8c33) Signed-off-by: Tathagata Das --- .../org/apache/spark/streaming/util/Clock.scala | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala index 39145a3ab081..7cd867ce34b8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala @@ -41,13 +41,7 @@ class SystemClock() extends Clock { return currentTime } - val pollTime = { - if (waitTime / 10.0 > minPollTime) { - (waitTime / 10.0).toLong - } else { - minPollTime - } - } + val pollTime = math.max(waitTime / 10.0, minPollTime).toLong while (true) { currentTime = System.currentTimeMillis() @@ -55,12 +49,7 @@ class SystemClock() extends Clock { if (waitTime <= 0) { return currentTime } - val sleepTime = - if (waitTime < pollTime) { - waitTime - } else { - pollTime - } + val sleepTime = math.min(waitTime, pollTime) Thread.sleep(sleepTime) } -1 From 8f7e80f30bd34897963334d0245c0ea6fccd6182 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 11 Nov 2014 12:30:35 -0600 Subject: [PATCH 087/652] SPARK-4305 [BUILD] yarn-alpha profile won't build due to network/yarn module SPARK-3797 introduced the `network/yarn` module, but its YARN code depends on YARN APIs not present in older versions covered by the `yarn-alpha` profile. As a result builds like `mvn -Pyarn-alpha -Phadoop-0.23 -Dhadoop.version=0.23.7 -DskipTests clean package` fail. The solution is just to not build `network/yarn` with profile `yarn-alpha`. Author: Sean Owen Closes #3167 from srowen/SPARK-4305 and squashes the following commits: 88938cb [Sean Owen] Don't build network/yarn in yarn-alpha profile as it won't compile (cherry picked from commit f820b563d88f6a972c219d9340fe95110493fb87) Signed-off-by: Thomas Graves --- pom.xml | 1 - 1 file changed, 1 deletion(-) diff --git a/pom.xml b/pom.xml index 88ef67c515b3..4e0cd6c151d0 100644 --- a/pom.xml +++ b/pom.xml @@ -1229,7 +1229,6 @@ yarn-alpha yarn - network/yarn From ec0d89bc93f3a69a844d4b133bf185ee24048726 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Tue, 11 Nov 2014 12:33:53 -0600 Subject: [PATCH 088/652] [SPARK-4282][YARN] Stopping flag in YarnClientSchedulerBackend should be volatile In YarnClientSchedulerBackend, a variable "stopping" is used as a flag and it's accessed by some threads so it should be volatile. Author: Kousuke Saruta Closes #3143 from sarutak/stopping-flag-volatile and squashes the following commits: 58fdcc9 [Kousuke Saruta] Marked stoppig flag as volatile (cherry picked from commit 7f3718842cc4025bb2ee2f5a3ec12efd100f6589) Signed-off-by: Thomas Graves --- .../spark/scheduler/cluster/YarnClientSchedulerBackend.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index f6f6dc52433e..2923e6729cd6 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -33,7 +33,7 @@ private[spark] class YarnClientSchedulerBackend( private var client: Client = null private var appId: ApplicationId = null - private var stopping: Boolean = false + @volatile private var stopping: Boolean = false /** * Create a Yarn client to submit an application to the ResourceManager. From 6a7ddf4ce10e540ecc389235a7e4d994e225b9e6 Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Tue, 11 Nov 2014 14:29:18 -0800 Subject: [PATCH 089/652] SPARK-2269 Refactor mesos scheduler resourceOffers and add unit test Author: Timothy Chen Closes #1487 from tnachen/resource_offer_refactor and squashes the following commits: 4ea5dec [Timothy Chen] Rebase from master and address comments 9ccab09 [Timothy Chen] Address review comments e6494dc [Timothy Chen] Refactor class loading 8207428 [Timothy Chen] Refactor mesos scheduler resourceOffers and add unit test (cherry picked from commit a878660d2d7bb7ad9b5818a674e1e7c651077e78) Signed-off-by: Andrew Or --- .../cluster/mesos/MesosSchedulerBackend.scala | 137 ++++++++---------- .../mesos/MesosSchedulerBackendSuite.scala | 94 ++++++++++++ 2 files changed, 152 insertions(+), 79 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index c5f3493477bc..d13795186c48 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -166,29 +166,16 @@ private[spark] class MesosSchedulerBackend( execArgs } - private def setClassLoader(): ClassLoader = { - val oldClassLoader = Thread.currentThread.getContextClassLoader - Thread.currentThread.setContextClassLoader(classLoader) - oldClassLoader - } - - private def restoreClassLoader(oldClassLoader: ClassLoader) { - Thread.currentThread.setContextClassLoader(oldClassLoader) - } - override def offerRescinded(d: SchedulerDriver, o: OfferID) {} override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { - val oldClassLoader = setClassLoader() - try { + inClassLoader() { appId = frameworkId.getValue logInfo("Registered as framework ID " + appId) registeredLock.synchronized { isRegistered = true registeredLock.notifyAll() } - } finally { - restoreClassLoader(oldClassLoader) } } @@ -200,6 +187,16 @@ private[spark] class MesosSchedulerBackend( } } + private def inClassLoader()(fun: => Unit) = { + val oldClassLoader = Thread.currentThread.getContextClassLoader + Thread.currentThread.setContextClassLoader(classLoader) + try { + fun + } finally { + Thread.currentThread.setContextClassLoader(oldClassLoader) + } + } + override def disconnected(d: SchedulerDriver) {} override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {} @@ -210,66 +207,57 @@ private[spark] class MesosSchedulerBackend( * tasks are balanced across the cluster. */ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { - val oldClassLoader = setClassLoader() - try { - synchronized { - // Build a big list of the offerable workers, and remember their indices so that we can - // figure out which Offer to reply to for each worker - val offerableWorkers = new ArrayBuffer[WorkerOffer] - val offerableIndices = new HashMap[String, Int] - - def sufficientOffer(o: Offer) = { - val mem = getResource(o.getResourcesList, "mem") - val cpus = getResource(o.getResourcesList, "cpus") - val slaveId = o.getSlaveId.getValue - (mem >= MemoryUtils.calculateTotalMemory(sc) && - // need at least 1 for executor, 1 for task - cpus >= 2 * scheduler.CPUS_PER_TASK) || - (slaveIdsWithExecutors.contains(slaveId) && - cpus >= scheduler.CPUS_PER_TASK) - } + inClassLoader() { + val (acceptedOffers, declinedOffers) = offers.partition { o => + val mem = getResource(o.getResourcesList, "mem") + val cpus = getResource(o.getResourcesList, "cpus") + val slaveId = o.getSlaveId.getValue + (mem >= MemoryUtils.calculateTotalMemory(sc) && + // need at least 1 for executor, 1 for task + cpus >= 2 * scheduler.CPUS_PER_TASK) || + (slaveIdsWithExecutors.contains(slaveId) && + cpus >= scheduler.CPUS_PER_TASK) + } - for ((offer, index) <- offers.zipWithIndex if sufficientOffer(offer)) { - val slaveId = offer.getSlaveId.getValue - offerableIndices.put(slaveId, index) - val cpus = if (slaveIdsWithExecutors.contains(slaveId)) { - getResource(offer.getResourcesList, "cpus").toInt - } else { - // If the executor doesn't exist yet, subtract CPU for executor - getResource(offer.getResourcesList, "cpus").toInt - - scheduler.CPUS_PER_TASK - } - offerableWorkers += new WorkerOffer( - offer.getSlaveId.getValue, - offer.getHostname, - cpus) + val offerableWorkers = acceptedOffers.map { o => + val cpus = if (slaveIdsWithExecutors.contains(o.getSlaveId.getValue)) { + getResource(o.getResourcesList, "cpus").toInt + } else { + // If the executor doesn't exist yet, subtract CPU for executor + getResource(o.getResourcesList, "cpus").toInt - + scheduler.CPUS_PER_TASK } + new WorkerOffer( + o.getSlaveId.getValue, + o.getHostname, + cpus) + } - // Call into the TaskSchedulerImpl - val taskLists = scheduler.resourceOffers(offerableWorkers) - - // Build a list of Mesos tasks for each slave - val mesosTasks = offers.map(o => new JArrayList[MesosTaskInfo]()) - for ((taskList, index) <- taskLists.zipWithIndex) { - if (!taskList.isEmpty) { - for (taskDesc <- taskList) { - val slaveId = taskDesc.executorId - val offerNum = offerableIndices(slaveId) - slaveIdsWithExecutors += slaveId - taskIdToSlaveId(taskDesc.taskId) = slaveId - mesosTasks(offerNum).add(createMesosTask(taskDesc, slaveId)) - } + val slaveIdToOffer = acceptedOffers.map(o => o.getSlaveId.getValue -> o).toMap + + val mesosTasks = new HashMap[String, JArrayList[MesosTaskInfo]] + + // Call into the TaskSchedulerImpl + scheduler.resourceOffers(offerableWorkers) + .filter(!_.isEmpty) + .foreach { offer => + offer.foreach { taskDesc => + val slaveId = taskDesc.executorId + slaveIdsWithExecutors += slaveId + taskIdToSlaveId(taskDesc.taskId) = slaveId + mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo]) + .add(createMesosTask(taskDesc, slaveId)) } } - // Reply to the offers - val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout? - for (i <- 0 until offers.size) { - d.launchTasks(Collections.singleton(offers(i).getId), mesosTasks(i), filters) - } + // Reply to the offers + val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout? + + mesosTasks.foreach { case (slaveId, tasks) => + d.launchTasks(Collections.singleton(slaveIdToOffer(slaveId).getId), tasks, filters) } - } finally { - restoreClassLoader(oldClassLoader) + + declinedOffers.foreach(o => d.declineOffer(o.getId)) } } @@ -308,8 +296,7 @@ private[spark] class MesosSchedulerBackend( } override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { - val oldClassLoader = setClassLoader() - try { + inClassLoader() { val tid = status.getTaskId.getValue.toLong val state = TaskState.fromMesos(status.getState) synchronized { @@ -322,18 +309,13 @@ private[spark] class MesosSchedulerBackend( } } scheduler.statusUpdate(tid, state, status.getData.asReadOnlyByteBuffer) - } finally { - restoreClassLoader(oldClassLoader) } } override def error(d: SchedulerDriver, message: String) { - val oldClassLoader = setClassLoader() - try { + inClassLoader() { logError("Mesos error: " + message) scheduler.error(message) - } finally { - restoreClassLoader(oldClassLoader) } } @@ -350,15 +332,12 @@ private[spark] class MesosSchedulerBackend( override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} private def recordSlaveLost(d: SchedulerDriver, slaveId: SlaveID, reason: ExecutorLossReason) { - val oldClassLoader = setClassLoader() - try { + inClassLoader() { logInfo("Mesos slave lost: " + slaveId.getValue) synchronized { slaveIdsWithExecutors -= slaveId.getValue } scheduler.executorLost(slaveId.getValue, reason) - } finally { - restoreClassLoader(oldClassLoader) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala new file mode 100644 index 000000000000..bef8d3a58ba6 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.mesos + +import org.scalatest.FunSuite +import org.apache.spark.{scheduler, SparkConf, SparkContext, LocalSparkContext} +import org.apache.spark.scheduler.{TaskDescription, WorkerOffer, TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.mesos.{MemoryUtils, MesosSchedulerBackend} +import org.apache.mesos.SchedulerDriver +import org.apache.mesos.Protos._ +import org.scalatest.mock.EasyMockSugar +import org.apache.mesos.Protos.Value.Scalar +import org.easymock.{Capture, EasyMock} +import java.nio.ByteBuffer +import java.util.Collections +import java.util +import scala.collection.mutable + +class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with EasyMockSugar { + test("mesos resource offer is launching tasks") { + def createOffer(id: Int, mem: Int, cpu: Int) = { + val builder = Offer.newBuilder() + builder.addResourcesBuilder() + .setName("mem") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(mem)) + builder.addResourcesBuilder() + .setName("cpus") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(cpu)) + builder.setId(OfferID.newBuilder().setValue(id.toString).build()).setFrameworkId(FrameworkID.newBuilder().setValue("f1")) + .setSlaveId(SlaveID.newBuilder().setValue("s1")).setHostname("localhost").build() + } + + val driver = EasyMock.createMock(classOf[SchedulerDriver]) + val taskScheduler = EasyMock.createMock(classOf[TaskSchedulerImpl]) + + val sc = EasyMock.createMock(classOf[SparkContext]) + + EasyMock.expect(sc.executorMemory).andReturn(100).anyTimes() + EasyMock.expect(sc.getSparkHome()).andReturn(Option("/path")).anyTimes() + EasyMock.expect(sc.executorEnvs).andReturn(new mutable.HashMap).anyTimes() + EasyMock.expect(sc.conf).andReturn(new SparkConf).anyTimes() + EasyMock.replay(sc) + val minMem = MemoryUtils.calculateTotalMemory(sc).toInt + val minCpu = 4 + val offers = new java.util.ArrayList[Offer] + offers.add(createOffer(1, minMem, minCpu)) + offers.add(createOffer(1, minMem - 1, minCpu)) + val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") + val workerOffers = Seq(offers.get(0)).map(o => new WorkerOffer( + o.getSlaveId.getValue, + o.getHostname, + 2 + )) + val taskDesc = new TaskDescription(1L, "s1", "n1", 0, ByteBuffer.wrap(new Array[Byte](0))) + EasyMock.expect(taskScheduler.resourceOffers(EasyMock.eq(workerOffers))).andReturn(Seq(Seq(taskDesc))) + EasyMock.expect(taskScheduler.CPUS_PER_TASK).andReturn(2).anyTimes() + EasyMock.replay(taskScheduler) + val capture = new Capture[util.Collection[TaskInfo]] + EasyMock.expect( + driver.launchTasks( + EasyMock.eq(Collections.singleton(offers.get(0).getId)), + EasyMock.capture(capture), + EasyMock.anyObject(classOf[Filters]) + ) + ).andReturn(Status.valueOf(1)) + EasyMock.expect(driver.declineOffer(offers.get(1).getId)).andReturn(Status.valueOf(1)) + EasyMock.replay(driver) + backend.resourceOffers(driver, offers) + assert(capture.getValue.size() == 1) + val taskInfo = capture.getValue.iterator().next() + assert(taskInfo.getName.equals("n1")) + val cpus = taskInfo.getResourcesList.get(0) + assert(cpus.getName.equals("cpus")) + assert(cpus.getScalar.getValue.equals(2.0)) + assert(taskInfo.getSlaveId.getValue.equals("s1")) + } +} From 307b69d73c37b5a580a1079843b13aeac1f6f6f4 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 11 Nov 2014 18:02:59 -0800 Subject: [PATCH 090/652] [Release] Log build output for each distribution --- dev/create-release/create-release.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 281e8d4de6d7..50a9a2fa1cb9 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -27,6 +27,7 @@ # Would be nice to add: # - Send output to stderr and have useful logging in stdout +# Note: The following variables must be set before use! GIT_USERNAME=${GIT_USERNAME:-pwendell} GIT_PASSWORD=${GIT_PASSWORD:-XXX} GPG_PASSPHRASE=${GPG_PASSPHRASE:-XXX} @@ -101,7 +102,7 @@ make_binary_release() { cp -r spark spark-$RELEASE_VERSION-bin-$NAME cd spark-$RELEASE_VERSION-bin-$NAME - ./make-distribution.sh --name $NAME --tgz $FLAGS + ./make-distribution.sh --name $NAME --tgz $FLAGS 2>&1 | tee binary-release-$NAME.log cd .. cp spark-$RELEASE_VERSION-bin-$NAME/spark-$RELEASE_VERSION-bin-$NAME.tgz . rm -rf spark-$RELEASE_VERSION-bin-$NAME From 12f56334bb308c19d1c6c017fe1ec10808bde12a Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Tue, 11 Nov 2014 21:36:48 -0800 Subject: [PATCH 091/652] Support cross building for Scala 2.11 Let's give this another go using a version of Hive that shades its JLine dependency. Author: Prashant Sharma Author: Patrick Wendell Closes #3159 from pwendell/scala-2.11-prashant and squashes the following commits: e93aa3e [Patrick Wendell] Restoring -Phive-thriftserver profile and cleaning up build script. f65d17d [Patrick Wendell] Fixing build issue due to merge conflict a8c41eb [Patrick Wendell] Reverting dev/run-tests back to master state. 7a6eb18 [Patrick Wendell] Merge remote-tracking branch 'apache/master' into scala-2.11-prashant 583aa07 [Prashant Sharma] REVERT ME: removed hive thirftserver 3680e58 [Prashant Sharma] Revert "REVERT ME: Temporarily removing some Cli tests." 935fb47 [Prashant Sharma] Revert "Fixed by disabling a few tests temporarily." 925e90f [Prashant Sharma] Fixed by disabling a few tests temporarily. 2fffed3 [Prashant Sharma] Exclude groovy from sbt build, and also provide a way for such instances in future. 8bd4e40 [Prashant Sharma] Switched to gmaven plus, it fixes random failures observer with its predecessor gmaven. 5272ce5 [Prashant Sharma] SPARK_SCALA_VERSION related bugs. 2121071 [Patrick Wendell] Migrating version detection to PySpark b1ed44d [Patrick Wendell] REVERT ME: Temporarily removing some Cli tests. 1743a73 [Patrick Wendell] Removing decimal test that doesn't work with Scala 2.11 f5cad4e [Patrick Wendell] Add Scala 2.11 docs 210d7e1 [Patrick Wendell] Revert "Testing new Hive version with shaded jline" 48518ce [Patrick Wendell] Remove association of Hive and Thriftserver profiles. e9d0a06 [Patrick Wendell] Revert "Enable thritfserver for Scala 2.10 only" 67ec364 [Patrick Wendell] Guard building of thriftserver around Scala 2.10 check 8502c23 [Patrick Wendell] Enable thritfserver for Scala 2.10 only e22b104 [Patrick Wendell] Small fix in pom file ec402ab [Patrick Wendell] Various fixes 0be5a9d [Patrick Wendell] Testing new Hive version with shaded jline 4eaec65 [Prashant Sharma] Changed scripts to ignore target. 5167bea [Prashant Sharma] small correction a4fcac6 [Prashant Sharma] Run against scala 2.11 on jenkins. 80285f4 [Prashant Sharma] MAven equivalent of setting spark.executor.extraClasspath during tests. 034b369 [Prashant Sharma] Setting test jars on executor classpath during tests from sbt. d4874cb [Prashant Sharma] Fixed Python Runner suite. null check should be first case in scala 2.11. 6f50f13 [Prashant Sharma] Fixed build after rebasing with master. We should use ${scala.binary.version} instead of just 2.10 e56ca9d [Prashant Sharma] Print an error if build for 2.10 and 2.11 is spotted. 937c0b8 [Prashant Sharma] SCALA_VERSION -> SPARK_SCALA_VERSION cb059b0 [Prashant Sharma] Code review 0476e5e [Prashant Sharma] Scala 2.11 support with repl and all build changes. (cherry picked from commit daaca14c16dc2c1abc98f15ab8c6f7c14761b627) Signed-off-by: Patrick Wendell --- .rat-excludes | 1 + assembly/pom.xml | 13 +- bin/compute-classpath.sh | 46 +- bin/load-spark-env.sh | 20 + bin/pyspark | 6 +- bin/run-example | 8 +- bin/spark-class | 8 +- core/pom.xml | 57 +- .../apache/spark/deploy/PythonRunner.scala | 2 +- .../org/apache/spark/deploy/SparkSubmit.scala | 2 +- dev/change-version-to-2.10.sh | 20 + dev/change-version-to-2.11.sh | 21 + dev/create-release/create-release.sh | 12 +- dev/run-tests | 13 +- dev/scalastyle | 2 +- docs/building-spark.md | 31 +- docs/sql-programming-guide.md | 2 +- examples/pom.xml | 199 ++- .../streaming/JavaKafkaWordCount.java | 0 .../examples/streaming/KafkaWordCount.scala | 0 .../streaming/TwitterAlgebirdCMS.scala | 0 .../streaming/TwitterAlgebirdHLL.scala | 0 external/mqtt/pom.xml | 5 - make-distribution.sh | 2 +- network/shuffle/pom.xml | 4 +- network/yarn/pom.xml | 2 +- pom.xml | 178 ++- project/SparkBuild.scala | 36 +- project/project/SparkPluginBuild.scala | 2 +- repl/pom.xml | 90 +- .../scala/org/apache/spark/repl/Main.scala | 0 .../apache/spark/repl/SparkCommandLine.scala | 0 .../apache/spark/repl/SparkExprTyper.scala | 0 .../org/apache/spark/repl/SparkHelper.scala | 0 .../org/apache/spark/repl/SparkILoop.scala | 0 .../apache/spark/repl/SparkILoopInit.scala | 0 .../org/apache/spark/repl/SparkIMain.scala | 0 .../org/apache/spark/repl/SparkImports.scala | 0 .../spark/repl/SparkJLineCompletion.scala | 0 .../apache/spark/repl/SparkJLineReader.scala | 0 .../spark/repl/SparkMemberHandlers.scala | 0 .../spark/repl/SparkRunnerSettings.scala | 0 .../org/apache/spark/repl/ReplSuite.scala | 0 .../scala/org/apache/spark/repl/Main.scala | 85 ++ .../apache/spark/repl/SparkExprTyper.scala | 86 ++ .../org/apache/spark/repl/SparkILoop.scala | 966 ++++++++++++ .../org/apache/spark/repl/SparkIMain.scala | 1319 +++++++++++++++++ .../org/apache/spark/repl/SparkImports.scala | 201 +++ .../spark/repl/SparkJLineCompletion.scala | 350 +++++ .../spark/repl/SparkMemberHandlers.scala | 221 +++ .../apache/spark/repl/SparkReplReporter.scala | 53 + .../org/apache/spark/repl/ReplSuite.scala | 326 ++++ sql/catalyst/pom.xml | 29 +- .../catalyst/types/decimal/DecimalSuite.scala | 1 - 54 files changed, 4204 insertions(+), 215 deletions(-) create mode 100755 dev/change-version-to-2.10.sh create mode 100755 dev/change-version-to-2.11.sh rename examples/{ => scala-2.10}/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java (100%) rename examples/{ => scala-2.10}/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala (100%) rename examples/{ => scala-2.10}/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala (100%) rename examples/{ => scala-2.10}/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala (100%) rename repl/{ => scala-2.10}/src/main/scala/org/apache/spark/repl/Main.scala (100%) rename repl/{ => scala-2.10}/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala (100%) rename repl/{ => scala-2.10}/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala (100%) rename repl/{ => scala-2.10}/src/main/scala/org/apache/spark/repl/SparkHelper.scala (100%) rename repl/{ => scala-2.10}/src/main/scala/org/apache/spark/repl/SparkILoop.scala (100%) rename repl/{ => scala-2.10}/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala (100%) rename repl/{ => scala-2.10}/src/main/scala/org/apache/spark/repl/SparkIMain.scala (100%) rename repl/{ => scala-2.10}/src/main/scala/org/apache/spark/repl/SparkImports.scala (100%) rename repl/{ => scala-2.10}/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala (100%) rename repl/{ => scala-2.10}/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala (100%) rename repl/{ => scala-2.10}/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala (100%) rename repl/{ => scala-2.10}/src/main/scala/org/apache/spark/repl/SparkRunnerSettings.scala (100%) rename repl/{ => scala-2.10}/src/test/scala/org/apache/spark/repl/ReplSuite.scala (100%) create mode 100644 repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala create mode 100644 repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala create mode 100644 repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala create mode 100644 repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala create mode 100644 repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkImports.scala create mode 100644 repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala create mode 100644 repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala create mode 100644 repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkReplReporter.scala create mode 100644 repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala diff --git a/.rat-excludes b/.rat-excludes index 20e337246438..d8bee1f8e49c 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -44,6 +44,7 @@ SparkImports.scala SparkJLineCompletion.scala SparkJLineReader.scala SparkMemberHandlers.scala +SparkReplReporter.scala sbt sbt-launch-lib.bash plugins.sbt diff --git a/assembly/pom.xml b/assembly/pom.xml index 31a01e4d8e1d..c65192bde64c 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -66,22 +66,22 @@ org.apache.spark - spark-repl_${scala.binary.version} + spark-streaming_${scala.binary.version} ${project.version} org.apache.spark - spark-streaming_${scala.binary.version} + spark-graphx_${scala.binary.version} ${project.version} org.apache.spark - spark-graphx_${scala.binary.version} + spark-sql_${scala.binary.version} ${project.version} org.apache.spark - spark-sql_${scala.binary.version} + spark-repl_${scala.binary.version} ${project.version} @@ -197,6 +197,11 @@ spark-hive_${scala.binary.version} ${project.version} + + + + hive-thriftserver + org.apache.spark spark-hive-thriftserver_${scala.binary.version} diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh index 905bbaf99b37..298641f2684d 100755 --- a/bin/compute-classpath.sh +++ b/bin/compute-classpath.sh @@ -20,8 +20,6 @@ # This script computes Spark's classpath and prints it to stdout; it's used by both the "run" # script and the ExecutorRunner in standalone cluster mode. -SCALA_VERSION=2.10 - # Figure out where Spark is installed FWDIR="$(cd "`dirname "$0"`"/..; pwd)" @@ -36,7 +34,7 @@ else CLASSPATH="$CLASSPATH:$FWDIR/conf" fi -ASSEMBLY_DIR="$FWDIR/assembly/target/scala-$SCALA_VERSION" +ASSEMBLY_DIR="$FWDIR/assembly/target/scala-$SPARK_SCALA_VERSION" if [ -n "$JAVA_HOME" ]; then JAR_CMD="$JAVA_HOME/bin/jar" @@ -48,19 +46,19 @@ fi if [ -n "$SPARK_PREPEND_CLASSES" ]; then echo "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark"\ "classes ahead of assembly." >&2 - CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SPARK_SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/core/target/jars/*" - CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/tools/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/hive-thriftserver/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/tools/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/hive-thriftserver/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SPARK_SCALA_VERSION/classes" fi # Use spark-assembly jar from either RELEASE or assembly directory @@ -123,15 +121,15 @@ fi # Add test classes if we're running from SBT or Maven with SPARK_TESTING set to 1 if [[ $SPARK_TESTING == 1 ]]; then - CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SPARK_SCALA_VERSION/test-classes" fi # Add hadoop conf dir if given -- otherwise FileSystem.*, etc fail ! diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index 6d4231b20459..356b3d49b2ff 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -36,3 +36,23 @@ if [ -z "$SPARK_ENV_LOADED" ]; then set +a fi fi + +# Setting SPARK_SCALA_VERSION if not already set. + +if [ -z "$SPARK_SCALA_VERSION" ]; then + + ASSEMBLY_DIR2="$FWDIR/assembly/target/scala-2.11" + ASSEMBLY_DIR1="$FWDIR/assembly/target/scala-2.10" + + if [[ -d "$ASSEMBLY_DIR2" && -d "$ASSEMBLY_DIR1" ]]; then + echo -e "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected." 1>&2 + echo -e 'Either clean one of them or, export SPARK_SCALA_VERSION=2.11 in spark-env.sh.' 1>&2 + exit 1 + fi + + if [ -d "$ASSEMBLY_DIR2" ]; then + export SPARK_SCALA_VERSION="2.11" + else + export SPARK_SCALA_VERSION="2.10" + fi +fi diff --git a/bin/pyspark b/bin/pyspark index 96f30a260a09..1d8c94d43d28 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -25,7 +25,7 @@ export SPARK_HOME="$FWDIR" source "$FWDIR/bin/utils.sh" -SCALA_VERSION=2.10 +source "$FWDIR"/bin/load-spark-env.sh function usage() { echo "Usage: ./bin/pyspark [options]" 1>&2 @@ -40,7 +40,7 @@ fi # Exit if the user hasn't compiled Spark if [ ! -f "$FWDIR/RELEASE" ]; then # Exit if the user hasn't compiled Spark - ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar >& /dev/null + ls "$FWDIR"/assembly/target/scala-$SPARK_SCALA_VERSION/spark-assembly*hadoop*.jar >& /dev/null if [[ $? != 0 ]]; then echo "Failed to find Spark assembly in $FWDIR/assembly/target" 1>&2 echo "You need to build Spark before running this program" 1>&2 @@ -48,8 +48,6 @@ if [ ! -f "$FWDIR/RELEASE" ]; then fi fi -. "$FWDIR"/bin/load-spark-env.sh - # In Spark <= 1.1, setting IPYTHON=1 would cause the driver to be launched using the `ipython` # executable, while the worker would still be launched using PYSPARK_PYTHON. # diff --git a/bin/run-example b/bin/run-example index 34dd71c71880..3d932509426f 100755 --- a/bin/run-example +++ b/bin/run-example @@ -17,12 +17,12 @@ # limitations under the License. # -SCALA_VERSION=2.10 - FWDIR="$(cd "`dirname "$0"`"/..; pwd)" export SPARK_HOME="$FWDIR" EXAMPLES_DIR="$FWDIR"/examples +. "$FWDIR"/bin/load-spark-env.sh + if [ -n "$1" ]; then EXAMPLE_CLASS="$1" shift @@ -36,8 +36,8 @@ fi if [ -f "$FWDIR/RELEASE" ]; then export SPARK_EXAMPLES_JAR="`ls "$FWDIR"/lib/spark-examples-*hadoop*.jar`" -elif [ -e "$EXAMPLES_DIR"/target/scala-$SCALA_VERSION/spark-examples-*hadoop*.jar ]; then - export SPARK_EXAMPLES_JAR="`ls "$EXAMPLES_DIR"/target/scala-$SCALA_VERSION/spark-examples-*hadoop*.jar`" +elif [ -e "$EXAMPLES_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-examples-*hadoop*.jar ]; then + export SPARK_EXAMPLES_JAR="`ls "$EXAMPLES_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-examples-*hadoop*.jar`" fi if [[ -z "$SPARK_EXAMPLES_JAR" ]]; then diff --git a/bin/spark-class b/bin/spark-class index 925367b0dd18..0d58d95c1aee 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -24,8 +24,6 @@ case "`uname`" in CYGWIN*) cygwin=true;; esac -SCALA_VERSION=2.10 - # Figure out where Spark is installed FWDIR="$(cd "`dirname "$0"`"/..; pwd)" @@ -128,9 +126,9 @@ fi TOOLS_DIR="$FWDIR"/tools SPARK_TOOLS_JAR="" -if [ -e "$TOOLS_DIR"/target/scala-$SCALA_VERSION/spark-tools*[0-9Tg].jar ]; then +if [ -e "$TOOLS_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-tools*[0-9Tg].jar ]; then # Use the JAR from the SBT build - export SPARK_TOOLS_JAR="`ls "$TOOLS_DIR"/target/scala-$SCALA_VERSION/spark-tools*[0-9Tg].jar`" + export SPARK_TOOLS_JAR="`ls "$TOOLS_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-tools*[0-9Tg].jar`" fi if [ -e "$TOOLS_DIR"/target/spark-tools*[0-9Tg].jar ]; then # Use the JAR from the Maven build @@ -149,7 +147,7 @@ fi if [[ "$1" =~ org.apache.spark.tools.* ]]; then if test -z "$SPARK_TOOLS_JAR"; then - echo "Failed to find Spark Tools Jar in $FWDIR/tools/target/scala-$SCALA_VERSION/" 1>&2 + echo "Failed to find Spark Tools Jar in $FWDIR/tools/target/scala-$SPARK_SCALA_VERSION/" 1>&2 echo "You need to build Spark before running $1." 1>&2 exit 1 fi diff --git a/core/pom.xml b/core/pom.xml index 92e9f1fc4627..03eb231581dc 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -34,6 +34,34 @@ Spark Project Core http://spark.apache.org/ + + com.twitter + chill_${scala.binary.version} + + + org.ow2.asm + asm + + + org.ow2.asm + asm-commons + + + + + com.twitter + chill-java + + + org.ow2.asm + asm + + + org.ow2.asm + asm-commons + + + org.apache.hadoop hadoop-client @@ -46,12 +74,12 @@ org.apache.spark - spark-network-common_2.10 + spark-network-common_${scala.binary.version} ${project.version} org.apache.spark - spark-network-shuffle_2.10 + spark-network-shuffle_${scala.binary.version} ${project.version} @@ -132,14 +160,6 @@ net.jpountz.lz4 lz4 - - com.twitter - chill_${scala.binary.version} - - - com.twitter - chill-java - org.roaringbitmap RoaringBitmap @@ -316,14 +336,16 @@ org.scalatest scalatest-maven-plugin - - - ${basedir}/.. - 1 - ${spark.classpath} - - + + + test + + test + + + + org.apache.maven.plugins @@ -431,4 +453,5 @@ + diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index af94b05ce384..039c8719e286 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -87,8 +87,8 @@ object PythonRunner { // Strip the URI scheme from the path formattedPath = new URI(formattedPath).getScheme match { - case Utils.windowsDrive(d) if windows => formattedPath case null => formattedPath + case Utils.windowsDrive(d) if windows => formattedPath case _ => new URI(formattedPath).getPath } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index b43e68e40f79..8a62519bd231 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -340,7 +340,7 @@ object SparkSubmit { e.printStackTrace(printStream) if (childMainClass.contains("thriftserver")) { println(s"Failed to load main class $childMainClass.") - println("You need to build Spark with -Phive.") + println("You need to build Spark with -Phive and -Phive-thriftserver.") } System.exit(CLASS_NOT_FOUND_EXIT_STATUS) } diff --git a/dev/change-version-to-2.10.sh b/dev/change-version-to-2.10.sh new file mode 100755 index 000000000000..7473c20d28e0 --- /dev/null +++ b/dev/change-version-to-2.10.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +find . -name 'pom.xml' | grep -v target \ + | xargs -I {} sed -i -e 's|\(artifactId.*\)_2.11|\1_2.10|g' {} diff --git a/dev/change-version-to-2.11.sh b/dev/change-version-to-2.11.sh new file mode 100755 index 000000000000..3957a9f3ba25 --- /dev/null +++ b/dev/change-version-to-2.11.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +find . -name 'pom.xml' | grep -v target \ + | xargs -I {} sed -i -e 's|\(artifactId.*\)_2.10|\1_2.11|g' {} diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 50a9a2fa1cb9..db441b3e4979 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -118,13 +118,13 @@ make_binary_release() { spark-$RELEASE_VERSION-bin-$NAME.tgz.sha } -make_binary_release "hadoop1" "-Phive -Dhadoop.version=1.0.4" & -make_binary_release "cdh4" "-Phive -Dhadoop.version=2.0.0-mr1-cdh4.2.0" & -make_binary_release "hadoop2.3" "-Phadoop-2.3 -Phive -Pyarn" & -make_binary_release "hadoop2.4" "-Phadoop-2.4 -Phive -Pyarn" & +make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4" & +make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" & +make_binary_release "hadoop2.3" "-Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" & +make_binary_release "hadoop2.4" "-Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" & make_binary_release "hadoop2.4-without-hive" "-Phadoop-2.4 -Pyarn" & -make_binary_release "mapr3" "-Pmapr3 -Phive" & -make_binary_release "mapr4" "-Pmapr4 -Pyarn -Phive" & +make_binary_release "mapr3" "-Pmapr3 -Phive -Phive-thriftserver" & +make_binary_release "mapr4" "-Pmapr4 -Pyarn -Phive -Phive-thriftserver" & wait # Copy data diff --git a/dev/run-tests b/dev/run-tests index de607e434445..328a73bd8b26 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -139,9 +139,6 @@ echo "=========================================================================" CURRENT_BLOCK=$BLOCK_BUILD { - # We always build with Hive because the PySpark Spark SQL tests need it. - BUILD_MVN_PROFILE_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-0.12.0" - # NOTE: echo "q" is needed because sbt on encountering a build file with failure #+ (either resolution or compilation) prompts the user for input either q, r, etc @@ -151,15 +148,17 @@ CURRENT_BLOCK=$BLOCK_BUILD # QUESTION: Why doesn't 'yes "q"' work? # QUESTION: Why doesn't 'grep -v -e "^\[info\] Resolving"' work? # First build with 0.12 to ensure patches do not break the hive 12 build + HIVE_12_BUILD_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver -Phive-0.12.0" echo "[info] Compile with hive 0.12" echo -e "q\n" \ - | sbt/sbt $BUILD_MVN_PROFILE_ARGS clean hive/compile hive-thriftserver/compile \ + | sbt/sbt $HIVE_12_BUILD_ARGS clean hive/compile hive-thriftserver/compile \ | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" # Then build with default version(0.13.1) because tests are based on this version - echo "[info] Building Spark with these arguments: $SBT_MAVEN_PROFILES_ARGS -Phive" + echo "[info] Building Spark with these arguments: $SBT_MAVEN_PROFILES_ARGS"\ + " -Phive -Phive-thriftserver" echo -e "q\n" \ - | sbt/sbt $SBT_MAVEN_PROFILES_ARGS -Phive package assembly/assembly \ + | sbt/sbt $SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver package assembly/assembly \ | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" } @@ -174,7 +173,7 @@ CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS # If the Spark SQL tests are enabled, run the tests with the Hive profiles enabled. # This must be a single argument, as it is. if [ -n "$_RUN_SQL_TESTS" ]; then - SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive" + SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver" fi if [ -n "$_SQL_TESTS_ONLY" ]; then diff --git a/dev/scalastyle b/dev/scalastyle index ed1b6b730af6..c3c6012e74ff 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -17,7 +17,7 @@ # limitations under the License. # -echo -e "q\n" | sbt/sbt -Phive scalastyle > scalastyle.txt +echo -e "q\n" | sbt/sbt -Phive -Phive-thriftserver scalastyle > scalastyle.txt # Check style with YARN alpha built too echo -e "q\n" | sbt/sbt -Pyarn-alpha -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn-alpha/scalastyle \ >> scalastyle.txt diff --git a/docs/building-spark.md b/docs/building-spark.md index 238ddae15545..20ba7da5d71f 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -101,25 +101,34 @@ mvn -Pyarn-alpha -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=0.23.7 -Dski # Building With Hive and JDBC Support To enable Hive integration for Spark SQL along with its JDBC server and CLI, -add the `-Phive` profile to your existing build options. By default Spark -will build with Hive 0.13.1 bindings. You can also build for Hive 0.12.0 using -the `-Phive-0.12.0` profile. +add the `-Phive` and `Phive-thriftserver` profiles to your existing build options. +By default Spark will build with Hive 0.13.1 bindings. You can also build for +Hive 0.12.0 using the `-Phive-0.12.0` profile. {% highlight bash %} # Apache Hadoop 2.4.X with Hive 13 support -mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -DskipTests clean package +mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-thriftserver -DskipTests clean package # Apache Hadoop 2.4.X with Hive 12 support -mvn -Pyarn -Phive-0.12.0 -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -DskipTests clean package +mvn -Pyarn -Phive -Phive-thriftserver-0.12.0 -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-thriftserver -DskipTests clean package {% endhighlight %} +# Building for Scala 2.11 +To produce a Spark package compiled with Scala 2.11, use the `-Pscala-2.11` profile: + + mvn -Pyarn -Phadoop-2.4 -Pscala-2.11 -DskipTests clean package + +Scala 2.11 support in Spark is experimental and does not support a few features. +Specifically, Spark's external Kafka library and JDBC component are not yet +supported in Scala 2.11 builds. + # Spark Tests in Maven Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.org/user_guide/using_the_scalatest_maven_plugin). Some of the tests require Spark to be packaged first, so always run `mvn package` with `-DskipTests` the first time. The following is an example of a correct (build, test) sequence: - mvn -Pyarn -Phadoop-2.3 -DskipTests -Phive clean package - mvn -Pyarn -Phadoop-2.3 -Phive test + mvn -Pyarn -Phadoop-2.3 -DskipTests -Phive -Phive-thriftserver clean package + mvn -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver test The ScalaTest plugin also supports running only a specific test suite as follows: @@ -182,16 +191,16 @@ can be set to control the SBT build. For example: Some of the tests require Spark to be packaged first, so always run `sbt/sbt assembly` the first time. The following is an example of a correct (build, test) sequence: - sbt/sbt -Pyarn -Phadoop-2.3 -Phive assembly - sbt/sbt -Pyarn -Phadoop-2.3 -Phive test + sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver assembly + sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver test To run only a specific test suite as follows: - sbt/sbt -Pyarn -Phadoop-2.3 -Phive "test-only org.apache.spark.repl.ReplSuite" + sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver "test-only org.apache.spark.repl.ReplSuite" To run test suites of a specific sub project as follows: - sbt/sbt -Pyarn -Phadoop-2.3 -Phive core/test + sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver core/test # Speeding up Compilation with Zinc diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index ffcce2c58887..48e8267ac072 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -728,7 +728,7 @@ anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD) Spark SQL also supports reading and writing data stored in [Apache Hive](http://hive.apache.org/). However, since Hive has a large number of dependencies, it is not included in the default Spark assembly. -In order to use Hive you must first run "`sbt/sbt -Phive assembly/assembly`" (or use `-Phive` for maven). +Hive support is enabled by adding the `-Phive` and `-Phive-thriftserver` flags to Spark's build. This command builds a new assembly jar that includes Hive. Note that this Hive assembly jar must also be present on all of the worker nodes, as they will need access to the Hive serialization and deserialization libraries (SerDes) in order to access data stored in Hive. diff --git a/examples/pom.xml b/examples/pom.xml index 910eb55308b9..2ec5728154ab 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -34,48 +34,6 @@ Spark Project Examples http://spark.apache.org/ - - - kinesis-asl - - - org.apache.spark - spark-streaming-kinesis-asl_${scala.binary.version} - ${project.version} - - - org.apache.httpcomponents - httpclient - ${commons.httpclient.version} - - - - - hbase-hadoop2 - - - hbase.profile - hadoop2 - - - - 0.98.7-hadoop2 - - - - hbase-hadoop1 - - - !hbase.profile - - - - 0.98.7-hadoop1 - - - - - @@ -124,11 +82,6 @@ spark-streaming-twitter_${scala.binary.version} ${project.version} - - org.apache.spark - spark-streaming-kafka_${scala.binary.version} - ${project.version} - org.apache.spark spark-streaming-flume_${scala.binary.version} @@ -136,12 +89,12 @@ org.apache.spark - spark-streaming-zeromq_${scala.binary.version} + spark-streaming-mqtt_${scala.binary.version} ${project.version} org.apache.spark - spark-streaming-mqtt_${scala.binary.version} + spark-streaming-zeromq_${scala.binary.version} ${project.version} @@ -260,11 +213,6 @@ test-jar test - - com.twitter - algebird-core_${scala.binary.version} - 0.1.11 - org.apache.commons commons-math3 @@ -401,4 +349,147 @@ + + + kinesis-asl + + + org.apache.spark + spark-streaming-kinesis-asl_${scala.binary.version} + ${project.version} + + + org.apache.httpcomponents + httpclient + ${commons.httpclient.version} + + + + + hbase-hadoop2 + + + hbase.profile + hadoop2 + + + + 0.98.7-hadoop2 + + + + hbase-hadoop1 + + + !hbase.profile + + + + 0.98.7-hadoop1 + + + + + scala-2.10 + + true + + + + org.apache.spark + spark-streaming-kafka_${scala.binary.version} + ${project.version} + + + com.twitter + algebird-core_${scala.binary.version} + 0.1.11 + + + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-scala-sources + generate-sources + + add-source + + + + src/main/scala + scala-2.10/src/main/scala + scala-2.10/src/main/java + + + + + add-scala-test-sources + generate-test-sources + + add-test-source + + + + src/test/scala + scala-2.10/src/test/scala + scala-2.10/src/test/java + + + + + + + + + + scala-2.11 + + false + + + + + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-scala-sources + generate-sources + + add-source + + + + src/main/scala + scala-2.11/src/main/scala + + + + + add-scala-test-sources + generate-test-sources + + add-test-source + + + + src/test/scala + scala-2.11/src/test/scala + + + + + + + + + diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java b/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java similarity index 100% rename from examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java rename to examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala similarity index 100% rename from examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala rename to examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala similarity index 100% rename from examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala rename to examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala similarity index 100% rename from examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala rename to examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 371f1f1e9d39..362a76e51593 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -52,11 +52,6 @@ mqtt-client 0.4.0 - - ${akka.group} - akka-zeromq_${scala.binary.version} - ${akka.version} - org.scalatest scalatest_${scala.binary.version} diff --git a/make-distribution.sh b/make-distribution.sh index 0bc839e1dbe4..d46edbc50d15 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -59,7 +59,7 @@ while (( "$#" )); do exit_with_usage ;; --with-hive) - echo "Error: '--with-hive' is no longer supported, use Maven option -Phive" + echo "Error: '--with-hive' is no longer supported, use Maven options -Phive and -Phive-thriftserver" exit_with_usage ;; --skip-java-test) diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index 27c8467687f1..a180a5e5f926 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -39,7 +39,7 @@ org.apache.spark - spark-network-common_2.10 + spark-network-common_${scala.binary.version} ${project.version} @@ -58,7 +58,7 @@ org.apache.spark - spark-network-common_2.10 + spark-network-common_${scala.binary.version} ${project.version} test-jar test diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml index 6e6f6f3e7929..85960eb85b48 100644 --- a/network/yarn/pom.xml +++ b/network/yarn/pom.xml @@ -39,7 +39,7 @@ org.apache.spark - spark-network-shuffle_2.10 + spark-network-shuffle_${scala.binary.version} ${project.version} diff --git a/pom.xml b/pom.xml index 4e0cd6c151d0..7bbde31e572d 100644 --- a/pom.xml +++ b/pom.xml @@ -97,30 +97,26 @@ sql/catalyst sql/core sql/hive - repl assembly external/twitter - external/kafka external/flume external/flume-sink - external/zeromq external/mqtt + external/zeromq examples + repl UTF-8 UTF-8 - + org.spark-project.akka + 2.3.4-spark 1.6 spark - 2.10.4 - 2.10 2.0.1 0.18.1 shaded-protobuf - org.spark-project.akka - 2.3.4-spark 1.7.5 1.2.17 1.0.4 @@ -137,7 +133,7 @@ 1.6.0rc3 1.2.3 8.1.14.v20131031 - 0.3.6 + 0.5.0 3.0.0 1.7.6 @@ -146,9 +142,13 @@ 1.1.0 4.2.6 3.1.1 - + ${project.build.directory}/spark-test-classpath.txt 64m 512m + 2.10.4 + 2.10 + ${scala.version} + org.scala-lang @@ -267,19 +267,66 @@ + + - org.spark-project.spark unused 1.0.0 + + + org.codehaus.groovy + groovy-all + 2.3.7 + provided + + + ${jline.groupid} + jline + ${jline.version} + + + com.twitter + chill_${scala.binary.version} + ${chill.version} + + + org.ow2.asm + asm + + + org.ow2.asm + asm-commons + + + + + com.twitter + chill-java + ${chill.version} + + + org.ow2.asm + asm + + + org.ow2.asm + asm-commons + + + org.eclipse.jetty jetty-util @@ -395,36 +442,6 @@ protobuf-java ${protobuf.version} - - com.twitter - chill_${scala.binary.version} - ${chill.version} - - - org.ow2.asm - asm - - - org.ow2.asm - asm-commons - - - - - com.twitter - chill-java - ${chill.version} - - - org.ow2.asm - asm - - - org.ow2.asm - asm-commons - - - ${akka.group} akka-actor_${scala.binary.version} @@ -512,11 +529,6 @@ scala-reflect ${scala.version} - - org.scala-lang - jline - ${scala.version} - org.scala-lang scala-library @@ -965,6 +977,7 @@ ${session.executionRootDirectory} 1 false + ${test_classpath} @@ -1026,6 +1039,47 @@ + + + org.apache.maven.plugins + maven-dependency-plugin + 2.9 + + + test-compile + + build-classpath + + + test + ${test_classpath_file} + + + + + + + + org.codehaus.gmavenplus + gmavenplus-plugin + 1.2 + + + process-test-classes + + execute + + + + + + + + + org.apache.maven.plugins @@ -1335,7 +1389,7 @@ - hive + hive-thriftserver false @@ -1365,5 +1419,35 @@ 10.10.1.1 + + + scala-2.10 + + true + + + 2.10.4 + 2.10 + ${scala.version} + org.scala-lang + + + external/kafka + + + + + scala-2.11 + + false + + + 2.11.2 + 2.11 + 2.12 + jline + + + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 351e57a4b578..492607d558de 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -31,8 +31,8 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, - sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingKafka, - streamingMqtt, streamingTwitter, streamingZeromq) = + sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingKafka, + streamingMqtt, streamingTwitter, streamingZeromq) = Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", "sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", @@ -68,8 +68,8 @@ object SparkBuild extends PomBuild { profiles ++= Seq("spark-ganglia-lgpl") } if (Properties.envOrNone("SPARK_HIVE").isDefined) { - println("NOTE: SPARK_HIVE is deprecated, please use -Phive flag.") - profiles ++= Seq("hive") + println("NOTE: SPARK_HIVE is deprecated, please use -Phive and -Phive-thriftserver flags.") + profiles ++= Seq("hive", "hive-thriftserver") } Properties.envOrNone("SPARK_HADOOP_VERSION") match { case Some(v) => @@ -91,13 +91,21 @@ object SparkBuild extends PomBuild { profiles } - override val profiles = Properties.envOrNone("SBT_MAVEN_PROFILES") match { + override val profiles = { + val profiles = Properties.envOrNone("SBT_MAVEN_PROFILES") match { case None => backwardCompatibility case Some(v) => if (backwardCompatibility.nonEmpty) println("Note: We ignore environment variables, when use of profile is detected in " + "conjunction with environment variable.") v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.trim.replaceAll("-P", "")).toSeq + } + if (profiles.exists(_.contains("scala-"))) { + profiles + } else { + println("Enabled default scala profile") + profiles ++ Seq("scala-2.10") + } } Properties.envOrNone("SBT_MAVEN_PROPERTIES") match { @@ -136,7 +144,8 @@ object SparkBuild extends PomBuild { // Note ordering of these settings matter. /* Enable shared settings on all projects */ - (allProjects ++ optionallyEnabledProjects ++ assemblyProjects).foreach(enable(sharedSettings)) + (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ Seq(spark, tools)) + .foreach(enable(sharedSettings ++ ExludedDependencies.settings)) /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) @@ -178,6 +187,16 @@ object Flume { lazy val settings = sbtavro.SbtAvro.avroSettings } +/** + This excludes library dependencies in sbt, which are specified in maven but are + not needed by sbt build. + */ +object ExludedDependencies { + lazy val settings = Seq( + libraryDependencies ~= { libs => libs.filterNot(_.name == "groovy-all") } + ) +} + /** * Following project only exists to pull previous artifacts of Spark for generating * Mima ignores. For more information see: SPARK 2071 @@ -353,8 +372,11 @@ object TestSettings { .map { case (k,v) => s"-D$k=$v" }.toSeq, javaOptions in Test ++= "-Xmx3g -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g" .split(" ").toSeq, + // This places test scope jars on the classpath of executors during tests. + javaOptions in Test += + "-Dspark.executor.extraClassPath=" + (fullClasspath in Test).value.files. + map(_.getAbsolutePath).mkString(":").stripSuffix(":"), javaOptions += "-Xmx3g", - // Show full stack trace and duration in test cases. testOptions in Test += Tests.Argument("-oDF"), testOptions += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), diff --git a/project/project/SparkPluginBuild.scala b/project/project/SparkPluginBuild.scala index 3ef2d5451da0..8863f272da41 100644 --- a/project/project/SparkPluginBuild.scala +++ b/project/project/SparkPluginBuild.scala @@ -26,7 +26,7 @@ import sbt.Keys._ object SparkPluginDef extends Build { lazy val root = Project("plugins", file(".")) dependsOn(sparkStyle, sbtPomReader) lazy val sparkStyle = Project("spark-style", file("spark-style"), settings = styleSettings) - lazy val sbtPomReader = uri("https://github.com/ScrapCodes/sbt-pom-reader.git") + lazy val sbtPomReader = uri("https://github.com/ScrapCodes/sbt-pom-reader.git#ignore_artifact_id") // There is actually no need to publish this artifact. def styleSettings = Defaults.defaultSettings ++ Seq ( diff --git a/repl/pom.xml b/repl/pom.xml index af528c891433..bd688c8c1e75 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -38,6 +38,11 @@ + + ${jline.groupid} + jline + ${jline.version} + org.apache.spark spark-core_${scala.binary.version} @@ -75,11 +80,6 @@ scala-reflect ${scala.version} - - org.scala-lang - jline - ${scala.version} - org.slf4j jul-to-slf4j @@ -124,4 +124,84 @@ + + + scala-2.10 + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-scala-sources + generate-sources + + add-source + + + + src/main/scala + scala-2.10/src/main/scala + + + + + add-scala-test-sources + generate-test-sources + + add-test-source + + + + src/test/scala + scala-2.10/src/test/scala + + + + + + + + + + scala-2.11 + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-scala-sources + generate-sources + + add-source + + + + src/main/scala + scala-2.11/src/main/scala + + + + + add-scala-test-sources + generate-test-sources + + add-test-source + + + + src/test/scala + scala-2.11/src/test/scala + + + + + + + + + diff --git a/repl/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/Main.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkHelper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkHelper.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkImports.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkImports.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkImports.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkRunnerSettings.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkRunnerSettings.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkRunnerSettings.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkRunnerSettings.scala diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala similarity index 100% rename from repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala rename to repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala new file mode 100644 index 000000000000..5e93a7199507 --- /dev/null +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.repl + +import org.apache.spark.util.Utils +import org.apache.spark._ + +import scala.tools.nsc.Settings +import scala.tools.nsc.interpreter.SparkILoop + +object Main extends Logging { + + val conf = new SparkConf() + val tmp = System.getProperty("java.io.tmpdir") + val rootDir = conf.get("spark.repl.classdir", tmp) + val outputDir = Utils.createTempDir(rootDir) + val s = new Settings() + s.processArguments(List("-Yrepl-class-based", + "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", "-Yrepl-sync"), true) + val classServer = new HttpServer(outputDir, new SecurityManager(conf)) + var sparkContext: SparkContext = _ + var interp = new SparkILoop // this is a public var because tests reset it. + + def main(args: Array[String]) { + if (getMaster == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") + // Start the classServer and store its URI in a spark system property + // (which will be passed to executors so that they can connect to it) + classServer.start() + interp.process(s) // Repl starts and goes in loop of R.E.P.L + classServer.stop() + Option(sparkContext).map(_.stop) + } + + + def getAddedJars: Array[String] = { + val envJars = sys.env.get("ADD_JARS") + val propJars = sys.props.get("spark.jars").flatMap { p => if (p == "") None else Some(p) } + val jars = propJars.orElse(envJars).getOrElse("") + Utils.resolveURIs(jars).split(",").filter(_.nonEmpty) + } + + def createSparkContext(): SparkContext = { + val execUri = System.getenv("SPARK_EXECUTOR_URI") + val jars = getAddedJars + val conf = new SparkConf() + .setMaster(getMaster) + .setAppName("Spark shell") + .setJars(jars) + .set("spark.repl.class.uri", classServer.uri) + logInfo("Spark class server started at " + classServer.uri) + if (execUri != null) { + conf.set("spark.executor.uri", execUri) + } + if (System.getenv("SPARK_HOME") != null) { + conf.setSparkHome(System.getenv("SPARK_HOME")) + } + sparkContext = new SparkContext(conf) + logInfo("Created spark context..") + sparkContext + } + + private def getMaster: String = { + val master = { + val envMaster = sys.env.get("MASTER") + val propMaster = sys.props.get("spark.master") + propMaster.orElse(envMaster).getOrElse("local[*]") + } + master + } +} diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala new file mode 100644 index 000000000000..8e519fa67f64 --- /dev/null +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala @@ -0,0 +1,86 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2013 LAMP/EPFL + * @author Paul Phillips + */ + +package scala.tools.nsc +package interpreter + +import scala.tools.nsc.ast.parser.Tokens.EOF + +trait SparkExprTyper { + val repl: SparkIMain + + import repl._ + import global.{ reporter => _, Import => _, _ } + import naming.freshInternalVarName + + def symbolOfLine(code: String): Symbol = { + def asExpr(): Symbol = { + val name = freshInternalVarName() + // Typing it with a lazy val would give us the right type, but runs + // into compiler bugs with things like existentials, so we compile it + // behind a def and strip the NullaryMethodType which wraps the expr. + val line = "def " + name + " = " + code + + interpretSynthetic(line) match { + case IR.Success => + val sym0 = symbolOfTerm(name) + // drop NullaryMethodType + sym0.cloneSymbol setInfo exitingTyper(sym0.tpe_*.finalResultType) + case _ => NoSymbol + } + } + def asDefn(): Symbol = { + val old = repl.definedSymbolList.toSet + + interpretSynthetic(code) match { + case IR.Success => + repl.definedSymbolList filterNot old match { + case Nil => NoSymbol + case sym :: Nil => sym + case syms => NoSymbol.newOverloaded(NoPrefix, syms) + } + case _ => NoSymbol + } + } + def asError(): Symbol = { + interpretSynthetic(code) + NoSymbol + } + beSilentDuring(asExpr()) orElse beSilentDuring(asDefn()) orElse asError() + } + + private var typeOfExpressionDepth = 0 + def typeOfExpression(expr: String, silent: Boolean = true): Type = { + if (typeOfExpressionDepth > 2) { + repldbg("Terminating typeOfExpression recursion for expression: " + expr) + return NoType + } + typeOfExpressionDepth += 1 + // Don't presently have a good way to suppress undesirable success output + // while letting errors through, so it is first trying it silently: if there + // is an error, and errors are desired, then it re-evaluates non-silently + // to induce the error message. + try beSilentDuring(symbolOfLine(expr).tpe) match { + case NoType if !silent => symbolOfLine(expr).tpe // generate error + case tpe => tpe + } + finally typeOfExpressionDepth -= 1 + } + + // This only works for proper types. + def typeOfTypeString(typeString: String): Type = { + def asProperType(): Option[Type] = { + val name = freshInternalVarName() + val line = "def %s: %s = ???" format (name, typeString) + interpretSynthetic(line) match { + case IR.Success => + val sym0 = symbolOfTerm(name) + Some(sym0.asMethod.returnType) + case _ => None + } + } + beSilentDuring(asProperType()) getOrElse NoType + } +} diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala new file mode 100644 index 000000000000..a591e9fc4622 --- /dev/null +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -0,0 +1,966 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2013 LAMP/EPFL + * @author Alexander Spoon + */ + +package scala +package tools.nsc +package interpreter + +import scala.language.{ implicitConversions, existentials } +import scala.annotation.tailrec +import Predef.{ println => _, _ } +import interpreter.session._ +import StdReplTags._ +import scala.reflect.api.{Mirror, Universe, TypeCreator} +import scala.util.Properties.{ jdkHome, javaVersion, versionString, javaVmName } +import scala.tools.nsc.util.{ ClassPath, Exceptional, stringFromWriter, stringFromStream } +import scala.reflect.{ClassTag, classTag} +import scala.reflect.internal.util.{ BatchSourceFile, ScalaClassLoader } +import ScalaClassLoader._ +import scala.reflect.io.{ File, Directory } +import scala.tools.util._ +import scala.collection.generic.Clearable +import scala.concurrent.{ ExecutionContext, Await, Future, future } +import ExecutionContext.Implicits._ +import java.io.{ BufferedReader, FileReader } + +/** The Scala interactive shell. It provides a read-eval-print loop + * around the Interpreter class. + * After instantiation, clients should call the main() method. + * + * If no in0 is specified, then input will come from the console, and + * the class will attempt to provide input editing feature such as + * input history. + * + * @author Moez A. Abdel-Gawad + * @author Lex Spoon + * @version 1.2 + */ +class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter) + extends AnyRef + with LoopCommands +{ + def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out) + def this() = this(None, new JPrintWriter(Console.out, true)) +// +// @deprecated("Use `intp` instead.", "2.9.0") def interpreter = intp +// @deprecated("Use `intp` instead.", "2.9.0") def interpreter_= (i: Interpreter): Unit = intp = i + + var in: InteractiveReader = _ // the input stream from which commands come + var settings: Settings = _ + var intp: SparkIMain = _ + + var globalFuture: Future[Boolean] = _ + + protected def asyncMessage(msg: String) { + if (isReplInfo || isReplPower) + echoAndRefresh(msg) + } + + def initializeSpark() { + intp.beQuietDuring { + command( """ + @transient val sc = org.apache.spark.repl.Main.createSparkContext(); + """) + command("import org.apache.spark.SparkContext._") + } + echo("Spark context available as sc.") + } + + /** Print a welcome message */ + def printWelcome() { + import org.apache.spark.SPARK_VERSION + echo("""Welcome to + ____ __ + / __/__ ___ _____/ /__ + _\ \/ _ \/ _ `/ __/ '_/ + /___/ .__/\_,_/_/ /_/\_\ version %s + /_/ + """.format(SPARK_VERSION)) + val welcomeMsg = "Using Scala %s (%s, Java %s)".format( + versionString, javaVmName, javaVersion) + echo(welcomeMsg) + echo("Type in expressions to have them evaluated.") + echo("Type :help for more information.") + } + + override def echoCommandMessage(msg: String) { + intp.reporter printUntruncatedMessage msg + } + + // lazy val power = new Power(intp, new StdReplVals(this))(tagOfStdReplVals, classTag[StdReplVals]) + def history = in.history + + // classpath entries added via :cp + var addedClasspath: String = "" + + /** A reverse list of commands to replay if the user requests a :replay */ + var replayCommandStack: List[String] = Nil + + /** A list of commands to replay if the user requests a :replay */ + def replayCommands = replayCommandStack.reverse + + /** Record a command for replay should the user request a :replay */ + def addReplay(cmd: String) = replayCommandStack ::= cmd + + def savingReplayStack[T](body: => T): T = { + val saved = replayCommandStack + try body + finally replayCommandStack = saved + } + def savingReader[T](body: => T): T = { + val saved = in + try body + finally in = saved + } + + /** Close the interpreter and set the var to null. */ + def closeInterpreter() { + if (intp ne null) { + intp.close() + intp = null + } + } + + class SparkILoopInterpreter extends SparkIMain(settings, out) { + outer => + + override lazy val formatting = new Formatting { + def prompt = SparkILoop.this.prompt + } + override protected def parentClassLoader = + settings.explicitParentLoader.getOrElse( classOf[SparkILoop].getClassLoader ) + } + + /** Create a new interpreter. */ + def createInterpreter() { + if (addedClasspath != "") + settings.classpath append addedClasspath + + intp = new SparkILoopInterpreter + } + + /** print a friendly help message */ + def helpCommand(line: String): Result = { + if (line == "") helpSummary() + else uniqueCommand(line) match { + case Some(lc) => echo("\n" + lc.help) + case _ => ambiguousError(line) + } + } + private def helpSummary() = { + val usageWidth = commands map (_.usageMsg.length) max + val formatStr = "%-" + usageWidth + "s %s" + + echo("All commands can be abbreviated, e.g. :he instead of :help.") + + commands foreach { cmd => + echo(formatStr.format(cmd.usageMsg, cmd.help)) + } + } + private def ambiguousError(cmd: String): Result = { + matchingCommands(cmd) match { + case Nil => echo(cmd + ": no such command. Type :help for help.") + case xs => echo(cmd + " is ambiguous: did you mean " + xs.map(":" + _.name).mkString(" or ") + "?") + } + Result(keepRunning = true, None) + } + private def matchingCommands(cmd: String) = commands filter (_.name startsWith cmd) + private def uniqueCommand(cmd: String): Option[LoopCommand] = { + // this lets us add commands willy-nilly and only requires enough command to disambiguate + matchingCommands(cmd) match { + case List(x) => Some(x) + // exact match OK even if otherwise appears ambiguous + case xs => xs find (_.name == cmd) + } + } + + /** Show the history */ + lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") { + override def usage = "[num]" + def defaultLines = 20 + + def apply(line: String): Result = { + if (history eq NoHistory) + return "No history available." + + val xs = words(line) + val current = history.index + val count = try xs.head.toInt catch { case _: Exception => defaultLines } + val lines = history.asStrings takeRight count + val offset = current - lines.size + 1 + + for ((line, index) <- lines.zipWithIndex) + echo("%3d %s".format(index + offset, line)) + } + } + + // When you know you are most likely breaking into the middle + // of a line being typed. This softens the blow. + protected def echoAndRefresh(msg: String) = { + echo("\n" + msg) + in.redrawLine() + } + protected def echo(msg: String) = { + out println msg + out.flush() + } + + /** Search the history */ + def searchHistory(_cmdline: String) { + val cmdline = _cmdline.toLowerCase + val offset = history.index - history.size + 1 + + for ((line, index) <- history.asStrings.zipWithIndex ; if line.toLowerCase contains cmdline) + echo("%d %s".format(index + offset, line)) + } + + private val currentPrompt = Properties.shellPromptString + + /** Prompt to print when awaiting input */ + def prompt = currentPrompt + + import LoopCommand.{ cmd, nullary } + + /** Standard commands **/ + lazy val standardCommands = List( + cmd("cp", "", "add a jar or directory to the classpath", addClasspath), + cmd("edit", "|", "edit history", editCommand), + cmd("help", "[command]", "print this summary or command-specific help", helpCommand), + historyCommand, + cmd("h?", "", "search the history", searchHistory), + cmd("imports", "[name name ...]", "show import history, identifying sources of names", importsCommand), + //cmd("implicits", "[-v]", "show the implicits in scope", intp.implicitsCommand), + cmd("javap", "", "disassemble a file or class name", javapCommand), + cmd("line", "|", "place line(s) at the end of history", lineCommand), + cmd("load", "", "interpret lines in a file", loadCommand), + cmd("paste", "[-raw] [path]", "enter paste mode or paste a file", pasteCommand), + // nullary("power", "enable power user mode", powerCmd), + nullary("quit", "exit the interpreter", () => Result(keepRunning = false, None)), + nullary("replay", "reset execution and replay all previous commands", replay), + nullary("reset", "reset the repl to its initial state, forgetting all session entries", resetCommand), + cmd("save", "", "save replayable session to a file", saveCommand), + shCommand, + cmd("settings", "[+|-]", "+enable/-disable flags, set compiler options", changeSettings), + nullary("silent", "disable/enable automatic printing of results", verbosity), +// cmd("type", "[-v] ", "display the type of an expression without evaluating it", typeCommand), +// cmd("kind", "[-v] ", "display the kind of expression's type", kindCommand), + nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand) + ) + + /** Power user commands */ +// lazy val powerCommands: List[LoopCommand] = List( +// cmd("phase", "", "set the implicit phase for power commands", phaseCommand) +// ) + + private def importsCommand(line: String): Result = { + val tokens = words(line) + val handlers = intp.languageWildcardHandlers ++ intp.importHandlers + + handlers.filterNot(_.importedSymbols.isEmpty).zipWithIndex foreach { + case (handler, idx) => + val (types, terms) = handler.importedSymbols partition (_.name.isTypeName) + val imps = handler.implicitSymbols + val found = tokens filter (handler importsSymbolNamed _) + val typeMsg = if (types.isEmpty) "" else types.size + " types" + val termMsg = if (terms.isEmpty) "" else terms.size + " terms" + val implicitMsg = if (imps.isEmpty) "" else imps.size + " are implicit" + val foundMsg = if (found.isEmpty) "" else found.mkString(" // imports: ", ", ", "") + val statsMsg = List(typeMsg, termMsg, implicitMsg) filterNot (_ == "") mkString ("(", ", ", ")") + + intp.reporter.printMessage("%2d) %-30s %s%s".format( + idx + 1, + handler.importString, + statsMsg, + foundMsg + )) + } + } + + private def findToolsJar() = PathResolver.SupplementalLocations.platformTools + + private def addToolsJarToLoader() = { + val cl = findToolsJar() match { + case Some(tools) => ScalaClassLoader.fromURLs(Seq(tools.toURL), intp.classLoader) + case _ => intp.classLoader + } + if (Javap.isAvailable(cl)) { + repldbg(":javap available.") + cl + } + else { + repldbg(":javap unavailable: no tools.jar at " + jdkHome) + intp.classLoader + } + } +// +// protected def newJavap() = +// JavapClass(addToolsJarToLoader(), new IMain.ReplStrippingWriter(intp), Some(intp)) +// +// private lazy val javap = substituteAndLog[Javap]("javap", NoJavap)(newJavap()) + + // Still todo: modules. +// private def typeCommand(line0: String): Result = { +// line0.trim match { +// case "" => ":type [-v] " +// case s => intp.typeCommandInternal(s stripPrefix "-v " trim, verbose = s startsWith "-v ") +// } +// } + +// private def kindCommand(expr: String): Result = { +// expr.trim match { +// case "" => ":kind [-v] " +// case s => intp.kindCommandInternal(s stripPrefix "-v " trim, verbose = s startsWith "-v ") +// } +// } + + private def warningsCommand(): Result = { + if (intp.lastWarnings.isEmpty) + "Can't find any cached warnings." + else + intp.lastWarnings foreach { case (pos, msg) => intp.reporter.warning(pos, msg) } + } + + private def changeSettings(args: String): Result = { + def showSettings() = { + for (s <- settings.userSetSettings.toSeq.sorted) echo(s.toString) + } + def updateSettings() = { + // put aside +flag options + val (pluses, rest) = (args split "\\s+").toList partition (_.startsWith("+")) + val tmps = new Settings + val (ok, leftover) = tmps.processArguments(rest, processAll = true) + if (!ok) echo("Bad settings request.") + else if (leftover.nonEmpty) echo("Unprocessed settings.") + else { + // boolean flags set-by-user on tmp copy should be off, not on + val offs = tmps.userSetSettings filter (_.isInstanceOf[Settings#BooleanSetting]) + val (minuses, nonbools) = rest partition (arg => offs exists (_ respondsTo arg)) + // update non-flags + settings.processArguments(nonbools, processAll = true) + // also snag multi-value options for clearing, e.g. -Ylog: and -language: + for { + s <- settings.userSetSettings + if s.isInstanceOf[Settings#MultiStringSetting] || s.isInstanceOf[Settings#PhasesSetting] + if nonbools exists (arg => arg.head == '-' && arg.last == ':' && (s respondsTo arg.init)) + } s match { + case c: Clearable => c.clear() + case _ => + } + def update(bs: Seq[String], name: String=>String, setter: Settings#Setting=>Unit) = { + for (b <- bs) + settings.lookupSetting(name(b)) match { + case Some(s) => + if (s.isInstanceOf[Settings#BooleanSetting]) setter(s) + else echo(s"Not a boolean flag: $b") + case _ => + echo(s"Not an option: $b") + } + } + update(minuses, identity, _.tryToSetFromPropertyValue("false")) // turn off + update(pluses, "-" + _.drop(1), _.tryToSet(Nil)) // turn on + } + } + if (args.isEmpty) showSettings() else updateSettings() + } + + private def javapCommand(line: String): Result = { +// if (javap == null) +// ":javap unavailable, no tools.jar at %s. Set JDK_HOME.".format(jdkHome) +// else if (line == "") +// ":javap [-lcsvp] [path1 path2 ...]" +// else +// javap(words(line)) foreach { res => +// if (res.isError) return "Failed: " + res.value +// else res.show() +// } + } + + private def pathToPhaseWrapper = intp.originalPath("$r") + ".phased.atCurrent" + + private def phaseCommand(name: String): Result = { +// val phased: Phased = power.phased +// import phased.NoPhaseName +// +// if (name == "clear") { +// phased.set(NoPhaseName) +// intp.clearExecutionWrapper() +// "Cleared active phase." +// } +// else if (name == "") phased.get match { +// case NoPhaseName => "Usage: :phase (e.g. typer, erasure.next, erasure+3)" +// case ph => "Active phase is '%s'. (To clear, :phase clear)".format(phased.get) +// } +// else { +// val what = phased.parse(name) +// if (what.isEmpty || !phased.set(what)) +// "'" + name + "' does not appear to represent a valid phase." +// else { +// intp.setExecutionWrapper(pathToPhaseWrapper) +// val activeMessage = +// if (what.toString.length == name.length) "" + what +// else "%s (%s)".format(what, name) +// +// "Active phase is now: " + activeMessage +// } +// } + } + + /** Available commands */ + def commands: List[LoopCommand] = standardCommands ++ ( + // if (isReplPower) + // powerCommands + // else + Nil + ) + + val replayQuestionMessage = + """|That entry seems to have slain the compiler. Shall I replay + |your session? I can re-run each line except the last one. + |[y/n] + """.trim.stripMargin + + private val crashRecovery: PartialFunction[Throwable, Boolean] = { + case ex: Throwable => + val (err, explain) = ( + if (intp.isInitializeComplete) + (intp.global.throwableAsString(ex), "") + else + (ex.getMessage, "The compiler did not initialize.\n") + ) + echo(err) + + ex match { + case _: NoSuchMethodError | _: NoClassDefFoundError => + echo("\nUnrecoverable error.") + throw ex + case _ => + def fn(): Boolean = + try in.readYesOrNo(explain + replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() }) + catch { case _: RuntimeException => false } + + if (fn()) replay() + else echo("\nAbandoning crashed session.") + } + true + } + + // return false if repl should exit + def processLine(line: String): Boolean = { + import scala.concurrent.duration._ + Await.ready(globalFuture, 60.seconds) + + (line ne null) && (command(line) match { + case Result(false, _) => false + case Result(_, Some(line)) => addReplay(line) ; true + case _ => true + }) + } + + private def readOneLine() = { + out.flush() + in readLine prompt + } + + /** The main read-eval-print loop for the repl. It calls + * command() for each line of input, and stops when + * command() returns false. + */ + @tailrec final def loop() { + if ( try processLine(readOneLine()) catch crashRecovery ) + loop() + } + + /** interpret all lines from a specified file */ + def interpretAllFrom(file: File) { + savingReader { + savingReplayStack { + file applyReader { reader => + in = SimpleReader(reader, out, interactive = false) + echo("Loading " + file + "...") + loop() + } + } + } + } + + /** create a new interpreter and replay the given commands */ + def replay() { + reset() + if (replayCommandStack.isEmpty) + echo("Nothing to replay.") + else for (cmd <- replayCommands) { + echo("Replaying: " + cmd) // flush because maybe cmd will have its own output + command(cmd) + echo("") + } + } + def resetCommand() { + echo("Resetting interpreter state.") + if (replayCommandStack.nonEmpty) { + echo("Forgetting this session history:\n") + replayCommands foreach echo + echo("") + replayCommandStack = Nil + } + if (intp.namedDefinedTerms.nonEmpty) + echo("Forgetting all expression results and named terms: " + intp.namedDefinedTerms.mkString(", ")) + if (intp.definedTypes.nonEmpty) + echo("Forgetting defined types: " + intp.definedTypes.mkString(", ")) + + reset() + } + def reset() { + intp.reset() + unleashAndSetPhase() + } + + def lineCommand(what: String): Result = editCommand(what, None) + + // :edit id or :edit line + def editCommand(what: String): Result = editCommand(what, Properties.envOrNone("EDITOR")) + + def editCommand(what: String, editor: Option[String]): Result = { + def diagnose(code: String) = { + echo("The edited code is incomplete!\n") + val errless = intp compileSources new BatchSourceFile("", s"object pastel {\n$code\n}") + if (errless) echo("The compiler reports no errors.") + } + def historicize(text: String) = history match { + case jlh: JLineHistory => text.lines foreach jlh.add ; jlh.moveToEnd() ; true + case _ => false + } + def edit(text: String): Result = editor match { + case Some(ed) => + val tmp = File.makeTemp() + tmp.writeAll(text) + try { + val pr = new ProcessResult(s"$ed ${tmp.path}") + pr.exitCode match { + case 0 => + tmp.safeSlurp() match { + case Some(edited) if edited.trim.isEmpty => echo("Edited text is empty.") + case Some(edited) => + echo(edited.lines map ("+" + _) mkString "\n") + val res = intp interpret edited + if (res == IR.Incomplete) diagnose(edited) + else { + historicize(edited) + Result(lineToRecord = Some(edited), keepRunning = true) + } + case None => echo("Can't read edited text. Did you delete it?") + } + case x => echo(s"Error exit from $ed ($x), ignoring") + } + } finally { + tmp.delete() + } + case None => + if (historicize(text)) echo("Placing text in recent history.") + else echo(f"No EDITOR defined and you can't change history, echoing your text:%n$text") + } + + // if what is a number, use it as a line number or range in history + def isNum = what forall (c => c.isDigit || c == '-' || c == '+') + // except that "-" means last value + def isLast = (what == "-") + if (isLast || !isNum) { + val name = if (isLast) intp.mostRecentVar else what + val sym = intp.symbolOfIdent(name) + intp.prevRequestList collectFirst { case r if r.defines contains sym => r } match { + case Some(req) => edit(req.line) + case None => echo(s"No symbol in scope: $what") + } + } else try { + val s = what + // line 123, 120+3, -3, 120-123, 120-, note -3 is not 0-3 but (cur-3,cur) + val (start, len) = + if ((s indexOf '+') > 0) { + val (a,b) = s splitAt (s indexOf '+') + (a.toInt, b.drop(1).toInt) + } else { + (s indexOf '-') match { + case -1 => (s.toInt, 1) + case 0 => val n = s.drop(1).toInt ; (history.index - n, n) + case _ if s.last == '-' => val n = s.init.toInt ; (n, history.index - n) + case i => val n = s.take(i).toInt ; (n, s.drop(i+1).toInt - n) + } + } + import scala.collection.JavaConverters._ + val index = (start - 1) max 0 + val text = history match { + case jlh: JLineHistory => jlh.entries(index).asScala.take(len) map (_.value) mkString "\n" + case _ => history.asStrings.slice(index, index + len) mkString "\n" + } + edit(text) + } catch { + case _: NumberFormatException => echo(s"Bad range '$what'") + echo("Use line 123, 120+3, -3, 120-123, 120-, note -3 is not 0-3 but (cur-3,cur)") + } + } + + /** fork a shell and run a command */ + lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") { + override def usage = "" + def apply(line: String): Result = line match { + case "" => showUsage() + case _ => + val toRun = s"new ${classOf[ProcessResult].getName}(${string2codeQuoted(line)})" + intp interpret toRun + () + } + } + + def withFile[A](filename: String)(action: File => A): Option[A] = { + val res = Some(File(filename)) filter (_.exists) map action + if (res.isEmpty) echo("That file does not exist") // courtesy side-effect + res + } + + def loadCommand(arg: String) = { + var shouldReplay: Option[String] = None + withFile(arg)(f => { + interpretAllFrom(f) + shouldReplay = Some(":load " + arg) + }) + Result(keepRunning = true, shouldReplay) + } + + def saveCommand(filename: String): Result = ( + if (filename.isEmpty) echo("File name is required.") + else if (replayCommandStack.isEmpty) echo("No replay commands in session") + else File(filename).printlnAll(replayCommands: _*) + ) + + def addClasspath(arg: String): Unit = { + val f = File(arg).normalize + if (f.exists) { + addedClasspath = ClassPath.join(addedClasspath, f.path) + val totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath) + echo("Added '%s'. Your new classpath is:\n\"%s\"".format(f.path, totalClasspath)) + replay() + } + else echo("The path '" + f + "' doesn't seem to exist.") + } + + def powerCmd(): Result = { + if (isReplPower) "Already in power mode." + else enablePowerMode(isDuringInit = false) + } + def enablePowerMode(isDuringInit: Boolean) = { + replProps.power setValue true + unleashAndSetPhase() + // asyncEcho(isDuringInit, power.banner) + } + private def unleashAndSetPhase() { + if (isReplPower) { + // power.unleash() + // Set the phase to "typer" + // intp beSilentDuring phaseCommand("typer") + } + } + + def asyncEcho(async: Boolean, msg: => String) { + if (async) asyncMessage(msg) + else echo(msg) + } + + def verbosity() = { + val old = intp.printResults + intp.printResults = !old + echo("Switched " + (if (old) "off" else "on") + " result printing.") + } + + /** Run one command submitted by the user. Two values are returned: + * (1) whether to keep running, (2) the line to record for replay, + * if any. */ + def command(line: String): Result = { + if (line startsWith ":") { + val cmd = line.tail takeWhile (x => !x.isWhitespace) + uniqueCommand(cmd) match { + case Some(lc) => lc(line.tail stripPrefix cmd dropWhile (_.isWhitespace)) + case _ => ambiguousError(cmd) + } + } + else if (intp.global == null) Result(keepRunning = false, None) // Notice failure to create compiler + else Result(keepRunning = true, interpretStartingWith(line)) + } + + private def readWhile(cond: String => Boolean) = { + Iterator continually in.readLine("") takeWhile (x => x != null && cond(x)) + } + + def pasteCommand(arg: String): Result = { + var shouldReplay: Option[String] = None + def result = Result(keepRunning = true, shouldReplay) + val (raw, file) = + if (arg.isEmpty) (false, None) + else { + val r = """(-raw)?(\s+)?([^\-]\S*)?""".r + arg match { + case r(flag, sep, name) => + if (flag != null && name != null && sep == null) + echo(s"""I assume you mean "$flag $name"?""") + (flag != null, Option(name)) + case _ => + echo("usage: :paste -raw file") + return result + } + } + val code = file match { + case Some(name) => + withFile(name)(f => { + shouldReplay = Some(s":paste $arg") + val s = f.slurp.trim + if (s.isEmpty) echo(s"File contains no code: $f") + else echo(s"Pasting file $f...") + s + }) getOrElse "" + case None => + echo("// Entering paste mode (ctrl-D to finish)\n") + val text = (readWhile(_ => true) mkString "\n").trim + if (text.isEmpty) echo("\n// Nothing pasted, nothing gained.\n") + else echo("\n// Exiting paste mode, now interpreting.\n") + text + } + def interpretCode() = { + val res = intp interpret code + // if input is incomplete, let the compiler try to say why + if (res == IR.Incomplete) { + echo("The pasted code is incomplete!\n") + // Remembrance of Things Pasted in an object + val errless = intp compileSources new BatchSourceFile("", s"object pastel {\n$code\n}") + if (errless) echo("...but compilation found no error? Good luck with that.") + } + } + def compileCode() = { + val errless = intp compileSources new BatchSourceFile("", code) + if (!errless) echo("There were compilation errors!") + } + if (code.nonEmpty) { + if (raw) compileCode() else interpretCode() + } + result + } + + private object paste extends Pasted { + val ContinueString = " | " + val PromptString = "scala> " + + def interpret(line: String): Unit = { + echo(line.trim) + intp interpret line + echo("") + } + + def transcript(start: String) = { + echo("\n// Detected repl transcript paste: ctrl-D to finish.\n") + apply(Iterator(start) ++ readWhile(_.trim != PromptString.trim)) + } + } + import paste.{ ContinueString, PromptString } + + /** Interpret expressions starting with the first line. + * Read lines until a complete compilation unit is available + * or until a syntax error has been seen. If a full unit is + * read, go ahead and interpret it. Return the full string + * to be recorded for replay, if any. + */ + def interpretStartingWith(code: String): Option[String] = { + // signal completion non-completion input has been received + in.completion.resetVerbosity() + + def reallyInterpret = { + val reallyResult = intp.interpret(code) + (reallyResult, reallyResult match { + case IR.Error => None + case IR.Success => Some(code) + case IR.Incomplete => + if (in.interactive && code.endsWith("\n\n")) { + echo("You typed two blank lines. Starting a new command.") + None + } + else in.readLine(ContinueString) match { + case null => + // we know compilation is going to fail since we're at EOF and the + // parser thinks the input is still incomplete, but since this is + // a file being read non-interactively we want to fail. So we send + // it straight to the compiler for the nice error message. + intp.compileString(code) + None + + case line => interpretStartingWith(code + "\n" + line) + } + }) + } + + /** Here we place ourselves between the user and the interpreter and examine + * the input they are ostensibly submitting. We intervene in several cases: + * + * 1) If the line starts with "scala> " it is assumed to be an interpreter paste. + * 2) If the line starts with "." (but not ".." or "./") it is treated as an invocation + * on the previous result. + * 3) If the Completion object's execute returns Some(_), we inject that value + * and avoid the interpreter, as it's likely not valid scala code. + */ + if (code == "") None + else if (!paste.running && code.trim.startsWith(PromptString)) { + paste.transcript(code) + None + } + else if (Completion.looksLikeInvocation(code) && intp.mostRecentVar != "") { + interpretStartingWith(intp.mostRecentVar + code) + } + else if (code.trim startsWith "//") { + // line comment, do nothing + None + } + else + reallyInterpret._2 + } + + // runs :load `file` on any files passed via -i + def loadFiles(settings: Settings) = settings match { + case settings: GenericRunnerSettings => + for (filename <- settings.loadfiles.value) { + val cmd = ":load " + filename + command(cmd) + addReplay(cmd) + echo("") + } + case _ => + } + + /** Tries to create a JLineReader, falling back to SimpleReader: + * unless settings or properties are such that it should start + * with SimpleReader. + */ + def chooseReader(settings: Settings): InteractiveReader = { + if (settings.Xnojline || Properties.isEmacsShell) + SimpleReader() + else try new JLineReader( + if (settings.noCompletion) NoCompletion + else new SparkJLineCompletion(intp) + ) + catch { + case ex @ (_: Exception | _: NoClassDefFoundError) => + echo("Failed to created JLineReader: " + ex + "\nFalling back to SimpleReader.") + SimpleReader() + } + } + protected def tagOfStaticClass[T: ClassTag]: u.TypeTag[T] = + u.TypeTag[T]( + m, + new TypeCreator { + def apply[U <: Universe with Singleton](m: Mirror[U]): U # Type = + m.staticClass(classTag[T].runtimeClass.getName).toTypeConstructor.asInstanceOf[U # Type] + }) + + private def loopPostInit() { + // Bind intp somewhere out of the regular namespace where + // we can get at it in generated code. + intp.quietBind(NamedParam[SparkIMain]("$intp", intp)(tagOfStaticClass[SparkIMain], classTag[SparkIMain])) + // Auto-run code via some setting. + ( replProps.replAutorunCode.option + flatMap (f => io.File(f).safeSlurp()) + foreach (intp quietRun _) + ) + // classloader and power mode setup + intp.setContextClassLoader() + if (isReplPower) { + // replProps.power setValue true + // unleashAndSetPhase() + // asyncMessage(power.banner) + } + // SI-7418 Now, and only now, can we enable TAB completion. + in match { + case x: JLineReader => x.consoleReader.postInit + case _ => + } + } + def process(settings: Settings): Boolean = savingContextLoader { + this.settings = settings + createInterpreter() + + // sets in to some kind of reader depending on environmental cues + in = in0.fold(chooseReader(settings))(r => SimpleReader(r, out, interactive = true)) + globalFuture = future { + intp.initializeSynchronous() + loopPostInit() + !intp.reporter.hasErrors + } + import scala.concurrent.duration._ + Await.ready(globalFuture, 10 seconds) + printWelcome() + initializeSpark() + loadFiles(settings) + + try loop() + catch AbstractOrMissingHandler() + finally closeInterpreter() + + true + } + + @deprecated("Use `process` instead", "2.9.0") + def main(settings: Settings): Unit = process(settings) //used by sbt +} + +object SparkILoop { + implicit def loopToInterpreter(repl: SparkILoop): SparkIMain = repl.intp + + // Designed primarily for use by test code: take a String with a + // bunch of code, and prints out a transcript of what it would look + // like if you'd just typed it into the repl. + def runForTranscript(code: String, settings: Settings): String = { + import java.io.{ BufferedReader, StringReader, OutputStreamWriter } + + stringFromStream { ostream => + Console.withOut(ostream) { + val output = new JPrintWriter(new OutputStreamWriter(ostream), true) { + override def write(str: String) = { + // completely skip continuation lines + if (str forall (ch => ch.isWhitespace || ch == '|')) () + else super.write(str) + } + } + val input = new BufferedReader(new StringReader(code.trim + "\n")) { + override def readLine(): String = { + val s = super.readLine() + // helping out by printing the line being interpreted. + if (s != null) + output.println(s) + s + } + } + val repl = new SparkILoop(input, output) + if (settings.classpath.isDefault) + settings.classpath.value = sys.props("java.class.path") + + repl process settings + } + } + } + + /** Creates an interpreter loop with default settings and feeds + * the given code to it as input. + */ + def run(code: String, sets: Settings = new Settings): String = { + import java.io.{ BufferedReader, StringReader, OutputStreamWriter } + + stringFromStream { ostream => + Console.withOut(ostream) { + val input = new BufferedReader(new StringReader(code)) + val output = new JPrintWriter(new OutputStreamWriter(ostream), true) + val repl = new SparkILoop(input, output) + + if (sets.classpath.isDefault) + sets.classpath.value = sys.props("java.class.path") + + repl process sets + } + } + } + def run(lines: List[String]): String = run(lines map (_ + "\n") mkString) +} diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala new file mode 100644 index 000000000000..1bb62c84abdd --- /dev/null +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -0,0 +1,1319 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2013 LAMP/EPFL + * @author Martin Odersky + */ + +package scala +package tools.nsc +package interpreter + +import PartialFunction.cond +import scala.language.implicitConversions +import scala.beans.BeanProperty +import scala.collection.mutable +import scala.concurrent.{ Future, ExecutionContext } +import scala.reflect.runtime.{ universe => ru } +import scala.reflect.{ ClassTag, classTag } +import scala.reflect.internal.util.{ BatchSourceFile, SourceFile } +import scala.tools.util.PathResolver +import scala.tools.nsc.io.AbstractFile +import scala.tools.nsc.typechecker.{ TypeStrings, StructuredTypeStrings } +import scala.tools.nsc.util.{ ScalaClassLoader, stringFromReader, stringFromWriter, StackTraceOps } +import scala.tools.nsc.util.Exceptional.unwrap +import javax.script.{AbstractScriptEngine, Bindings, ScriptContext, ScriptEngine, ScriptEngineFactory, ScriptException, CompiledScript, Compilable} + +/** An interpreter for Scala code. + * + * The main public entry points are compile(), interpret(), and bind(). + * The compile() method loads a complete Scala file. The interpret() method + * executes one line of Scala code at the request of the user. The bind() + * method binds an object to a variable that can then be used by later + * interpreted code. + * + * The overall approach is based on compiling the requested code and then + * using a Java classloader and Java reflection to run the code + * and access its results. + * + * In more detail, a single compiler instance is used + * to accumulate all successfully compiled or interpreted Scala code. To + * "interpret" a line of code, the compiler generates a fresh object that + * includes the line of code and which has public member(s) to export + * all variables defined by that code. To extract the result of an + * interpreted line to show the user, a second "result object" is created + * which imports the variables exported by the above object and then + * exports members called "$eval" and "$print". To accomodate user expressions + * that read from variables or methods defined in previous statements, "import" + * statements are used. + * + * This interpreter shares the strengths and weaknesses of using the + * full compiler-to-Java. The main strength is that interpreted code + * behaves exactly as does compiled code, including running at full speed. + * The main weakness is that redefining classes and methods is not handled + * properly, because rebinding at the Java level is technically difficult. + * + * @author Moez A. Abdel-Gawad + * @author Lex Spoon + */ +class SparkIMain(@BeanProperty val factory: ScriptEngineFactory, initialSettings: Settings, + protected val out: JPrintWriter) extends AbstractScriptEngine with Compilable with SparkImports { + imain => + + setBindings(createBindings, ScriptContext.ENGINE_SCOPE) + object replOutput extends ReplOutput(settings.Yreploutdir) { } + + @deprecated("Use replOutput.dir instead", "2.11.0") + def virtualDirectory = replOutput.dir + // Used in a test case. + def showDirectory() = replOutput.show(out) + + private[nsc] var printResults = true // whether to print result lines + private[nsc] var totalSilence = false // whether to print anything + private var _initializeComplete = false // compiler is initialized + private var _isInitialized: Future[Boolean] = null // set up initialization future + private var bindExceptions = true // whether to bind the lastException variable + private var _executionWrapper = "" // code to be wrapped around all lines + + /** We're going to go to some trouble to initialize the compiler asynchronously. + * It's critical that nothing call into it until it's been initialized or we will + * run into unrecoverable issues, but the perceived repl startup time goes + * through the roof if we wait for it. So we initialize it with a future and + * use a lazy val to ensure that any attempt to use the compiler object waits + * on the future. + */ + private var _classLoader: util.AbstractFileClassLoader = null // active classloader + private val _compiler: ReplGlobal = newCompiler(settings, reporter) // our private compiler + + def compilerClasspath: Seq[java.net.URL] = ( + if (isInitializeComplete) global.classPath.asURLs + else new PathResolver(settings).result.asURLs // the compiler's classpath + ) + def settings = initialSettings + // Run the code body with the given boolean settings flipped to true. + def withoutWarnings[T](body: => T): T = beQuietDuring { + val saved = settings.nowarn.value + if (!saved) + settings.nowarn.value = true + + try body + finally if (!saved) settings.nowarn.value = false + } + + /** construct an interpreter that reports to Console */ + def this(settings: Settings, out: JPrintWriter) = this(null, settings, out) + def this(factory: ScriptEngineFactory, settings: Settings) = this(factory, settings, new NewLinePrintWriter(new ConsoleWriter, true)) + def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true)) + def this(factory: ScriptEngineFactory) = this(factory, new Settings()) + def this() = this(new Settings()) + + lazy val formatting: Formatting = new Formatting { + val prompt = Properties.shellPromptString + } + lazy val reporter: SparkReplReporter = new SparkReplReporter(this) + + import formatting._ + import reporter.{ printMessage, printUntruncatedMessage } + + // This exists mostly because using the reporter too early leads to deadlock. + private def echo(msg: String) { Console println msg } + private def _initSources = List(new BatchSourceFile("", "class $repl_$init { }")) + private def _initialize() = { + try { + // if this crashes, REPL will hang its head in shame + val run = new _compiler.Run() + assert(run.typerPhase != NoPhase, "REPL requires a typer phase.") + run compileSources _initSources + _initializeComplete = true + true + } + catch AbstractOrMissingHandler() + } + private def tquoted(s: String) = "\"\"\"" + s + "\"\"\"" + private val logScope = scala.sys.props contains "scala.repl.scope" + private def scopelog(msg: String) = if (logScope) Console.err.println(msg) + + // argument is a thunk to execute after init is done + def initialize(postInitSignal: => Unit) { + synchronized { + if (_isInitialized == null) { + _isInitialized = + Future(try _initialize() finally postInitSignal)(ExecutionContext.global) + } + } + } + def initializeSynchronous(): Unit = { + if (!isInitializeComplete) { + _initialize() + assert(global != null, global) + } + } + def isInitializeComplete = _initializeComplete + + lazy val global: Global = { + if (!isInitializeComplete) _initialize() + _compiler + } + + import global._ + import definitions.{ ObjectClass, termMember, dropNullaryMethod} + + lazy val runtimeMirror = ru.runtimeMirror(classLoader) + + private def noFatal(body: => Symbol): Symbol = try body catch { case _: FatalError => NoSymbol } + + def getClassIfDefined(path: String) = ( + noFatal(runtimeMirror staticClass path) + orElse noFatal(rootMirror staticClass path) + ) + def getModuleIfDefined(path: String) = ( + noFatal(runtimeMirror staticModule path) + orElse noFatal(rootMirror staticModule path) + ) + + implicit class ReplTypeOps(tp: Type) { + def andAlso(fn: Type => Type): Type = if (tp eq NoType) tp else fn(tp) + } + + // TODO: If we try to make naming a lazy val, we run into big time + // scalac unhappiness with what look like cycles. It has not been easy to + // reduce, but name resolution clearly takes different paths. + object naming extends { + val global: imain.global.type = imain.global + } with Naming { + // make sure we don't overwrite their unwisely named res3 etc. + def freshUserTermName(): TermName = { + val name = newTermName(freshUserVarName()) + if (replScope containsName name) freshUserTermName() + else name + } + def isInternalTermName(name: Name) = isInternalVarName("" + name) + } + import naming._ + + object deconstruct extends { + val global: imain.global.type = imain.global + } with StructuredTypeStrings + + lazy val memberHandlers = new { + val intp: imain.type = imain + } with SparkMemberHandlers + import memberHandlers._ + + /** Temporarily be quiet */ + def beQuietDuring[T](body: => T): T = { + val saved = printResults + printResults = false + try body + finally printResults = saved + } + def beSilentDuring[T](operation: => T): T = { + val saved = totalSilence + totalSilence = true + try operation + finally totalSilence = saved + } + + def quietRun[T](code: String) = beQuietDuring(interpret(code)) + + /** takes AnyRef because it may be binding a Throwable or an Exceptional */ + private def withLastExceptionLock[T](body: => T, alt: => T): T = { + assert(bindExceptions, "withLastExceptionLock called incorrectly.") + bindExceptions = false + + try beQuietDuring(body) + catch logAndDiscard("withLastExceptionLock", alt) + finally bindExceptions = true + } + + def executionWrapper = _executionWrapper + def setExecutionWrapper(code: String) = _executionWrapper = code + def clearExecutionWrapper() = _executionWrapper = "" + + /** interpreter settings */ + lazy val isettings = new SparkISettings(this) + + /** Instantiate a compiler. Overridable. */ + protected def newCompiler(settings: Settings, reporter: reporters.Reporter): ReplGlobal = { + settings.outputDirs setSingleOutput replOutput.dir + settings.exposeEmptyPackage.value = true + new Global(settings, reporter) with ReplGlobal { override def toString: String = "" } + } + + /** Parent classloader. Overridable. */ + protected def parentClassLoader: ClassLoader = + settings.explicitParentLoader.getOrElse( this.getClass.getClassLoader() ) + + /* A single class loader is used for all commands interpreted by this Interpreter. + It would also be possible to create a new class loader for each command + to interpret. The advantages of the current approach are: + + - Expressions are only evaluated one time. This is especially + significant for I/O, e.g. "val x = Console.readLine" + + The main disadvantage is: + + - Objects, classes, and methods cannot be rebound. Instead, definitions + shadow the old ones, and old code objects refer to the old + definitions. + */ + def resetClassLoader() = { + repldbg("Setting new classloader: was " + _classLoader) + _classLoader = null + ensureClassLoader() + } + final def ensureClassLoader() { + if (_classLoader == null) + _classLoader = makeClassLoader() + } + def classLoader: util.AbstractFileClassLoader = { + ensureClassLoader() + _classLoader + } + + def backticked(s: String): String = ( + (s split '.').toList map { + case "_" => "_" + case s if nme.keywords(newTermName(s)) => s"`$s`" + case s => s + } mkString "." + ) + def readRootPath(readPath: String) = getModuleIfDefined(readPath) + + abstract class PhaseDependentOps { + def shift[T](op: => T): T + + def path(name: => Name): String = shift(path(symbolOfName(name))) + def path(sym: Symbol): String = backticked(shift(sym.fullName)) + def sig(sym: Symbol): String = shift(sym.defString) + } + object typerOp extends PhaseDependentOps { + def shift[T](op: => T): T = exitingTyper(op) + } + object flatOp extends PhaseDependentOps { + def shift[T](op: => T): T = exitingFlatten(op) + } + + def originalPath(name: String): String = originalPath(name: TermName) + def originalPath(name: Name): String = typerOp path name + def originalPath(sym: Symbol): String = typerOp path sym + def flatPath(sym: Symbol): String = flatOp shift sym.javaClassName + def translatePath(path: String) = { + val sym = if (path endsWith "$") symbolOfTerm(path.init) else symbolOfIdent(path) + sym.toOption map flatPath + } + def translateEnclosingClass(n: String) = symbolOfTerm(n).enclClass.toOption map flatPath + + private class TranslatingClassLoader(parent: ClassLoader) extends util.AbstractFileClassLoader(replOutput.dir, parent) { + /** Overridden here to try translating a simple name to the generated + * class name if the original attempt fails. This method is used by + * getResourceAsStream as well as findClass. + */ + override protected def findAbstractFile(name: String): AbstractFile = + super.findAbstractFile(name) match { + case null if _initializeComplete => translatePath(name) map (super.findAbstractFile(_)) orNull + case file => file + } + } + private def makeClassLoader(): util.AbstractFileClassLoader = + new TranslatingClassLoader(parentClassLoader match { + case null => ScalaClassLoader fromURLs compilerClasspath + case p => new ScalaClassLoader.URLClassLoader(compilerClasspath, p) + }) + + // Set the current Java "context" class loader to this interpreter's class loader + def setContextClassLoader() = classLoader.setAsContext() + + def allDefinedNames: List[Name] = exitingTyper(replScope.toList.map(_.name).sorted) + def unqualifiedIds: List[String] = allDefinedNames map (_.decode) sorted + + /** Most recent tree handled which wasn't wholly synthetic. */ + private def mostRecentlyHandledTree: Option[Tree] = { + prevRequests.reverse foreach { req => + req.handlers.reverse foreach { + case x: MemberDefHandler if x.definesValue && !isInternalTermName(x.name) => return Some(x.member) + case _ => () + } + } + None + } + + private def updateReplScope(sym: Symbol, isDefined: Boolean) { + def log(what: String) { + val mark = if (sym.isType) "t " else "v " + val name = exitingTyper(sym.nameString) + val info = cleanTypeAfterTyper(sym) + val defn = sym defStringSeenAs info + + scopelog(f"[$mark$what%6s] $name%-25s $defn%s") + } + if (ObjectClass isSubClass sym.owner) return + // unlink previous + replScope lookupAll sym.name foreach { sym => + log("unlink") + replScope unlink sym + } + val what = if (isDefined) "define" else "import" + log(what) + replScope enter sym + } + + def recordRequest(req: Request) { + if (req == null) + return + + prevRequests += req + + // warning about serially defining companions. It'd be easy + // enough to just redefine them together but that may not always + // be what people want so I'm waiting until I can do it better. + exitingTyper { + req.defines filterNot (s => req.defines contains s.companionSymbol) foreach { newSym => + val oldSym = replScope lookup newSym.name.companionName + if (Seq(oldSym, newSym).permutations exists { case Seq(s1, s2) => s1.isClass && s2.isModule }) { + replwarn(s"warning: previously defined $oldSym is not a companion to $newSym.") + replwarn("Companions must be defined together; you may wish to use :paste mode for this.") + } + } + } + exitingTyper { + req.imports foreach (sym => updateReplScope(sym, isDefined = false)) + req.defines foreach (sym => updateReplScope(sym, isDefined = true)) + } + } + + private[nsc] def replwarn(msg: => String) { + if (!settings.nowarnings) + printMessage(msg) + } + + def compileSourcesKeepingRun(sources: SourceFile*) = { + val run = new Run() + assert(run.typerPhase != NoPhase, "REPL requires a typer phase.") + reporter.reset() + run compileSources sources.toList + (!reporter.hasErrors, run) + } + + /** Compile an nsc SourceFile. Returns true if there are + * no compilation errors, or false otherwise. + */ + def compileSources(sources: SourceFile*): Boolean = + compileSourcesKeepingRun(sources: _*)._1 + + /** Compile a string. Returns true if there are no + * compilation errors, or false otherwise. + */ + def compileString(code: String): Boolean = + compileSources(new BatchSourceFile("