diff --git a/.rat-excludes b/.rat-excludes index 769defbac11b7..e3c033145b83a 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -18,6 +18,7 @@ fairscheduler.xml.template spark-defaults.conf.template log4j.properties log4j.properties.template +metrics.properties metrics.properties.template slaves slaves.template diff --git a/CHANGES.txt b/CHANGES.txt new file mode 100644 index 0000000000000..7da0244899e5a --- /dev/null +++ b/CHANGES.txt @@ -0,0 +1,7338 @@ +Spark Change Log +---------------- + +Release 1.3.1 + + [SQL] Use path.makeQualified in newParquet. + Yin Huai + 2015-04-04 23:26:10 +0800 + Commit: eb57d4f, github.com/apache/spark/pull/5353 + + [SPARK-6700] disable flaky test + Davies Liu + 2015-04-03 15:22:21 -0700 + Commit: 3366af6, github.com/apache/spark/pull/5356 + + [SPARK-6688] [core] Always use resolved URIs in EventLoggingListener. + Marcelo Vanzin + 2015-04-03 11:54:31 -0700 + Commit: f17a2fe, github.com/apache/spark/pull/5340 + + [SPARK-6575][SQL] Converted Parquet Metastore tables no longer cache metadata + Yin Huai + 2015-04-03 14:40:36 +0800 + Commit: 0c1b78b, github.com/apache/spark/pull/5339 + + [SPARK-6621][Core] Fix the bug that calling EventLoop.stop in EventLoop.onReceive/onError/onStart doesn't call onStop + zsxwing + 2015-04-02 22:54:30 -0700 + Commit: ac705aa, github.com/apache/spark/pull/5280 + + [SPARK-6345][STREAMING][MLLIB] Fix for training with prediction + freeman + 2015-04-02 21:37:44 -0700 + Commit: d21f779, github.com/apache/spark/pull/5037 + + [CORE] The descriptionof jobHistory config should be spark.history.fs.logDirectory + KaiXinXiaoLei + 2015-04-02 20:24:31 -0700 + Commit: 17ab6b0, github.com/apache/spark/pull/5332 + + [SPARK-6575][SQL] Converted Parquet Metastore tables no longer cache metadata + Yin Huai + 2015-04-02 20:23:08 -0700 + Commit: 0c1c0fb, github.com/apache/spark/pull/5339 + + [SPARK-6650] [core] Stop ExecutorAllocationManager when context stops. + Marcelo Vanzin + 2015-04-02 19:48:55 -0700 + Commit: 0ef46b2, github.com/apache/spark/pull/5311 + + [SPARK-6686][SQL] Use resolved output instead of names for toDF rename + Michael Armbrust + 2015-04-02 18:30:55 -0700 + Commit: 2927af1, github.com/apache/spark/pull/5337 + + [SPARK-6672][SQL] convert row to catalyst in createDataFrame(RDD[Row], ...) + Xiangrui Meng + 2015-04-02 17:57:01 +0800 + Commit: c2694bb, github.com/apache/spark/pull/5329 + + [SPARK-6618][SPARK-6669][SQL] Lock Hive metastore client correctly. + Yin Huai , Michael Armbrust + 2015-04-02 16:46:50 -0700 + Commit: e6ee95c, github.com/apache/spark/pull/5333 + + [Minor] [SQL] Follow-up of PR #5210 + Cheng Lian + 2015-04-02 16:15:34 -0700 + Commit: 4f1fe3f, github.com/apache/spark/pull/5219 + + [SPARK-6655][SQL] We need to read the schema of a data source table stored in spark.sql.sources.schema property + Yin Huai + 2015-04-02 16:02:31 -0700 + Commit: aecec07, github.com/apache/spark/pull/5313 + + [SQL] Throw UnsupportedOperationException instead of NotImplementedError + Michael Armbrust + 2015-04-02 16:01:03 -0700 + Commit: 78ba245, github.com/apache/spark/pull/5315 + + SPARK-6414: Spark driver failed with NPE on job cancelation + Hung Lin + 2015-04-02 14:01:43 -0700 + Commit: 58e2b3f, github.com/apache/spark/pull/5124 + + [SPARK-6079] Use index to speed up StatusTracker.getJobIdsForGroup() + Josh Rosen + 2015-03-25 17:40:00 -0700 + Commit: a6664dc, github.com/apache/spark/pull/4830 + + [SPARK-6667] [PySpark] remove setReuseAddress + Davies Liu + 2015-04-02 12:18:33 -0700 + Commit: ee2bd70, github.com/apache/spark/pull/5324 + + Revert "[SPARK-6618][SQL] HiveMetastoreCatalog.lookupRelation should use fine-grained lock" + Cheng Lian + 2015-04-02 12:59:38 +0800 + Commit: 1160cc9 + + [SQL] SPARK-6658: Update DataFrame documentation to refer to correct types + Michael Armbrust + 2015-04-01 18:00:07 -0400 + Commit: 223dd3f + + [SPARK-6578] Small rewrite to make the logic more clear in MessageWithHeader.transferTo. + Reynold Xin + 2015-04-01 18:36:06 -0700 + Commit: d697b76, github.com/apache/spark/pull/5319 + + [SPARK-6660][MLLIB] pythonToJava doesn't recognize object arrays + Xiangrui Meng + 2015-04-01 18:17:07 -0700 + Commit: 0d1e476, github.com/apache/spark/pull/5318 + + [SPARK-6553] [pyspark] Support functools.partial as UDF + ksonj + 2015-04-01 17:23:57 -0700 + Commit: 98f72df, github.com/apache/spark/pull/5206 + + [SPARK-6642][MLLIB] use 1.2 lambda scaling and remove addImplicit from NormalEquation + Xiangrui Meng + 2015-04-01 16:47:18 -0700 + Commit: bc04fa2, github.com/apache/spark/pull/5314 + + [SPARK-6578] [core] Fix thread-safety issue in outbound path of network library. + Marcelo Vanzin + 2015-04-01 16:06:11 -0700 + Commit: 1c31ebd, github.com/apache/spark/pull/5234 + + [SPARK-6657] [Python] [Docs] fixed python doc build warnings + Joseph K. Bradley + 2015-04-01 15:15:47 -0700 + Commit: e347a7a, github.com/apache/spark/pull/5317 + + [SPARK-6651][MLLIB] delegate dense vector arithmetics to the underlying numpy array + Xiangrui Meng + 2015-04-01 13:29:04 -0700 + Commit: f50d95a, github.com/apache/spark/pull/5312 + + SPARK-6626 [DOCS]: Corrected Scala:TwitterUtils parameters + jayson + 2015-04-01 11:12:55 +0100 + Commit: 7d029cb, github.com/apache/spark/pull/5295 + + [Doc] Improve Python DataFrame documentation + Reynold Xin + 2015-03-31 18:31:36 -0700 + Commit: e527b35, github.com/apache/spark/pull/5287 + + [SPARK-6614] OutputCommitCoordinator should clear authorized committer only after authorized committer fails, not after any failure + Josh Rosen + 2015-03-31 16:18:39 -0700 + Commit: c4c982a, github.com/apache/spark/pull/5276 + + [SPARK-6633][SQL] Should be "Contains" instead of "EndsWith" when constructing sources.StringContains + Liang-Chi Hsieh + 2015-03-31 13:18:07 -0700 + Commit: d851646, github.com/apache/spark/pull/5299 + + [SPARK-5371][SQL] Propagate types after function conversion, before futher resolution + Michael Armbrust + 2015-03-31 11:34:29 -0700 + Commit: 5a957fe, github.com/apache/spark/pull/5278 + + [SPARK-6145][SQL] fix ORDER BY on nested fields + Michael Armbrust + 2015-03-31 11:23:18 -0700 + Commit: 045228f, github.com/apache/spark/pull/5189 + + [SPARK-6575] [SQL] Adds configuration to disable schema merging while converting metastore Parquet tables + Cheng Lian + 2015-03-31 11:21:15 -0700 + Commit: 778c876, github.com/apache/spark/pull/5231 + + [SPARK-6555] [SQL] Overrides equals() and hashCode() for MetastoreRelation + Cheng Lian + 2015-03-31 11:18:25 -0700 + Commit: 9ebefb1, github.com/apache/spark/pull/5289 + + [SPARK-6618][SQL] HiveMetastoreCatalog.lookupRelation should use fine-grained lock + Yin Huai + 2015-03-31 16:28:40 +0800 + Commit: fd600ce, github.com/apache/spark/pull/5281 + + [SPARK-6623][SQL] Alias DataFrame.na.drop and DataFrame.na.fill in Python. + Reynold Xin + 2015-03-31 00:25:23 -0700 + Commit: cf651a4, github.com/apache/spark/pull/5284 + + [SPARK-6625][SQL] Add common string filters to data sources. + Reynold Xin + 2015-03-31 00:19:51 -0700 + Commit: a97d4e6, github.com/apache/spark/pull/5285 + + [SPARK-6119][SQL] DataFrame support for missing data handling + Reynold Xin + 2015-03-30 20:47:10 -0700 + Commit: 67c885e, github.com/apache/spark/pull/5274 + + [SPARK-6369] [SQL] Uses commit coordinator to help committing Hive and Parquet tables + Cheng Lian + 2015-03-31 07:48:37 +0800 + Commit: fedbfc7, github.com/apache/spark/pull/5139 + + [SPARK-6603] [PySpark] [SQL] add SQLContext.udf and deprecate inferSchema() and applySchema + Davies Liu + 2015-03-30 15:47:00 -0700 + Commit: 30e7c63, github.com/apache/spark/pull/5273 + + [SPARK-6592][SQL] fix filter for scaladoc to generate API doc for Row class under catalyst dir + CodingCat + 2015-03-30 11:54:44 -0700 + Commit: f9d4efa, github.com/apache/spark/pull/5252 + + [SPARK-6571][MLLIB] use wrapper in MatrixFactorizationModel.load + Xiangrui Meng + 2015-03-28 15:08:05 -0700 + Commit: 93a7166, github.com/apache/spark/pull/5243 + + [SPARK-6595][SQL] MetastoreRelation should be a MultiInstanceRelation + Michael Armbrust + 2015-03-30 22:24:12 +0800 + Commit: c411530, github.com/apache/spark/pull/5251 + + [SPARK-6558] Utils.getCurrentUserName returns the full principal name instead of login name + Thomas Graves + 2015-03-29 12:43:30 +0100 + Commit: f8132de, github.com/apache/spark/pull/5229 + + [SPARK-5750][SPARK-3441][SPARK-5836][CORE] Added documentation explaining shuffle + Ilya Ganelin , Ilya Ganelin + 2015-03-30 11:52:02 +0100 + Commit: 1c59a4b, github.com/apache/spark/pull/5074 + + [spark-sql] a better exception message than "scala.MatchError" for unsupported types in Schema creation + Eran Medan + 2015-03-30 00:02:52 -0700 + Commit: 4859c40, github.com/apache/spark/pull/5235 + + [HOTFIX] Build break due to NoRelation cherry-pick. + Reynold Xin + 2015-03-29 12:07:28 -0700 + Commit: 6181366 + + [DOC] Improvements to Python docs. + Reynold Xin + 2015-03-28 23:59:27 -0700 + Commit: 3db0844, github.com/apache/spark/pull/5238 + + [SPARK-6538][SQL] Add missing nullable Metastore fields when merging a Parquet schema + Adam Budde + 2015-03-28 09:14:09 +0800 + Commit: 5e04f45, github.com/apache/spark/pull/5214 + + [SPARK-6564][SQL] SQLContext.emptyDataFrame should contain 0 row, not 1 row + Reynold Xin + 2015-03-27 14:56:57 -0700 + Commit: 7006858, github.com/apache/spark/pull/5226 + + [SPARK-6544][build] Increment Avro version from 1.7.6 to 1.7.7 + Dean Chen + 2015-03-27 14:32:51 +0000 + Commit: fefd49f, github.com/apache/spark/pull/5193 + + [SPARK-6574] [PySpark] fix sql example + Davies Liu + 2015-03-27 11:42:26 -0700 + Commit: b902a95, github.com/apache/spark/pull/5230 + + [SPARK-6550][SQL] Use analyzed plan in DataFrame + Michael Armbrust + 2015-03-27 11:40:00 -0700 + Commit: bc75189, github.com/apache/spark/pull/5217 + + [SPARK-6341][mllib] Upgrade breeze from 0.11.1 to 0.11.2 + Yu ISHIKAWA + 2015-03-27 00:15:02 -0700 + Commit: b318858, github.com/apache/spark/pull/5222 + + [DOCS][SQL] Fix JDBC example + Michael Armbrust + 2015-03-26 14:51:46 -0700 + Commit: 54d92b5, github.com/apache/spark/pull/5192 + + [SPARK-6554] [SQL] Don't push down predicates which reference partition column(s) + Cheng Lian + 2015-03-26 13:11:37 -0700 + Commit: 3d54578, github.com/apache/spark/pull/5210 + + [SPARK-6117] [SQL] Improvements to DataFrame.describe() + Reynold Xin + 2015-03-26 12:26:13 -0700 + Commit: 28e3a1e, github.com/apache/spark/pull/5201 + + [SPARK-6117] [SQL] add describe function to DataFrame for summary statis... + azagrebin + 2015-03-26 00:25:04 -0700 + Commit: 84735c3, github.com/apache/spark/pull/5073 + + SPARK-6480 [CORE] histogram() bucket function is wrong in some simple edge cases + Sean Owen + 2015-03-26 15:00:23 +0000 + Commit: aa2d157, github.com/apache/spark/pull/5148 + + [SPARK-6491] Spark will put the current working dir to the CLASSPATH + guliangliang + 2015-03-26 13:28:56 +0000 + Commit: 5b5f0e2, github.com/apache/spark/pull/5156 + + [SQL][SPARK-6471]: Metastore schema should only be a subset of parquet schema to support dropping of columns using replace columns + Yash Datta + 2015-03-26 21:13:38 +0800 + Commit: 836c921, github.com/apache/spark/pull/5141 + + [SPARK-6465][SQL] Fix serialization of GenericRowWithSchema using kryo + Michael Armbrust + 2015-03-26 18:46:57 +0800 + Commit: 8254996, github.com/apache/spark/pull/5191 + + [SPARK-6536] [PySpark] Column.inSet() in Python + Davies Liu + 2015-03-26 00:01:24 -0700 + Commit: 0ba7599, github.com/apache/spark/pull/5190 + + [SPARK-6463][SQL] AttributeSet.equal should compare size + sisihj , Michael Armbrust + 2015-03-25 19:21:54 -0700 + Commit: 9edb34f, github.com/apache/spark/pull/5194 + + [SPARK-6450] [SQL] Fixes metastore Parquet table conversion + Cheng Lian + 2015-03-25 17:40:19 -0700 + Commit: 0cd4748, github.com/apache/spark/pull/5183 + + [SPARK-6409][SQL] It is not necessary that avoid old inteface of hive, because this will make some UDAF can not work. + DoingDone9 <799203320@qq.com> + 2015-03-25 11:11:52 -0700 + Commit: 4efa6c5, github.com/apache/spark/pull/5131 + + SPARK-6063 MLlib doesn't pass mvn scalastyle check due to UTF chars in LDAModel.scala + Michael Griffiths , Griffiths, Michael (NYC-RPM) + 2015-02-28 14:47:39 +0000 + Commit: 6791f42, github.com/apache/spark/pull/4815 + + [SPARK-6496] [MLLIB] GeneralizedLinearAlgorithm.run(input, initialWeights) should initialize numFeatures + Yanbo Liang + 2015-03-25 17:05:56 +0000 + Commit: 2be4255, github.com/apache/spark/pull/5167 + + [DOCUMENTATION]Fixed Missing Type Import in Documentation + Bill Chambers , anabranch + 2015-03-24 22:24:35 -0700 + Commit: 8e4e2e3, github.com/apache/spark/pull/5179 + + [SPARK-6469] Improving documentation on YARN local directories usage + Christophe Préaud + 2015-03-24 17:05:49 -0700 + Commit: 6af9408, github.com/apache/spark/pull/5165 + + [SPARK-3570] Include time to open files in shuffle write time. + Kay Ousterhout + 2015-03-24 16:29:40 -0700 + Commit: e4db5a3, github.com/apache/spark/pull/4550 + + [SPARK-6088] Correct how tasks that get remote results are shown in UI. + Kay Ousterhout + 2015-03-24 16:26:43 -0700 + Commit: de8b2d4, github.com/apache/spark/pull/4839 + + [SPARK-6428][SQL] Added explicit types for all public methods in catalyst + Reynold Xin + 2015-03-24 16:03:55 -0700 + Commit: 586e0d9, github.com/apache/spark/pull/5162 + + [SPARK-6209] Clean up connections in ExecutorClassLoader after failing to load classes (master branch PR) + Josh Rosen + 2015-03-24 14:38:20 -0700 + Commit: dcf56aa, github.com/apache/spark/pull/4944 + + [SPARK-6458][SQL] Better error messages for invalid data sources + Michael Armbrust + 2015-03-24 14:10:56 -0700 + Commit: f48c16d, github.com/apache/spark/pull/5158 + + [SPARK-6376][SQL] Avoid eliminating subqueries until optimization + Michael Armbrust + 2015-03-24 14:08:20 -0700 + Commit: df671bc, github.com/apache/spark/pull/5160 + + [SPARK-6375][SQL] Fix formatting of error messages. + Michael Armbrust + 2015-03-24 13:22:46 -0700 + Commit: 92bf888, github.com/apache/spark/pull/5155 + + Revert "[SPARK-5680][SQL] Sum function on all null values, should return zero" + Michael Armbrust + 2015-03-24 12:32:25 -0700 + Commit: 930b667 + + [SPARK-6054][SQL] Fix transformations of TreeNodes that hold StructTypes + Michael Armbrust + 2015-03-24 12:28:01 -0700 + Commit: c699e2b, github.com/apache/spark/pull/5157 + + [SPARK-6437][SQL] Use completion iterator to close external sorter + Michael Armbrust + 2015-03-24 12:10:30 -0700 + Commit: c0101d3, github.com/apache/spark/pull/5161 + + [SPARK-6459][SQL] Warn when constructing trivially true equals predicate + Michael Armbrust + 2015-03-24 12:09:02 -0700 + Commit: f0141ca, github.com/apache/spark/pull/5163 + + [SPARK-5955][MLLIB] add checkpointInterval to ALS + Xiangrui Meng + 2015-03-20 15:02:57 -0400 + Commit: bc92a2e, github.com/apache/spark/pull/5076 + + [ML][docs][minor] Define LabeledDocument/Document classes in CV example + Peter Rudenko + 2015-03-24 16:33:38 +0000 + Commit: 4ff5771, github.com/apache/spark/pull/5135 + + [SPARK-5559] [Streaming] [Test] Remove oppotunity we met flakiness when running FlumeStreamSuite + Kousuke Saruta + 2015-03-24 16:13:25 +0000 + Commit: 8722369, github.com/apache/spark/pull/4337 + + Update the command to use IPython notebook + Cong Yue + 2015-03-24 12:56:13 +0000 + Commit: e545143, github.com/apache/spark/pull/5111 + + [SPARK-6452] [SQL] Checks for missing attributes and unresolved operator for all types of operator + Cheng Lian + 2015-03-24 01:12:11 -0700 + Commit: 6f10142, github.com/apache/spark/pull/5129 + + [SPARK-6124] Support jdbc connection properties in OPTIONS part of the query + Volodymyr Lyubinets + 2015-03-23 17:00:27 -0700 + Commit: 04b2078, github.com/apache/spark/pull/4859 + + [SPARK-6397][SQL] Check the missingInput simply + Yadong Qi + 2015-03-23 18:16:49 +0800 + Commit: a29f493, github.com/apache/spark/pull/5132 + + [SPARK-4985] [SQL] parquet support for date type + Daoyuan Wang + 2015-03-23 11:46:16 +0800 + Commit: 60b9b96, github.com/apache/spark/pull/3822 + + [SPARK-6337][Documentation, SQL]Spark 1.3 doc fixes + vinodkc + 2015-03-22 20:00:08 +0000 + Commit: 857e8a6, github.com/apache/spark/pull/5112 + + SPARK-6454 [DOCS] Fix links to pyspark api + Kamil Smuga , stderr + 2015-03-22 15:56:25 +0000 + Commit: 3ba295f, github.com/apache/spark/pull/5120 + + [SPARK-6408] [SQL] Fix JDBCRDD filtering string literals + ypcat , Pei-Lun Lee + 2015-03-22 15:49:13 +0800 + Commit: e60fbf6, github.com/apache/spark/pull/5087 + + [SPARK-6428][SQL] Added explicit type for all public methods for Hive module + Reynold Xin + 2015-03-21 14:30:04 -0700 + Commit: 0021d22, github.com/apache/spark/pull/5108 + + [SPARK-6428][SQL] Added explicit type for all public methods in sql/core + Reynold Xin + 2015-03-20 15:47:07 -0700 + Commit: c964588, github.com/apache/spark/pull/5104 + + [SPARK-6250][SPARK-6146][SPARK-5911][SQL] Types are now reserved words in DDL parser. + Yin Huai + 2015-03-21 13:27:53 -0700 + Commit: 102daaf, github.com/apache/spark/pull/5078 + + [SPARK-5680][SQL] Sum function on all null values, should return zero + Venkata Ramana G , Venkata Ramana Gollamudi + 2015-03-21 13:24:24 -0700 + Commit: 93975a3, github.com/apache/spark/pull/4466 + + [SPARK-5320][SQL]Add statistics method at NoRelation (override super). + x1- + 2015-03-21 13:22:34 -0700 + Commit: cba6842, github.com/apache/spark/pull/5105 + + [SPARK-5821] [SQL] JSON CTAS command should throw error message when delete path failure + Yanbo Liang , Yanbo Liang + 2015-03-21 11:23:28 +0800 + Commit: 8de90c7, github.com/apache/spark/pull/4610 + + [SPARK-6315] [SQL] Also tries the case class string parser while reading Parquet schema + Cheng Lian + 2015-03-21 11:18:45 +0800 + Commit: b75943f, github.com/apache/spark/pull/5034 + + [SPARK-5821] [SQL] ParquetRelation2 CTAS should check if delete is successful + Yanbo Liang + 2015-03-21 10:53:04 +0800 + Commit: df83e21, github.com/apache/spark/pull/5107 + + [SPARK-6421][MLLIB] _regression_train_wrapper does not test initialWeights correctly + lewuathe + 2015-03-20 17:18:18 -0400 + Commit: aff9f8d, github.com/apache/spark/pull/5101 + + [SPARK-6286][Mesos][minor] Handle missing Mesos case TASK_ERROR + Jongyoul Lee + 2015-03-20 12:24:34 +0000 + Commit: db812d9, github.com/apache/spark/pull/5088 + + [SPARK-6222][Streaming] Dont delete checkpoint data when doing pre-batch-start checkpoint + Tathagata Das + 2015-03-19 02:15:50 -0400 + Commit: 03e263f, github.com/apache/spark/pull/5008 + + [SPARK-6325] [core,yarn] Do not change target executor count when killing executors. + Marcelo Vanzin + 2015-03-18 09:18:28 -0400 + Commit: 1723f05, github.com/apache/spark/pull/5018 + + [SPARK-6286][minor] Handle missing Mesos case TASK_ERROR. + Iulian Dragos + 2015-03-18 09:15:33 -0400 + Commit: ff0a7f4, github.com/apache/spark/pull/5000 + + [SPARK-6247][SQL] Fix resolution of ambiguous joins caused by new aliases + Michael Armbrust + 2015-03-17 19:47:51 -0700 + Commit: ba8352c, github.com/apache/spark/pull/5062 + + [SPARK-6383][SQL]Fixed compiler and errors in Dataframe examples + Tijo Thomas + 2015-03-17 18:50:19 -0700 + Commit: cee6d08, github.com/apache/spark/pull/5068 + + [SPARK-6366][SQL] In Python API, the default save mode for save and saveAsTable should be "error" instead of "append". + Yin Huai + 2015-03-18 09:41:06 +0800 + Commit: 3ea38bc, github.com/apache/spark/pull/5053 + + [SPARK-6330] [SQL] Add a test case for SPARK-6330 + Pei-Lun Lee + 2015-03-18 08:34:46 +0800 + Commit: 9d88f0c, github.com/apache/spark/pull/5039 + + [SPARK-6336] LBFGS should document what convergenceTol means + lewuathe + 2015-03-17 12:11:57 -0700 + Commit: 476c4e1, github.com/apache/spark/pull/5033 + + [SPARK-6365] jetty-security also needed for SPARK_PREPEND_CLASSES to work + Imran Rashid + 2015-03-17 12:03:54 -0500 + Commit: ac0e7cc, github.com/apache/spark/pull/5071 + + [SPARK-6313] Add config option to disable file locks/fetchFile cache to ... + nemccarthy + 2015-03-17 09:33:11 -0700 + Commit: febb123, github.com/apache/spark/pull/5036 + + [SPARK-3266] Use intermediate abstract classes to fix type erasure issues in Java APIs + Josh Rosen + 2015-03-17 09:18:57 -0700 + Commit: 29e39e1, github.com/apache/spark/pull/5050 + + [SPARK-6331] Load new master URL if present when recovering streaming context from checkpoint + Tathagata Das + 2015-03-17 05:31:27 -0700 + Commit: 95f8d1c, github.com/apache/spark/pull/5024 + + [SQL][docs][minor] Fixed sample code in SQLContext scaladoc + Lomig Mégard + 2015-03-16 23:52:42 -0700 + Commit: 426816b, github.com/apache/spark/pull/5051 + + [SPARK-6299][CORE] ClassNotFoundException in standalone mode when running groupByKey with class defined in REPL + Kevin (Sangwoo) Kim + 2015-03-16 23:49:23 -0700 + Commit: 5c16ced, github.com/apache/spark/pull/5046 + + [SPARK-6077] Remove streaming tab while stopping StreamingContext + lisurprise + 2015-03-16 13:10:32 -0700 + Commit: 47cce98, github.com/apache/spark/pull/4828 + + [SPARK-6330] Fix filesystem bug in newParquet relation + Volodymyr Lyubinets + 2015-03-16 12:13:18 -0700 + Commit: 67fa6d1, github.com/apache/spark/pull/5020 + + SPARK-6245 [SQL] jsonRDD() of empty RDD results in exception + Sean Owen + 2015-03-11 14:09:09 +0000 + Commit: 684ff24, github.com/apache/spark/pull/4971 + + [SPARK-6300][Spark Core] sc.addFile(path) does not support the relative path. + DoingDone9 <799203320@qq.com> + 2015-03-16 12:27:15 +0000 + Commit: 724aab4, github.com/apache/spark/pull/4993 + + [SPARK-3619] Part 2. Upgrade to Mesos 0.21 to work around MESOS-1688 + Jongyoul Lee + 2015-03-15 15:46:55 +0000 + Commit: 43fcab0, github.com/apache/spark/pull/4361 + + [SPARK-6210] [SQL] use prettyString as column name in agg() + Davies Liu + 2015-03-14 00:43:33 -0700 + Commit: ad47563, github.com/apache/spark/pull/5006 + + [SPARK-6275][Documentation]Miss toDF() function in docs/sql-programming-guide.md + zzcclp + 2015-03-12 15:07:15 +0000 + Commit: 3012781, github.com/apache/spark/pull/4977 + + [SPARK-6133] Make sc.stop() idempotent + Andrew Or + 2015-03-03 15:09:57 -0800 + Commit: a08588c, github.com/apache/spark/pull/4871 + + [SPARK-6132][HOTFIX] ContextCleaner InterruptedException should be quiet + Andrew Or + 2015-03-03 20:49:45 -0800 + Commit: 338bea7, github.com/apache/spark/pull/4882 + + [SPARK-6132] ContextCleaner race condition across SparkContexts + Andrew Or + 2015-03-03 13:44:05 -0800 + Commit: 3cdc8a3, github.com/apache/spark/pull/4869 + + [SPARK-6087][CORE] Provide actionable exception if Kryo buffer is not large enough + Lev Khomich + 2015-03-10 10:55:42 +0000 + Commit: 9846790, github.com/apache/spark/pull/4947 + + [SPARK-6036][CORE] avoid race condition between eventlogListener and akka actor system + Zhang, Liye + 2015-02-26 23:11:43 -0800 + Commit: f81611d, github.com/apache/spark/pull/4785 + + SPARK-4044 [CORE] Thriftserver fails to start when JAVA_HOME points to JRE instead of JDK + Sean Owen + 2015-03-13 17:59:31 +0000 + Commit: 4aa4132, github.com/apache/spark/pull/4981 + + SPARK-4300 [CORE] Race condition during SparkWorker shutdown + Sean Owen + 2015-02-26 14:08:56 -0800 + Commit: a3493eb, github.com/apache/spark/pull/4787 + + [SPARK-6194] [SPARK-677] [PySpark] fix memory leak in collect() + Davies Liu + 2015-03-09 16:24:06 -0700 + Commit: 170af49, github.com/apache/spark/pull/4923 + + SPARK-4704 [CORE] SparkSubmitDriverBootstrap doesn't flush output + Sean Owen + 2015-02-26 12:56:54 -0800 + Commit: dbee7e1, github.com/apache/spark/pull/4788 + + [SPARK-6278][MLLIB] Mention the change of objective in linear regression + Xiangrui Meng + 2015-03-13 10:27:28 -0700 + Commit: 214f681, github.com/apache/spark/pull/4978 + + [SPARK-5310] [SQL] [DOC] Parquet section for the SQL programming guide + Cheng Lian + 2015-03-13 21:34:50 +0800 + Commit: dc287f3, github.com/apache/spark/pull/5001 + + [mllib] [python] Add LassoModel to __all__ in regression.py + Joseph K. Bradley + 2015-03-12 16:46:29 -0700 + Commit: 23069bd, github.com/apache/spark/pull/4970 + + [SPARK-6294] fix hang when call take() in JVM on PythonRDD + Davies Liu + 2015-03-12 01:34:38 -0700 + Commit: 850e694, github.com/apache/spark/pull/4987 + + [SPARK-6296] [SQL] Added equals to Column + Volodymyr Lyubinets + 2015-03-12 00:55:26 -0700 + Commit: d9e141c, github.com/apache/spark/pull/4988 + + [SPARK-6128][Streaming][Documentation] Updates to Spark Streaming Programming Guide + Tathagata Das + 2015-03-11 18:48:21 -0700 + Commit: bdc4682, github.com/apache/spark/pull/4956 + + [SPARK-6274][Streaming][Examples] Added examples streaming + sql examples. + Tathagata Das + 2015-03-11 11:19:51 -0700 + Commit: ac61466, github.com/apache/spark/pull/4975 + + [SPARK-5183][SQL] Update SQL Docs with JDBC and Migration Guide + Michael Armbrust + 2015-03-10 18:13:09 -0700 + Commit: edbcb6f, github.com/apache/spark/pull/4958 + + Minor doc: Remove the extra blank line in data types javadoc. + Reynold Xin + 2015-03-10 17:25:04 -0700 + Commit: 7295192, github.com/apache/spark/pull/4955 + + [SPARK-5310][Doc] Update SQL Programming Guide to include DataFrames. + Reynold Xin + 2015-03-09 16:16:16 -0700 + Commit: bc53d3d, github.com/apache/spark/pull/4954 + + [Docs] Replace references to SchemaRDD with DataFrame + Reynold Xin + 2015-03-09 13:29:19 -0700 + Commit: 5e58f76, github.com/apache/spark/pull/4952 + + Preparing development version 1.3.1-SNAPSHOT + Patrick Wendell + 2015-03-05 23:02:08 +0000 + Commit: c152f9a + + +Release 1.3.0 + + [SQL] Make Strategies a public developer API + Michael Armbrust + 2015-03-05 14:50:25 -0800 + Commit: 556e0de, github.com/apache/spark/pull/4920 + + [SPARK-6163][SQL] jsonFile should be backed by the data source API + Yin Huai + 2015-03-05 14:49:44 -0800 + Commit: 083fed5, github.com/apache/spark/pull/4896 + + [SPARK-6145][SQL] fix ORDER BY on nested fields + Wenchen Fan , Michael Armbrust + 2015-03-05 14:49:01 -0800 + Commit: e358f55, github.com/apache/spark/pull/4918 + + [SPARK-6175] Fix standalone executor log links when ephemeral ports or SPARK_PUBLIC_DNS are used + Josh Rosen + 2015-03-05 12:04:00 -0800 + Commit: 988b498, github.com/apache/spark/pull/4903 + + SPARK-6182 [BUILD] spark-parent pom needs to be published for both 2.10 and 2.11 + Sean Owen + 2015-03-05 11:31:48 -0800 + Commit: ae315d2, github.com/apache/spark/pull/4912 + + Revert "[SPARK-6153] [SQL] promote guava dep for hive-thriftserver" + Cheng Lian + 2015-03-05 17:58:18 +0800 + Commit: f8205d3 + + [SPARK-6153] [SQL] promote guava dep for hive-thriftserver + Daoyuan Wang + 2015-03-05 16:35:17 +0800 + Commit: b92d925, github.com/apache/spark/pull/4884 + + Updating CHANGES file + Patrick Wendell + 2015-03-04 21:19:49 -0800 + Commit: 87eac3c + + SPARK-5143 [BUILD] [WIP] spark-network-yarn 2.11 depends on spark-network-shuffle 2.10 + Sean Owen + 2015-03-04 21:00:51 -0800 + Commit: f509159, github.com/apache/spark/pull/4876 + + [SPARK-6149] [SQL] [Build] Excludes Guava 15 referenced by jackson-module-scala_2.10 + Cheng Lian + 2015-03-04 20:52:58 -0800 + Commit: a0aa24a, github.com/apache/spark/pull/4890 + + [SPARK-6144] [core] Fix addFile when source files are on "hdfs:" + Marcelo Vanzin , trystanleftwich + 2015-03-04 12:58:39 -0800 + Commit: 3fc74f4, github.com/apache/spark/pull/4894 + + [SPARK-6134][SQL] Fix wrong datatype for casting FloatType and default LongType value in defaultPrimitive + Liang-Chi Hsieh + 2015-03-04 20:23:43 +0800 + Commit: bfa4e31, github.com/apache/spark/pull/4870 + + [SPARK-6136] [SQL] Removed JDBC integration tests which depends on docker-client + Cheng Lian + 2015-03-04 19:39:02 +0800 + Commit: 035243d, github.com/apache/spark/pull/4872 + + [SPARK-6141][MLlib] Upgrade Breeze from 0.10 to 0.11 to fix convergence bug + Xiangrui Meng , DB Tsai , DB Tsai + 2015-03-03 23:52:02 -0800 + Commit: 9f24977, github.com/apache/spark/pull/4879 + + [SPARK-5949] HighlyCompressedMapStatus needs more classes registered w/ kryo + Imran Rashid + 2015-03-03 15:33:19 -0800 + Commit: 9a0b75c, github.com/apache/spark/pull/4877 + + SPARK-1911 [DOCS] Warn users if their assembly jars are not built with Java 6 + Sean Owen + 2015-03-03 13:40:11 -0800 + Commit: 8446ad0, github.com/apache/spark/pull/4874 + + Revert "[SPARK-5423][Core] Cleanup resources in DiskMapIterator.finalize to ensure deleting the temp file" + Andrew Or + 2015-03-03 13:04:15 -0800 + Commit: ee4929d + + Adding CHANGES.txt for Spark 1.3 + Patrick Wendell + 2015-03-03 02:19:19 -0800 + Commit: ce7158c + + BUILD: Minor tweaks to internal build scripts + Patrick Wendell + 2015-03-03 00:38:12 -0800 + Commit: ae60eb9 + + HOTFIX: Bump HBase version in MapR profiles. + Patrick Wendell + 2015-03-03 01:38:07 -0800 + Commit: 1aa8461 + + [SPARK-5537][MLlib][Docs] Add user guide for multinomial logistic regression + DB Tsai + 2015-03-02 22:37:12 -0800 + Commit: 841d2a2, github.com/apache/spark/pull/4866 + + [SPARK-6120] [mllib] Warnings about memory in tree, ensemble model save + Joseph K. Bradley + 2015-03-02 22:33:51 -0800 + Commit: 81648a7, github.com/apache/spark/pull/4864 + + [SPARK-6097][MLLIB] Support tree model save/load in PySpark/MLlib + Xiangrui Meng + 2015-03-02 22:27:01 -0800 + Commit: 62c53be, github.com/apache/spark/pull/4854 + + [SPARK-5310][SQL] Fixes to Docs and Datasources API + Reynold Xin , Michael Armbrust + 2015-03-02 22:14:08 -0800 + Commit: 4e6e008, github.com/apache/spark/pull/4868 + + [SPARK-5950][SQL]Insert array into a metastore table saved as parquet should work when using datasource api + Yin Huai + 2015-03-02 19:31:55 -0800 + Commit: 1b490e9, github.com/apache/spark/pull/4826 + + [SPARK-6127][Streaming][Docs] Add Kafka to Python api docs + Tathagata Das + 2015-03-02 18:40:46 -0800 + Commit: ffd0591, github.com/apache/spark/pull/4860 + + [SPARK-5537] Add user guide for multinomial logistic regression + Xiangrui Meng , DB Tsai + 2015-03-02 18:10:50 -0800 + Commit: 11389f0, github.com/apache/spark/pull/4801 + + [SPARK-6121][SQL][MLLIB] simpleString for UDT + Xiangrui Meng + 2015-03-02 17:14:34 -0800 + Commit: 1b8ab57, github.com/apache/spark/pull/4858 + + [SPARK-6048] SparkConf should not translate deprecated configs on set + Andrew Or + 2015-03-02 16:36:42 -0800 + Commit: ea69cf2, github.com/apache/spark/pull/4799 + + [SPARK-6066] Make event log format easier to parse + Andrew Or + 2015-03-02 16:34:32 -0800 + Commit: 8100b79, github.com/apache/spark/pull/4821 + + [SPARK-6082] [SQL] Provides better error message for malformed rows when caching tables + Cheng Lian + 2015-03-02 16:18:00 -0800 + Commit: 866f281, github.com/apache/spark/pull/4842 + + [SPARK-6114][SQL] Avoid metastore conversions before plan is resolved + Michael Armbrust + 2015-03-02 16:10:54 -0800 + Commit: 3899c7c, github.com/apache/spark/pull/4855 + + [SPARK-6050] [yarn] Relax matching of vcore count in received containers. + Marcelo Vanzin + 2015-03-02 16:41:43 -0600 + Commit: 650d1e7, github.com/apache/spark/pull/4818 + + [SPARK-6040][SQL] Fix the percent bug in tablesample + q00251598 + 2015-03-02 13:16:29 -0800 + Commit: a83b9bb, github.com/apache/spark/pull/4789 + + [Minor] Fix doc typo for describing primitiveTerm effectiveness condition + Liang-Chi Hsieh + 2015-03-02 13:11:17 -0800 + Commit: f92876a, github.com/apache/spark/pull/4762 + + SPARK-5390 [DOCS] Encourage users to post on Stack Overflow in Community Docs + Sean Owen + 2015-03-02 21:10:08 +0000 + Commit: 58e7198, github.com/apache/spark/pull/4843 + + [DOCS] Refactored Dataframe join comment to use correct parameter ordering + Paul Power + 2015-03-02 13:08:47 -0800 + Commit: 54ac243, github.com/apache/spark/pull/4847 + + [SPARK-6080] [PySpark] correct LogisticRegressionWithLBFGS regType parameter for pyspark + Yanbo Liang + 2015-03-02 10:17:24 -0800 + Commit: 4ffaf85, github.com/apache/spark/pull/4831 + + [SPARK-5741][SQL] Support the path contains comma in HiveContext + q00251598 + 2015-03-02 10:13:11 -0800 + Commit: f476108, github.com/apache/spark/pull/4532 + + [SPARK-6111] Fixed usage string in documentation. + Kenneth Myers + 2015-03-02 17:25:24 +0000 + Commit: b2b7f01, github.com/apache/spark/pull/4852 + + [SPARK-6052][SQL]In JSON schema inference, we should always set containsNull of an ArrayType to true + Yin Huai + 2015-03-02 23:18:07 +0800 + Commit: a3fef2c, github.com/apache/spark/pull/4806 + + [SPARK-6073][SQL] Need to refresh metastore cache after append data in CreateMetastoreDataSourceAsSelect + Yin Huai + 2015-03-02 22:42:18 +0800 + Commit: c59871c, github.com/apache/spark/pull/4824 + + [Streaming][Minor]Fix some error docs in streaming examples + Saisai Shao + 2015-03-02 08:49:19 +0000 + Commit: 1fe677a, github.com/apache/spark/pull/4837 + + [SPARK-6083] [MLLib] [DOC] Make Python API example consistent in NaiveBayes + MechCoder + 2015-03-01 16:28:15 -0800 + Commit: 6a2fc85, github.com/apache/spark/pull/4834 + + [SPARK-6053][MLLIB] support save/load in PySpark's ALS + Xiangrui Meng + 2015-03-01 16:26:57 -0800 + Commit: b570d98, github.com/apache/spark/pull/4811 + + [SPARK-6074] [sql] Package pyspark sql bindings. + Marcelo Vanzin + 2015-03-01 11:05:10 +0000 + Commit: bb16618, github.com/apache/spark/pull/4822 + + SPARK-5984: Fix TimSort bug causes ArrayOutOfBoundsException + Evan Yu + 2015-02-28 18:55:34 -0800 + Commit: 317694c, github.com/apache/spark/pull/4804 + + [SPARK-5775] [SQL] BugFix: GenericRow cannot be cast to SpecificMutableRow when nested data and partitioned table + Cheng Lian , Cheng Lian , Yin Huai + 2015-02-28 21:15:43 +0800 + Commit: aa39460, github.com/apache/spark/pull/4792 + + [SPARK-5979][SPARK-6032] Smaller safer --packages fix + Burak Yavuz + 2015-02-27 22:59:35 -0800 + Commit: 5a55c96, github.com/apache/spark/pull/4802 + + [SPARK-6070] [yarn] Remove unneeded classes from shuffle service jar. + Marcelo Vanzin + 2015-02-27 22:44:11 -0800 + Commit: 1747e0a, github.com/apache/spark/pull/4820 + + [SPARK-6055] [PySpark] fix incorrect __eq__ of DataType + Davies Liu + 2015-02-27 20:07:17 -0800 + Commit: 49f2187, github.com/apache/spark/pull/4808 + + [SPARK-5751] [SQL] Sets SPARK_HOME as SPARK_PID_DIR when running Thrift server test suites + Cheng Lian + 2015-02-28 08:41:49 +0800 + Commit: 5d19cf0, github.com/apache/spark/pull/4758 + + [Streaming][Minor] Remove useless type signature of Java Kafka direct stream API + Saisai Shao + 2015-02-27 13:01:42 -0800 + Commit: ceebe3c, github.com/apache/spark/pull/4817 + + [SPARK-4587] [mllib] [docs] Fixed save,load calls in ML guide examples + Joseph K. Bradley + 2015-02-27 13:00:36 -0800 + Commit: 117e10c, github.com/apache/spark/pull/4816 + + [SPARK-6058][Yarn] Log the user class exception in ApplicationMaster + zsxwing + 2015-02-27 13:31:46 +0000 + Commit: bff8088, github.com/apache/spark/pull/4813 + + fix spark-6033, clarify the spark.worker.cleanup behavior in standalone mode + 许鹏 + 2015-02-26 23:05:56 -0800 + Commit: b8db84c, github.com/apache/spark/pull/4803 + + SPARK-2168 [Spark core] Use relative URIs for the app links in the History Server. + Lukasz Jastrzebski + 2015-02-26 22:38:06 -0800 + Commit: 485b919, github.com/apache/spark/pull/4778 + + [SPARK-6024][SQL] When a data source table has too many columns, it's schema cannot be stored in metastore. + Yin Huai + 2015-02-26 20:46:05 -0800 + Commit: 6200f07, github.com/apache/spark/pull/4795 + + [SPARK-6037][SQL] Avoiding duplicate Parquet schema merging + Liang-Chi Hsieh + 2015-02-27 11:06:47 +0800 + Commit: 25a109e, github.com/apache/spark/pull/4786 + + SPARK-4579 [WEBUI] Scheduling Delay appears negative + Sean Owen + 2015-02-26 17:35:09 -0800 + Commit: b83a93e, github.com/apache/spark/pull/4796 + + [SPARK-5951][YARN] Remove unreachable driver memory properties in yarn client mode + mohit.goyal + 2015-02-26 14:27:47 -0800 + Commit: 5b426cb, github.com/apache/spark/pull/4730 + + Add a note for context termination for History server on Yarn + moussa taifi + 2015-02-26 14:19:43 -0800 + Commit: 297c3ef, github.com/apache/spark/pull/4721 + + [SPARK-6018] [YARN] NoSuchMethodError in Spark app is swallowed by YARN AM + Cheolsoo Park + 2015-02-26 13:53:49 -0800 + Commit: fe79674, github.com/apache/spark/pull/4773 + + [SPARK-6027][SPARK-5546] Fixed --jar and --packages not working for KafkaUtils and improved error message + Tathagata Das + 2015-02-26 13:46:07 -0800 + Commit: 731a997, github.com/apache/spark/pull/4779 + + Modify default value description for spark.scheduler.minRegisteredResourcesRatio on docs. + Li Zhihui + 2015-02-26 13:07:07 -0800 + Commit: 62652dc, github.com/apache/spark/pull/4781 + + [SPARK-5363] Fix bug in PythonRDD: remove() inside iterator is not safe + Davies Liu + 2015-02-26 11:54:17 -0800 + Commit: 5d309ad, github.com/apache/spark/pull/4776 + + [SPARK-6015] fix links to source code in Python API docs + Davies Liu + 2015-02-26 10:45:29 -0800 + Commit: dafb3d2, github.com/apache/spark/pull/4772 + + [SPARK-6007][SQL] Add numRows param in DataFrame.show() + Jacky Li + 2015-02-26 10:40:58 -0800 + Commit: 7c779d8, github.com/apache/spark/pull/4767 + + [SPARK-6016][SQL] Cannot read the parquet table after overwriting the existing table when spark.sql.parquet.cacheMetadata=true + Yin Huai + 2015-02-27 01:01:32 +0800 + Commit: b5c5e93, github.com/apache/spark/pull/4775 + + [SPARK-6023][SQL] ParquetConversions fails to replace the destination MetastoreRelation of an InsertIntoTable node to ParquetRelation2 + Yin Huai + 2015-02-26 22:39:49 +0800 + Commit: e0f5fb0, github.com/apache/spark/pull/4782 + + [SPARK-5976][MLLIB] Add partitioner to factors returned by ALS + Xiangrui Meng + 2015-02-25 23:43:29 -0800 + Commit: a51d9db, github.com/apache/spark/pull/4748 + + [SPARK-1182][Docs] Sort the configuration parameters in configuration.md + Brennon York + 2015-02-25 16:12:56 -0800 + Commit: 56fa38a, github.com/apache/spark/pull/3863 + + [SPARK-5724] fix the misconfiguration in AkkaUtils + CodingCat + 2015-02-23 11:29:25 +0000 + Commit: b32a653, github.com/apache/spark/pull/4512 + + [SPARK-5974] [SPARK-5980] [mllib] [python] [docs] Update ML guide with save/load, Python GBT + Joseph K. Bradley + 2015-02-25 16:13:17 -0800 + Commit: a1b4856, github.com/apache/spark/pull/4750 + + [SPARK-5926] [SQL] make DataFrame.explain leverage queryExecution.logical + Yanbo Liang + 2015-02-25 15:37:13 -0800 + Commit: 5bd4b49, github.com/apache/spark/pull/4707 + + [SPARK-5999][SQL] Remove duplicate Literal matching block + Liang-Chi Hsieh + 2015-02-25 15:22:33 -0800 + Commit: 6fff9b8, github.com/apache/spark/pull/4760 + + [SPARK-6010] [SQL] Merging compatible Parquet schemas before computing splits + Cheng Lian + 2015-02-25 15:15:22 -0800 + Commit: 016f1f8, github.com/apache/spark/pull/4768 + + [SPARK-5944] [PySpark] fix version in Python API docs + Davies Liu + 2015-02-25 15:13:34 -0800 + Commit: 9aca3c6, github.com/apache/spark/pull/4731 + + [SPARK-5982] Remove incorrect Local Read Time Metric + Kay Ousterhout + 2015-02-25 14:55:24 -0800 + Commit: 791df93, github.com/apache/spark/pull/4749 + + [SPARK-1955][GraphX]: VertexRDD can incorrectly assume index sharing + Brennon York + 2015-02-25 14:11:12 -0800 + Commit: 8073767, github.com/apache/spark/pull/4705 + + SPARK-5930 [DOCS] Documented default of spark.shuffle.io.retryWait is confusing + Sean Owen + 2015-02-25 12:20:44 -0800 + Commit: eaffc6e, github.com/apache/spark/pull/4769 + + [SPARK-5996][SQL] Fix specialized outbound conversions + Michael Armbrust + 2015-02-25 10:13:40 -0800 + Commit: fada683, github.com/apache/spark/pull/4757 + + [SPARK-5994] [SQL] Python DataFrame documentation fixes + Davies Liu + 2015-02-24 20:51:55 -0800 + Commit: 5c421e0, github.com/apache/spark/pull/4756 + + [SPARK-5286][SQL] SPARK-5286 followup + Yin Huai + 2015-02-24 19:51:36 -0800 + Commit: e7a748e, github.com/apache/spark/pull/4755 + + [SPARK-5993][Streaming][Build] Fix assembly jar location of kafka-assembly + Tathagata Das + 2015-02-24 19:10:37 -0800 + Commit: 1e94894, github.com/apache/spark/pull/4753 + + [SPARK-5985][SQL] DataFrame sortBy -> orderBy in Python. + Reynold Xin + 2015-02-24 18:59:23 -0800 + Commit: 5e233b2, github.com/apache/spark/pull/4752 + + [SPARK-5904][SQL] DataFrame Java API test suites. + Reynold Xin + 2015-02-24 18:51:41 -0800 + Commit: 78a1781, github.com/apache/spark/pull/4751 + + [SPARK-5751] [SQL] [WIP] Revamped HiveThriftServer2Suite for robustness + Cheng Lian + 2015-02-25 08:34:55 +0800 + Commit: 17ee246, github.com/apache/spark/pull/4720 + + [SPARK-5973] [PySpark] fix zip with two RDDs with AutoBatchedSerializer + Davies Liu + 2015-02-24 14:50:00 -0800 + Commit: 91bf0f8, github.com/apache/spark/pull/4745 + + [SPARK-5952][SQL] Lock when using hive metastore client + Michael Armbrust + 2015-02-24 13:39:29 -0800 + Commit: 641423d, github.com/apache/spark/pull/4746 + + [MLLIB] Change x_i to y_i in Variance's user guide + Xiangrui Meng + 2015-02-24 11:38:59 -0800 + Commit: a4ff445, github.com/apache/spark/pull/4740 + + [SPARK-5965] Standalone Worker UI displays {{USER_JAR}} + Andrew Or + 2015-02-24 11:08:07 -0800 + Commit: eaf7bf9, github.com/apache/spark/pull/4739 + + [Spark-5967] [UI] Correctly clean JobProgressListener.stageIdToActiveJobIds + Tathagata Das + 2015-02-24 11:02:47 -0800 + Commit: 28dd53b, github.com/apache/spark/pull/4741 + + [SPARK-5532][SQL] Repartition should not use external rdd representation + Michael Armbrust + 2015-02-24 10:52:18 -0800 + Commit: e46096b, github.com/apache/spark/pull/4738 + + [SPARK-5910][SQL] Support for as in selectExpr + Michael Armbrust + 2015-02-24 10:49:51 -0800 + Commit: ba5d60d, github.com/apache/spark/pull/4736 + + [SPARK-5968] [SQL] Suppresses ParquetOutputCommitter WARN logs + Cheng Lian + 2015-02-24 10:45:38 -0800 + Commit: 2b562b0, github.com/apache/spark/pull/4744 + + [SPARK-5958][MLLIB][DOC] update block matrix user guide + Xiangrui Meng + 2015-02-23 22:08:44 -0800 + Commit: dd42558, github.com/apache/spark/pull/4737 + + [SPARK-5873][SQL] Allow viewing of partially analyzed plans in queryExecution + Michael Armbrust + 2015-02-23 17:34:54 -0800 + Commit: 2d7786e, github.com/apache/spark/pull/4684 + + [SPARK-5935][SQL] Accept MapType in the schema provided to a JSON dataset. + Yin Huai , Yin Huai + 2015-02-23 17:16:34 -0800 + Commit: 33ccad2, github.com/apache/spark/pull/4710 + + [SPARK-5912] [docs] [mllib] Small fixes to ChiSqSelector docs + Joseph K. Bradley + 2015-02-23 16:15:57 -0800 + Commit: ae97040, github.com/apache/spark/pull/4732 + + [MLLIB] SPARK-5912 Programming guide for feature selection + Alexander Ulanov + 2015-02-23 12:09:40 -0800 + Commit: 8355773, github.com/apache/spark/pull/4709 + + [SPARK-5939][MLLib] make FPGrowth example app take parameters + Jacky Li + 2015-02-23 08:47:28 -0800 + Commit: 33b9084, github.com/apache/spark/pull/4714 + + [SPARK-5943][Streaming] Update the test to use new API to reduce the warning + Saisai Shao + 2015-02-23 11:27:27 +0000 + Commit: 67b7f79, github.com/apache/spark/pull/4722 + + [EXAMPLES] fix typo. + Makoto Fukuhara + 2015-02-23 09:24:33 +0000 + Commit: f172387, github.com/apache/spark/pull/4724 + + Revert "[SPARK-4808] Removing minimum number of elements read before spill check" + Andrew Or + 2015-02-22 09:44:52 -0800 + Commit: 4186dd3 + + SPARK-5669 [BUILD] Reverse exclusion of JBLAS libs for 1.3 + Sean Owen + 2015-02-22 09:09:06 +0000 + Commit: eed7389, github.com/apache/spark/pull/4715 + + [DataFrame] [Typo] Fix the typo + Cheng Hao + 2015-02-22 08:56:30 +0000 + Commit: 04d3b32, github.com/apache/spark/pull/4717 + + [DOCS] Fix typo in API for custom InputFormats based on the “new” MapReduce API + Alexander + 2015-02-22 08:53:05 +0000 + Commit: c5a5c6f, github.com/apache/spark/pull/4718 + + [SPARK-5937][YARN] Fix ClientSuite to set YARN mode, so that the correct class is used in t... + Hari Shreedharan + 2015-02-21 10:01:01 -0800 + Commit: 76e3e65, github.com/apache/spark/pull/4711 + + SPARK-5841 [CORE] [HOTFIX 2] Memory leak in DiskBlockManager + Nishkam Ravi , nishkamravi2 , nravi + 2015-02-21 09:59:28 -0800 + Commit: 932338e, github.com/apache/spark/pull/4690 + + [SPARK-5909][SQL] Add a clearCache command to Spark SQL's cache manager + Yin Huai + 2015-02-20 16:20:02 +0800 + Commit: b9a6c5c, github.com/apache/spark/pull/4694 + + [SPARK-5898] [SPARK-5896] [SQL] [PySpark] create DataFrame from pandas and tuple/list + Davies Liu + 2015-02-20 15:35:05 -0800 + Commit: 913562a, github.com/apache/spark/pull/4679 + + [SPARK-5867] [SPARK-5892] [doc] [ml] [mllib] Doc cleanups for 1.3 release + Joseph K. Bradley + 2015-02-20 02:31:32 -0800 + Commit: 8c12f31, github.com/apache/spark/pull/4675 + + [SPARK-4808] Removing minimum number of elements read before spill check + mcheah + 2015-02-19 18:09:22 -0800 + Commit: 0382dcc, github.com/apache/spark/pull/4420 + + [SPARK-5900][MLLIB] make PIC and FPGrowth Java-friendly + Xiangrui Meng + 2015-02-19 18:06:16 -0800 + Commit: ba941ce, github.com/apache/spark/pull/4695 + + SPARK-5570: No docs stating that `new SparkConf().set("spark.driver.memory", ...) will not work + Ilya Ganelin + 2015-02-19 15:50:58 -0800 + Commit: c5f3b9e, github.com/apache/spark/pull/4665 + + SPARK-4682 [CORE] Consolidate various 'Clock' classes + Sean Owen + 2015-02-19 15:35:23 -0800 + Commit: bd49e8b, github.com/apache/spark/pull/4514 + + [Spark-5889] Remove pid file after stopping service. + Zhan Zhang + 2015-02-19 23:13:02 +0000 + Commit: ff8976e, github.com/apache/spark/pull/4676 + + [SPARK-5902] [ml] Made PipelineStage.transformSchema public instead of private to ml + Joseph K. Bradley + 2015-02-19 12:46:27 -0800 + Commit: 0c494cf, github.com/apache/spark/pull/4682 + + [SPARK-5904][SQL] DataFrame API fixes. + Reynold Xin + 2015-02-19 12:09:44 -0800 + Commit: 55d91d9, github.com/apache/spark/pull/4686 + + [SPARK-5825] [Spark Submit] Remove the double checking instance name when stopping the service + Cheng Hao + 2015-02-19 12:07:51 -0800 + Commit: fe00eb6, github.com/apache/spark/pull/4611 + + [SPARK-5423][Core] Cleanup resources in DiskMapIterator.finalize to ensure deleting the temp file + zsxwing + 2015-02-19 18:37:31 +0000 + Commit: 25fae8e, github.com/apache/spark/pull/4219 + + [SPARK-5816] Add huge compatibility warning in DriverWrapper + Andrew Or + 2015-02-19 09:56:25 -0800 + Commit: f93d4d9, github.com/apache/spark/pull/4687 + + SPARK-5548: Fix for AkkaUtilsSuite failure - attempt 2 + Jacek Lewandowski + 2015-02-19 09:53:36 -0800 + Commit: fbcb949, github.com/apache/spark/pull/4653 + + [SPARK-5846] Correctly set job description and pool for SQL jobs + Kay Ousterhout + 2015-02-19 09:49:34 +0800 + Commit: 092b45f, github.com/apache/spark/pull/4630 + + [SPARK-5879][MLLIB] update PIC user guide and add a Java example + Xiangrui Meng + 2015-02-18 16:29:32 -0800 + Commit: a64f374, github.com/apache/spark/pull/4680 + + [SPARK-5722] [SQL] [PySpark] infer int as LongType + Davies Liu + 2015-02-18 14:17:04 -0800 + Commit: 470cba8, github.com/apache/spark/pull/4666 + + [SPARK-5840][SQL] HiveContext cannot be serialized due to tuple extraction + Reynold Xin + 2015-02-18 14:02:32 -0800 + Commit: b86e44c, github.com/apache/spark/pull/4628 + + [SPARK-5507] Added documentation for BlockMatrix + Burak Yavuz + 2015-02-18 10:11:08 -0800 + Commit: 56f8f29, github.com/apache/spark/pull/4664 + + [SPARK-5519][MLLIB] add user guide with example code for fp-growth + Xiangrui Meng + 2015-02-18 10:09:56 -0800 + Commit: 661fbd3, github.com/apache/spark/pull/4661 + + SPARK-5669 [BUILD] [HOTFIX] Spark assembly includes incompatibly licensed libgfortran, libgcc code via JBLAS + Sean Owen + 2015-02-18 14:41:44 +0000 + Commit: 9f256ce, github.com/apache/spark/pull/4673 + + SPARK-4610 addendum: [Minor] [MLlib] Minor doc fix in GBT classification example + MechCoder + 2015-02-18 10:13:28 +0000 + Commit: 3997e74, github.com/apache/spark/pull/4672 + + [SPARK-5878] fix DataFrame.repartition() in Python + Davies Liu + 2015-02-18 01:00:54 -0800 + Commit: aca7991, github.com/apache/spark/pull/4667 + + Avoid deprecation warnings in JDBCSuite. + Tor Myklebust + 2015-02-18 01:00:13 -0800 + Commit: 9a565b8, github.com/apache/spark/pull/4668 + + [Minor] [SQL] Cleans up DataFrame variable names and toDF() calls + Cheng Lian + 2015-02-17 23:36:20 -0800 + Commit: 2bd33ce, github.com/apache/spark/pull/4670 + + [SPARK-5731][Streaming][Test] Fix incorrect test in DirectKafkaStreamSuite + Tathagata Das + 2015-02-17 22:44:16 -0800 + Commit: f8f9a64, github.com/apache/spark/pull/4597 + + [SPARK-5723][SQL]Change the default file format to Parquet for CTAS statements. + Yin Huai + 2015-02-17 18:14:33 -0800 + Commit: 6e82c46, github.com/apache/spark/pull/4639 + + Preparing development version 1.3.1-SNAPSHOT + Patrick Wendell + 2015-02-18 01:52:06 +0000 + Commit: 2ab0ba0 + + Preparing Spark release v1.3.0-rc1 + Patrick Wendell + 2015-02-18 01:52:06 +0000 + Commit: f97b0d4 + + [SPARK-5875][SQL]logical.Project should not be resolved if it contains aggregates or generators + Yin Huai + 2015-02-17 17:50:39 -0800 + Commit: e8284b2, github.com/apache/spark/pull/4663 + + Revert "Preparing Spark release v1.3.0-snapshot1" + Patrick Wendell + 2015-02-17 17:48:47 -0800 + Commit: 7320605 + + Revert "Preparing development version 1.3.1-SNAPSHOT" + Patrick Wendell + 2015-02-17 17:48:43 -0800 + Commit: 932ae4d + + [SPARK-4454] Revert getOrElse() cleanup in DAGScheduler.getCacheLocs() + Josh Rosen + 2015-02-17 17:45:16 -0800 + Commit: 7e5e4d8 + + [SPARK-4454] Properly synchronize accesses to DAGScheduler cacheLocs map + Josh Rosen + 2015-02-17 17:39:58 -0800 + Commit: 07a401a, github.com/apache/spark/pull/4660 + + [SPARK-5811] Added documentation for maven coordinates and added Spark Packages support + Burak Yavuz , Davies Liu + 2015-02-17 17:15:43 -0800 + Commit: cb90584, github.com/apache/spark/pull/4662 + + [SPARK-5785] [PySpark] narrow dependency for cogroup/join in PySpark + Davies Liu + 2015-02-17 16:54:57 -0800 + Commit: 8120235, github.com/apache/spark/pull/4629 + + [SPARK-5852][SQL]Fail to convert a newly created empty metastore parquet table to a data source parquet table. + Yin Huai , Cheng Hao + 2015-02-17 15:47:59 -0800 + Commit: 07d8ef9, github.com/apache/spark/pull/4655 + + [SPARK-5872] [SQL] create a sqlCtx in pyspark shell + Davies Liu + 2015-02-17 15:44:37 -0800 + Commit: 0dba382, github.com/apache/spark/pull/4659 + + [SPARK-5871] output explain in Python + Davies Liu + 2015-02-17 13:48:38 -0800 + Commit: cb06160, github.com/apache/spark/pull/4658 + + [SPARK-4172] [PySpark] Progress API in Python + Davies Liu + 2015-02-17 13:36:43 -0800 + Commit: 35e23ff, github.com/apache/spark/pull/3027 + + [SPARK-5868][SQL] Fix python UDFs in HiveContext and checks in SQLContext + Michael Armbrust + 2015-02-17 13:23:45 -0800 + Commit: e65dc1f, github.com/apache/spark/pull/4657 + + [SQL] [Minor] Update the HiveContext Unittest + Cheng Hao + 2015-02-17 12:25:35 -0800 + Commit: 0135651, github.com/apache/spark/pull/4584 + + [Minor][SQL] Use same function to check path parameter in JSONRelation + Liang-Chi Hsieh + 2015-02-17 12:24:13 -0800 + Commit: d74d5e8, github.com/apache/spark/pull/4649 + + [SPARK-5862][SQL] Only transformUp the given plan once in HiveMetastoreCatalog + Liang-Chi Hsieh + 2015-02-17 12:23:18 -0800 + Commit: 62063b7, github.com/apache/spark/pull/4651 + + [Minor] fix typo in SQL document + CodingCat + 2015-02-17 12:16:52 -0800 + Commit: 5636c4a, github.com/apache/spark/pull/4656 + + [SPARK-5864] [PySpark] support .jar as python package + Davies Liu + 2015-02-17 12:05:06 -0800 + Commit: 71cf6e2, github.com/apache/spark/pull/4652 + + SPARK-5841 [CORE] [HOTFIX] Memory leak in DiskBlockManager + Sean Owen + 2015-02-17 19:40:06 +0000 + Commit: e64afcd, github.com/apache/spark/pull/4648 + + [SPARK-5661]function hasShutdownDeleteTachyonDir should use shutdownDeleteTachyonPaths to determine whether contains file + xukun 00228947 , viper-kun + 2015-02-17 18:59:41 +0000 + Commit: 420bc9b, github.com/apache/spark/pull/4418 + + [SPARK-5778] throw if nonexistent metrics config file provided + Ryan Williams + 2015-02-17 10:57:16 -0800 + Commit: 2bf2b56, github.com/apache/spark/pull/4571 + + [SPARK-5859] [PySpark] [SQL] fix DataFrame Python API + Davies Liu + 2015-02-17 10:22:48 -0800 + Commit: 4a581aa, github.com/apache/spark/pull/4645 + + [SPARK-5166][SPARK-5247][SPARK-5258][SQL] API Cleanup / Documentation + Michael Armbrust + 2015-02-17 10:21:17 -0800 + Commit: cd3d415, github.com/apache/spark/pull/4642 + + [SPARK-5858][MLLIB] Remove unnecessary first() call in GLM + Xiangrui Meng + 2015-02-17 10:17:45 -0800 + Commit: 97cb568, github.com/apache/spark/pull/4647 + + SPARK-5856: In Maven build script, launch Zinc with more memory + Patrick Wendell + 2015-02-17 10:10:01 -0800 + Commit: 8240629, github.com/apache/spark/pull/4643 + + Revert "[SPARK-5363] [PySpark] check ending mark in non-block way" + Josh Rosen + 2015-02-17 07:48:27 -0800 + Commit: aeb85cd + + [SPARK-5826][Streaming] Fix Configuration not serializable problem + jerryshao + 2015-02-17 10:45:18 +0000 + Commit: b8da5c3, github.com/apache/spark/pull/4612 + + HOTFIX: Style issue causing build break + Patrick Wendell + 2015-02-16 22:10:39 -0800 + Commit: e9241fa + + [SPARK-5802][MLLIB] cache transformed data in glm + Xiangrui Meng + 2015-02-16 22:09:04 -0800 + Commit: dfe0fa0, github.com/apache/spark/pull/4593 + + [SPARK-5853][SQL] Schema support in Row. + Reynold Xin + 2015-02-16 20:42:57 -0800 + Commit: d0701d9, github.com/apache/spark/pull/4640 + + SPARK-5850: Remove experimental label for Scala 2.11 and FlumePollingStream + Patrick Wendell + 2015-02-16 20:33:33 -0800 + Commit: c6a7069, github.com/apache/spark/pull/4638 + + [SPARK-5363] [PySpark] check ending mark in non-block way + Davies Liu + 2015-02-16 20:32:03 -0800 + Commit: baad6b3, github.com/apache/spark/pull/4601 + + [SQL] Various DataFrame doc changes. + Reynold Xin + 2015-02-16 19:00:30 -0800 + Commit: e355b54, github.com/apache/spark/pull/4636 + + [SPARK-5849] Handle more types of invalid JSON requests in SubmitRestProtocolMessage.parseAction + Josh Rosen + 2015-02-16 18:08:02 -0800 + Commit: 385a339, github.com/apache/spark/pull/4637 + + [SPARK-3340] Deprecate ADD_JARS and ADD_FILES + azagrebin + 2015-02-16 18:06:19 -0800 + Commit: d8c70fb, github.com/apache/spark/pull/4616 + + [SPARK-5788] [PySpark] capture the exception in python write thread + Davies Liu + 2015-02-16 17:57:14 -0800 + Commit: c2a9a61, github.com/apache/spark/pull/4577 + + SPARK-5848: tear down the ConsoleProgressBar timer + Matt Whelan + 2015-02-17 00:59:49 +0000 + Commit: 52994d8, github.com/apache/spark/pull/4635 + + [SPARK-4865][SQL]Include temporary tables in SHOW TABLES + Yin Huai + 2015-02-16 15:59:23 -0800 + Commit: 8a94bf7, github.com/apache/spark/pull/4618 + + [SQL] Optimize arithmetic and predicate operators + kai + 2015-02-16 15:58:05 -0800 + Commit: 639a3c2, github.com/apache/spark/pull/4472 + + [SPARK-5839][SQL]HiveMetastoreCatalog does not recognize table names and aliases of data source tables. + Yin Huai + 2015-02-16 15:54:01 -0800 + Commit: a15a0a0, github.com/apache/spark/pull/4626 + + [SPARK-5746][SQL] Check invalid cases for the write path of data source API + Yin Huai + 2015-02-16 15:51:59 -0800 + Commit: 4198654, github.com/apache/spark/pull/4617 + + HOTFIX: Break in Jekyll build from #4589 + Patrick Wendell + 2015-02-16 15:43:56 -0800 + Commit: ad8fd4f + + [SPARK-2313] Use socket to communicate GatewayServer port back to Python driver + Josh Rosen + 2015-02-16 15:25:11 -0800 + Commit: b70b8ba, github.com/apache/spark/pull/3424. + + SPARK-5357: Update commons-codec version to 1.10 (current) + Matt Whelan + 2015-02-16 23:05:34 +0000 + Commit: 8c45619, github.com/apache/spark/pull/4153 + + SPARK-5841: remove DiskBlockManager shutdown hook on stop + Matt Whelan + 2015-02-16 22:54:32 +0000 + Commit: dd977df, github.com/apache/spark/pull/4627 + + [SPARK-5833] [SQL] Adds REFRESH TABLE command + Cheng Lian + 2015-02-16 12:52:05 -0800 + Commit: 864d77e, github.com/apache/spark/pull/4624 + + [SPARK-5296] [SQL] Add more filter types for data sources API + Cheng Lian + 2015-02-16 12:48:55 -0800 + Commit: 363a9a7, github.com/apache/spark/pull/4623 + + [SQL] Add fetched row count in SparkSQLCLIDriver + OopsOutOfMemory + 2015-02-16 12:34:09 -0800 + Commit: 0368494, github.com/apache/spark/pull/4604 + + [SQL] Initial support for reporting location of error in sql string + Michael Armbrust + 2015-02-16 12:32:56 -0800 + Commit: 63fa123, github.com/apache/spark/pull/4587 + + [SPARK-5824] [SQL] add null format in ctas and set default col comment to null + Daoyuan Wang + 2015-02-16 12:31:36 -0800 + Commit: c2eaaea, github.com/apache/spark/pull/4609 + + [SQL] [Minor] Update the SpecificMutableRow.copy + Cheng Hao + 2015-02-16 12:21:08 -0800 + Commit: 1a88955, github.com/apache/spark/pull/4619 + + SPARK-5795 [STREAMING] api.java.JavaPairDStream.saveAsNewAPIHadoopFiles may not friendly to java + Sean Owen + 2015-02-16 19:32:31 +0000 + Commit: fef2267, github.com/apache/spark/pull/4608 + + [SPARK-5799][SQL] Compute aggregation function on specified numeric columns + Liang-Chi Hsieh + 2015-02-16 10:06:11 -0800 + Commit: 0165e9d, github.com/apache/spark/pull/4592 + + [SPARK-4553] [SPARK-5767] [SQL] Wires Parquet data source with the newly introduced write support for data source API + Cheng Lian + 2015-02-16 01:38:31 -0800 + Commit: 78f7edb, github.com/apache/spark/pull/4563 + + [Minor] [SQL] Renames stringRddToDataFrame to stringRddToDataFrameHolder for consistency + Cheng Lian + 2015-02-16 01:33:37 -0800 + Commit: 066301c, github.com/apache/spark/pull/4613 + + [Ml] SPARK-5804 Explicitly manage cache in Crossvalidator k-fold loop + Peter Rudenko + 2015-02-16 00:07:23 -0800 + Commit: 0d93205, github.com/apache/spark/pull/4595 + + [Ml] SPARK-5796 Don't transform data on a last estimator in Pipeline + Peter Rudenko + 2015-02-15 20:51:32 -0800 + Commit: 9cf7d70, github.com/apache/spark/pull/4590 + + SPARK-5815 [MLLIB] Deprecate SVDPlusPlus APIs that expose DoubleMatrix from JBLAS + Sean Owen + 2015-02-15 20:41:27 -0800 + Commit: db3c539, github.com/apache/spark/pull/4614 + + [SPARK-5769] Set params in constructors and in setParams in Python ML pipelines + Xiangrui Meng + 2015-02-15 20:29:26 -0800 + Commit: d710991, github.com/apache/spark/pull/4564 + + SPARK-5669 [BUILD] Spark assembly includes incompatibly licensed libgfortran, libgcc code via JBLAS + Sean Owen + 2015-02-15 09:15:48 -0800 + Commit: 4e099d7, github.com/apache/spark/pull/4453 + + [MLLIB][SPARK-5502] User guide for isotonic regression + martinzapletal + 2015-02-15 09:10:03 -0800 + Commit: d96e188, github.com/apache/spark/pull/4536 + + [HOTFIX] Ignore DirectKafkaStreamSuite. + Patrick Wendell + 2015-02-13 12:43:53 -0800 + Commit: 70ebad4 + + [SPARK-5827][SQL] Add missing import in the example of SqlContext + Takeshi Yamamuro + 2015-02-15 14:42:20 +0000 + Commit: 9c1c70d, github.com/apache/spark/pull/4615 + + SPARK-5822 [BUILD] cannot import src/main/scala & src/test/scala into eclipse as source folder + gli + 2015-02-14 20:43:27 +0000 + Commit: f87f3b7, github.com/apache/spark/pull/4531 + + Revise formatting of previous commit f80e2629bb74bc62960c61ff313f7e7802d61319 + Sean Owen + 2015-02-14 20:12:29 +0000 + Commit: 1945fcf + + [SPARK-5800] Streaming Docs. Change linked files according the selected language + gasparms + 2015-02-14 20:10:29 +0000 + Commit: e99e170, github.com/apache/spark/pull/4589 + + [SPARK-5752][SQL] Don't implicitly convert RDDs directly to DataFrames + Reynold Xin , Davies Liu + 2015-02-13 23:03:22 -0800 + Commit: ba91bf5, github.com/apache/spark/pull/4556 + + SPARK-3290 [GRAPHX] No unpersist callls in SVDPlusPlus + Sean Owen + 2015-02-13 20:12:52 -0800 + Commit: db57479, github.com/apache/spark/pull/4234 + + [SPARK-5227] [SPARK-5679] Disable FileSystem cache in WholeTextFileRecordReaderSuite + Josh Rosen + 2015-02-13 17:45:31 -0800 + Commit: 152147f, github.com/apache/spark/pull/4599 + + [SPARK-5730][ML] add doc groups to spark.ml components + Xiangrui Meng + 2015-02-13 16:45:59 -0800 + Commit: fccd38d, github.com/apache/spark/pull/4600 + + [SPARK-5803][MLLIB] use ArrayBuilder to build primitive arrays + Xiangrui Meng + 2015-02-13 16:43:49 -0800 + Commit: 356b798, github.com/apache/spark/pull/4594 + + [SPARK-5806] re-organize sections in mllib-clustering.md + Xiangrui Meng + 2015-02-13 15:09:27 -0800 + Commit: 9658763, github.com/apache/spark/pull/4598 + + [SPARK-5789][SQL]Throw a better error message if JsonRDD.parseJson encounters unrecoverable parsing errors. + Yin Huai + 2015-02-13 13:51:06 -0800 + Commit: d9d0250, github.com/apache/spark/pull/4582 + + [SPARK-5642] [SQL] Apply column pruning on unused aggregation fields + Daoyuan Wang , Michael Armbrust + 2015-02-13 13:46:50 -0800 + Commit: efffc2e, github.com/apache/spark/pull/4415 + + [HOTFIX] Fix build break in MesosSchedulerBackendSuite + Andrew Or + 2015-02-13 13:10:29 -0800 + Commit: 4160371 + + SPARK-5805 Fixed the type error in documentation. + Emre Sevinç + 2015-02-13 12:31:27 -0800 + Commit: ad73189, github.com/apache/spark/pull/4596 + + [SPARK-5735] Replace uses of EasyMock with Mockito + Josh Rosen + 2015-02-13 09:53:57 -0800 + Commit: cc9eec1, github.com/apache/spark/pull/4578 + + [SPARK-5783] Better eventlog-parsing error messages + Ryan Williams + 2015-02-13 09:47:26 -0800 + Commit: e5690a5, github.com/apache/spark/pull/4573 + + [SPARK-5503][MLLIB] Example code for Power Iteration Clustering + sboeschhuawei + 2015-02-13 09:45:57 -0800 + Commit: 5e63942, github.com/apache/spark/pull/4495 + + [SPARK-5732][CORE]:Add an option to print the spark version in spark script. + uncleGen , genmao.ygm + 2015-02-13 09:43:10 -0800 + Commit: 5c883df, github.com/apache/spark/pull/4522 + + [SPARK-4832][Deploy]some other processes might take the daemon pid + WangTaoTheTonic , WangTaoTheTonic + 2015-02-13 10:27:23 +0000 + Commit: 1255e83, github.com/apache/spark/pull/3683 + + [SQL] Fix docs of SQLContext.tables + Yin Huai + 2015-02-12 20:37:55 -0800 + Commit: a8f560c, github.com/apache/spark/pull/4579 + + [SPARK-3365][SQL]Wrong schema generated for List type + tianyi + 2015-02-12 22:18:39 -0800 + Commit: b9f332a, github.com/apache/spark/pull/4581 + + [SPARK-3299][SQL]Public API in SQLContext to list tables + Yin Huai + 2015-02-12 18:08:01 -0800 + Commit: edbac17, github.com/apache/spark/pull/4547 + + [SQL] Move SaveMode to SQL package. + Yin Huai + 2015-02-12 15:32:17 -0800 + Commit: 925fd84, github.com/apache/spark/pull/4542 + + [SPARK-5335] Fix deletion of security groups within a VPC + Vladimir Grigor , Vladimir Grigor + 2015-02-12 23:26:24 +0000 + Commit: 5c9db4e, github.com/apache/spark/pull/4122 + + [SPARK-5755] [SQL] remove unnecessary Add + Daoyuan Wang + 2015-02-12 15:22:07 -0800 + Commit: f7103b3, github.com/apache/spark/pull/4551 + + [SPARK-5573][SQL] Add explode to dataframes + Michael Armbrust + 2015-02-12 15:19:19 -0800 + Commit: c7eb9ee, github.com/apache/spark/pull/4546 + + [SPARK-5758][SQL] Use LongType as the default type for integers in JSON schema inference. + Yin Huai + 2015-02-12 15:17:25 -0800 + Commit: b0c79da, github.com/apache/spark/pull/4544 + + [SPARK-5780] [PySpark] Mute the logging during unit tests + Davies Liu + 2015-02-12 14:54:38 -0800 + Commit: bf0d15c, github.com/apache/spark/pull/4572 + + SPARK-5747: Fix wordsplitting bugs in make-distribution.sh + David Y. Ross + 2015-02-12 14:52:38 -0800 + Commit: 11a0d5b, github.com/apache/spark/pull/4540 + + [SPARK-5759][Yarn]ExecutorRunnable should catch YarnException while NMClient start contain... + lianhuiwang + 2015-02-12 14:50:16 -0800 + Commit: 02d5b32, github.com/apache/spark/pull/4554 + + [SPARK-5760][SPARK-5761] Fix standalone rest protocol corner cases + revamp tests + Andrew Or + 2015-02-12 14:47:52 -0800 + Commit: 11d1080, github.com/apache/spark/pull/4557 + + [SPARK-5762] Fix shuffle write time for sort-based shuffle + Kay Ousterhout + 2015-02-12 14:46:37 -0800 + Commit: 0040fc5, github.com/apache/spark/pull/4559 + + [SPARK-5765][Examples]Fixed word split problem in run-example and compute-classpath + Venkata Ramana G , Venkata Ramana Gollamudi + 2015-02-12 14:44:21 -0800 + Commit: 9a1de4b, github.com/apache/spark/pull/4561 + + [SPARK-5645] Added local read bytes/time to task metrics + Kay Ousterhout + 2015-02-12 14:35:44 -0800 + Commit: 74f34bb, github.com/apache/spark/pull/4510 + + [SQL] Improve error messages + Michael Armbrust , wangfei + 2015-02-12 13:11:28 -0800 + Commit: e3a975d, github.com/apache/spark/pull/4558 + + [SQL][DOCS] Update sql documentation + Antonio Navarro Perez + 2015-02-12 12:46:17 -0800 + Commit: cbd659e, github.com/apache/spark/pull/4560 + + [SPARK-5757][MLLIB] replace SQL JSON usage in model import/export by json4s + Xiangrui Meng + 2015-02-12 10:48:13 -0800 + Commit: e26c149, github.com/apache/spark/pull/4555 + + [SPARK-5655] Don't chmod700 application files if running in YARN + Andrew Rowson + 2015-02-12 18:41:39 +0000 + Commit: e23c8f5, github.com/apache/spark/pull/4509 + + [SQL] Make dataframe more tolerant of being serialized + Michael Armbrust + 2015-02-11 19:05:49 -0800 + Commit: 3c1b9bf, github.com/apache/spark/pull/4545 + + [SQL] Two DataFrame fixes. + Reynold Xin + 2015-02-11 18:32:48 -0800 + Commit: bcb1382, github.com/apache/spark/pull/4543 + + [SPARK-3688][SQL] More inline comments for LogicalPlan. + Reynold Xin + 2015-02-11 15:26:31 -0800 + Commit: 08ab3d2, github.com/apache/spark/pull/4539 + + [SPARK-3688][SQL]LogicalPlan can't resolve column correctlly + tianyi + 2015-02-11 12:50:17 -0800 + Commit: e136f47, github.com/apache/spark/pull/4524 + + [SPARK-5454] More robust handling of self joins + Michael Armbrust + 2015-02-11 12:31:56 -0800 + Commit: 1bb3631, github.com/apache/spark/pull/4520 + + Remove outdated remark about take(n). + Daniel Darabos + 2015-02-11 20:24:17 +0000 + Commit: 72adfc5, github.com/apache/spark/pull/4533 + + [SPARK-5677] [SPARK-5734] [SQL] [PySpark] Python DataFrame API remaining tasks + Davies Liu + 2015-02-11 12:13:16 -0800 + Commit: d66aae2, github.com/apache/spark/pull/4528 + + [SPARK-5733] Error Link in Pagination of HistroyPage when showing Incomplete Applications + guliangliang + 2015-02-11 15:55:49 +0000 + Commit: 864dccd, github.com/apache/spark/pull/4523 + + SPARK-5727 [BUILD] Deprecate Debian packaging + Sean Owen + 2015-02-11 08:30:16 +0000 + Commit: 057ec4f, github.com/apache/spark/pull/4516 + + SPARK-5728 [STREAMING] MQTTStreamSuite leaves behind ActiveMQ database files + Sean Owen + 2015-02-11 08:13:51 +0000 + Commit: 476b6d7, github.com/apache/spark/pull/4517 + + [SPARK-4964] [Streaming] refactor createRDD to take leaders via map instead of array + cody koeninger + 2015-02-11 00:13:27 -0800 + Commit: 811d179, github.com/apache/spark/pull/4511 + + Preparing development version 1.3.1-SNAPSHOT + Patrick Wendell + 2015-02-11 07:47:03 +0000 + Commit: e57c81b + + Preparing Spark release v1.3.0-snapshot1 + Patrick Wendell + 2015-02-11 07:47:02 +0000 + Commit: d97bfc6 + + Revert "Preparing Spark release v1.3.0-snapshot1" + Patrick Wendell + 2015-02-10 23:46:04 -0800 + Commit: 6a91d59 + + Revert "Preparing development version 1.3.1-SNAPSHOT" + Patrick Wendell + 2015-02-10 23:46:02 -0800 + Commit: 3a50383 + + HOTFIX: Adding Junit to Hive tests for Maven build + Patrick Wendell + 2015-02-10 23:39:21 -0800 + Commit: 0386fc4 + + Preparing development version 1.3.1-SNAPSHOT + Patrick Wendell + 2015-02-11 06:45:03 +0000 + Commit: ba12b79 + + Preparing Spark release v1.3.0-snapshot1 + Patrick Wendell + 2015-02-11 06:45:03 +0000 + Commit: 53068f5 + + HOTFIX: Java 6 compilation error in Spark SQL + Patrick Wendell + 2015-02-10 22:43:32 -0800 + Commit: 15180bc + + Revert "Preparing Spark release v1.3.0-snapshot1" + Patrick Wendell + 2015-02-10 22:44:10 -0800 + Commit: 536dae9 + + Revert "Preparing development version 1.3.1-SNAPSHOT" + Patrick Wendell + 2015-02-10 22:44:07 -0800 + Commit: 01b562e + + Preparing development version 1.3.1-SNAPSHOT + Patrick Wendell + 2015-02-11 06:15:29 +0000 + Commit: db80d0f + + Preparing Spark release v1.3.0-snapshot1 + Patrick Wendell + 2015-02-11 06:15:29 +0000 + Commit: c2e4001 + + Updating versions for Spark 1.3 + Patrick Wendell + 2015-02-10 21:54:55 -0800 + Commit: 2f52489 + + [SPARK-5714][Mllib] Refactor initial step of LDA to remove redundant operations + Liang-Chi Hsieh + 2015-02-10 21:51:15 -0800 + Commit: ba3aa8f, github.com/apache/spark/pull/4501 + + [SPARK-5702][SQL] Allow short names for built-in data sources. + Reynold Xin + 2015-02-10 20:40:21 -0800 + Commit: 63af90c, github.com/apache/spark/pull/4489 + + [SPARK-5729] Potential NPE in standalone REST API + Andrew Or + 2015-02-10 20:19:14 -0800 + Commit: 1bc75b0, github.com/apache/spark/pull/4518 + + [SPARK-4879] Use driver to coordinate Hadoop output committing for speculative tasks + mcheah , Josh Rosen + 2015-02-10 20:12:18 -0800 + Commit: 79cd59c, github.com/apache/spark/pull/4155. + + [SQL][DataFrame] Fix column computability bug. + Reynold Xin + 2015-02-10 19:50:44 -0800 + Commit: e477e91, github.com/apache/spark/pull/4519 + + [SPARK-5709] [SQL] Add EXPLAIN support in DataFrame API for debugging purpose + Cheng Hao + 2015-02-10 19:40:51 -0800 + Commit: 7fa0d5f, github.com/apache/spark/pull/4496 + + [SPARK-5704] [SQL] [PySpark] createDataFrame from RDD with columns + Davies Liu + 2015-02-10 19:40:12 -0800 + Commit: 1056c5b, github.com/apache/spark/pull/4498 + + [SPARK-5683] [SQL] Avoid multiple json generator created + Cheng Hao + 2015-02-10 18:19:56 -0800 + Commit: fc0446f, github.com/apache/spark/pull/4468 + + [SQL] Add an exception for analysis errors. + Michael Armbrust + 2015-02-10 17:32:42 -0800 + Commit: 748cdc1, github.com/apache/spark/pull/4439 + + [SPARK-5658][SQL] Finalize DDL and write support APIs + Yin Huai + 2015-02-10 17:29:52 -0800 + Commit: a21090e, github.com/apache/spark/pull/4446 + + [SPARK-5493] [core] Add option to impersonate user. + Marcelo Vanzin + 2015-02-10 17:19:10 -0800 + Commit: 8e75b0e, github.com/apache/spark/pull/4405 + + [SQL] Make Options in the data source API CREATE TABLE statements optional. + Yin Huai + 2015-02-10 17:06:12 -0800 + Commit: 445dbc7, github.com/apache/spark/pull/4515 + + [SPARK-5725] [SQL] Fixes ParquetRelation2.equals + Cheng Lian + 2015-02-10 17:02:44 -0800 + Commit: f43bc3d, github.com/apache/spark/pull/4513 + + [SPARK-5343][GraphX]: ShortestPaths traverses backwards + Brennon York + 2015-02-10 14:57:00 -0800 + Commit: 5be8902, github.com/apache/spark/pull/4478 + + [SPARK-5021] [MLlib] Gaussian Mixture now supports Sparse Input + MechCoder + 2015-02-10 14:05:55 -0800 + Commit: bba0953, github.com/apache/spark/pull/4459 + + [HOTFIX][SPARK-4136] Fix compilation and tests + Andrew Or + 2015-02-10 11:18:01 -0800 + Commit: 4e3aa68 + + [SPARK-5686][SQL] Add show current roles command in HiveQl + OopsOutOfMemory + 2015-02-10 13:20:15 -0800 + Commit: 8b7587a, github.com/apache/spark/pull/4471 + + [SQL] Add toString to DataFrame/Column + Michael Armbrust + 2015-02-10 13:14:01 -0800 + Commit: ef739d9, github.com/apache/spark/pull/4436 + + SPARK-5613: Catch the ApplicationNotFoundException exception to avoid thread from getting killed on yarn restart. + Kashish Jain + 2015-02-06 13:47:23 -0800 + Commit: c294216, github.com/apache/spark/pull/4392 + + [SPARK-5592][SQL] java.net.URISyntaxException when insert data to a partitioned table + wangfei , Fei Wang + 2015-02-10 11:54:30 -0800 + Commit: dbfce30, github.com/apache/spark/pull/4368 + + SPARK-4136. Under dynamic allocation, cancel outstanding executor requests when no longer needed + Sandy Ryza + 2015-02-10 11:07:25 -0800 + Commit: e53da21, github.com/apache/spark/pull/4168 + + [SPARK-5716] [SQL] Support TOK_CHARSETLITERAL in HiveQl + Daoyuan Wang + 2015-02-10 11:08:21 -0800 + Commit: e508237, github.com/apache/spark/pull/4502 + + [Spark-5717] [MLlib] add stop and reorganize import + JqueryFan , Yuhao Yang + 2015-02-10 17:37:32 +0000 + Commit: b32f553, github.com/apache/spark/pull/4503 + + [SPARK-5700] [SQL] [Build] Bumps jets3t to 0.9.3 for hadoop-2.3 and hadoop-2.4 profiles + Cheng Lian + 2015-02-10 02:28:47 -0800 + Commit: d6f31e0, github.com/apache/spark/pull/4499 + + SPARK-5239 [CORE] JdbcRDD throws "java.lang.AbstractMethodError: oracle.jdbc.driver.xxxxxx.isClosed()Z" + Sean Owen + 2015-02-10 09:19:01 +0000 + Commit: 4cfc025, github.com/apache/spark/pull/4470 + + [SPARK-4964][Streaming][Kafka] More updates to Exactly-once Kafka stream + Tathagata Das + 2015-02-09 22:45:48 -0800 + Commit: 281614d, github.com/apache/spark/pull/4384 + + [SPARK-5597][MLLIB] save/load for decision trees and emsembles + Joseph K. Bradley , Xiangrui Meng + 2015-02-09 22:09:07 -0800 + Commit: 01905c4, github.com/apache/spark/pull/4444. + + [SQL] Remove the duplicated code + Cheng Hao + 2015-02-09 21:33:34 -0800 + Commit: 663d34e, github.com/apache/spark/pull/4494 + + [SPARK-5701] Only set ShuffleReadMetrics when task has shuffle deps + Kay Ousterhout + 2015-02-09 21:22:09 -0800 + Commit: 6ddbca4, github.com/apache/spark/pull/4488 + + [SPARK-5703] AllJobsPage throws empty.max exception + Andrew Or + 2015-02-09 21:18:48 -0800 + Commit: 8326255, github.com/apache/spark/pull/4490 + + [SPARK-2996] Implement userClassPathFirst for driver, yarn. + Marcelo Vanzin + 2015-02-09 21:17:06 -0800 + Commit: 6a1e0f9, github.com/apache/spark/pull/3233 + + SPARK-4900 [MLLIB] MLlib SingularValueDecomposition ARPACK IllegalStateException + Sean Owen + 2015-02-09 21:13:58 -0800 + Commit: ebf1df0, github.com/apache/spark/pull/4485 + + Add a config option to print DAG. + KaiXinXiaoLei + 2015-02-09 20:58:58 -0800 + Commit: dad05e0, github.com/apache/spark/pull/4257 + + [SPARK-5469] restructure pyspark.sql into multiple files + Davies Liu + 2015-02-09 20:49:22 -0800 + Commit: f0562b4, github.com/apache/spark/pull/4479 + + [SPARK-5698] Do not let user request negative # of executors + Andrew Or + 2015-02-09 17:33:29 -0800 + Commit: 62b1e1f, github.com/apache/spark/pull/4483 + + [SPARK-5699] [SQL] [Tests] Runs hive-thriftserver tests whenever SQL code is modified + Cheng Lian + 2015-02-09 16:52:05 -0800 + Commit: 71f0f51, github.com/apache/spark/pull/4486 + + [SPARK-5648][SQL] support "alter ... unset tblproperties("key")" + DoingDone9 <799203320@qq.com> + 2015-02-09 16:40:26 -0800 + Commit: e2bf59a, github.com/apache/spark/pull/4424 + + [SPARK-2096][SQL] support dot notation on array of struct + Wenchen Fan + 2015-02-09 16:39:34 -0800 + Commit: 15f557f, github.com/apache/spark/pull/2405 + + [SPARK-5614][SQL] Predicate pushdown through Generate. + Lu Yan + 2015-02-09 16:25:38 -0800 + Commit: ce2c89c, github.com/apache/spark/pull/4394 + + [SPARK-5696] [SQL] [HOTFIX] Asks HiveThriftServer2 to re-initialize log4j using Hive configurations + Cheng Lian + 2015-02-09 16:23:12 -0800 + Commit: 379233c, github.com/apache/spark/pull/4484 + + [SQL] Code cleanup. + Yin Huai + 2015-02-09 16:20:42 -0800 + Commit: e241601, github.com/apache/spark/pull/4482 + + [SQL] Add some missing DataFrame functions. + Michael Armbrust + 2015-02-09 16:02:56 -0800 + Commit: a70dca0, github.com/apache/spark/pull/4437 + + [SPARK-5675][SQL] XyzType companion object should subclass XyzType + Reynold Xin + 2015-02-09 14:51:46 -0800 + Commit: 1e2fab2, github.com/apache/spark/pull/4463 + + [SPARK-4905][STREAMING] FlumeStreamSuite fix. + Hari Shreedharan + 2015-02-09 14:17:14 -0800 + Commit: 18c5a99, github.com/apache/spark/pull/4371 + + [SPARK-5691] Fixing wrong data structure lookup for dupe app registratio... + mcheah + 2015-02-09 13:20:14 -0800 + Commit: 6a0144c, github.com/apache/spark/pull/4477 + + [SPARK-5678] Convert DataFrame to pandas.DataFrame and Series + Davies Liu + 2015-02-09 11:42:52 -0800 + Commit: 43972b5, github.com/apache/spark/pull/4476 + + [SPARK-5664][BUILD] Restore stty settings when exiting from SBT's spark-shell + Liang-Chi Hsieh + 2015-02-09 11:45:12 -0800 + Commit: fa67877, github.com/apache/spark/pull/4451 + + SPARK-4267 [YARN] Failing to launch jobs on Spark on YARN with Hadoop 2.5.0 or later + Sean Owen + 2015-02-09 10:33:57 -0800 + Commit: c88d4ab, github.com/apache/spark/pull/4452 + + [SPARK-5473] [EC2] Expose SSH failures after status checks pass + Nicholas Chammas + 2015-02-09 09:44:53 +0000 + Commit: f2aa7b7, github.com/apache/spark/pull/4262 + + [SPARK-5539][MLLIB] LDA guide + Xiangrui Meng , Joseph K. Bradley + 2015-02-08 23:40:36 -0800 + Commit: 5782ee2, github.com/apache/spark/pull/4465 + + [SPARK-5472][SQL] Fix Scala code style + Hung Lin + 2015-02-08 22:36:42 -0800 + Commit: 955f286, github.com/apache/spark/pull/4464 + + SPARK-4405 [MLLIB] Matrices.* construction methods should check for rows x cols overflow + Sean Owen + 2015-02-08 21:08:50 -0800 + Commit: fa8ea48, github.com/apache/spark/pull/4461 + + [SPARK-5660][MLLIB] Make Matrix apply public + Joseph K. Bradley , Xiangrui Meng + 2015-02-08 21:07:36 -0800 + Commit: df9b105, github.com/apache/spark/pull/4447 + + [SPARK-5643][SQL] Add a show method to print the content of a DataFrame in tabular format. + Reynold Xin + 2015-02-08 18:56:51 -0800 + Commit: e1996aa, github.com/apache/spark/pull/4416 + + SPARK-5665 [DOCS] Update netlib-java documentation + Sam Halliday , Sam Halliday + 2015-02-08 16:34:26 -0800 + Commit: c515634, github.com/apache/spark/pull/4448 + + [SPARK-5598][MLLIB] model save/load for ALS + Xiangrui Meng + 2015-02-08 16:26:20 -0800 + Commit: 9e4d58f, github.com/apache/spark/pull/4422 + + [SQL] Set sessionState in QueryExecution. + Yin Huai + 2015-02-08 14:55:07 -0800 + Commit: 42c56b6, github.com/apache/spark/pull/4445 + + [SPARK-3039] [BUILD] Spark assembly for new hadoop API (hadoop 2) contai... + medale + 2015-02-08 10:35:29 +0000 + Commit: bc55e20, github.com/apache/spark/pull/4315 + + [SPARK-5672][Web UI] Don't return `ERROR 500` when have missing args + Kirill A. Korinskiy + 2015-02-08 10:31:46 +0000 + Commit: 96010fa, github.com/apache/spark/pull/4239 + + [SPARK-5671] Upgrade jets3t to 0.9.2 in hadoop-2.3 and 2.4 profiles + Josh Rosen + 2015-02-07 17:19:08 -0800 + Commit: 0f9d765, github.com/apache/spark/pull/4454 + + [SPARK-5108][BUILD] Jackson dependency management for Hadoop-2.6.0 support + Zhan Zhang + 2015-02-07 19:41:30 +0000 + Commit: 51fbca4, github.com/apache/spark/pull/3938 + + [BUILD] Add the ability to launch spark-shell from SBT. + Michael Armbrust + 2015-02-07 00:14:38 -0800 + Commit: 6bda169, github.com/apache/spark/pull/4438 + + [SPARK-5388] Provide a stable application submission gateway for standalone cluster mode + Andrew Or + 2015-02-06 15:57:06 -0800 + Commit: 6ec0cdc, github.com/apache/spark/pull/4216 + + SPARK-5403: Ignore UserKnownHostsFile in SSH calls + Grzegorz Dubicki + 2015-02-06 15:43:58 -0800 + Commit: 3d99741, github.com/apache/spark/pull/4196 + + [SPARK-5601][MLLIB] make streaming linear algorithms Java-friendly + Xiangrui Meng + 2015-02-06 15:42:59 -0800 + Commit: 11b28b9, github.com/apache/spark/pull/4432 + + [SQL] [Minor] HiveParquetSuite was disabled by mistake, re-enable them + Cheng Lian + 2015-02-06 15:23:42 -0800 + Commit: 4005802, github.com/apache/spark/pull/4440 + + [SQL] Use TestSQLContext in Java tests + Michael Armbrust + 2015-02-06 15:11:02 -0800 + Commit: c950058, github.com/apache/spark/pull/4441 + + [SPARK-4994][network]Cleanup removed executors' ShuffleInfo in yarn shuffle service + lianhuiwang + 2015-02-06 14:47:52 -0800 + Commit: af6ddf8, github.com/apache/spark/pull/3828 + + [SPARK-5444][Network]Add a retry to deal with the conflict port in netty server. + huangzhaowei + 2015-02-06 14:35:29 -0800 + Commit: caca15a, github.com/apache/spark/pull/4240 + + [SPARK-4874] [CORE] Collect record count metrics + Kostas Sakellis + 2015-02-06 14:31:20 -0800 + Commit: 9fa29a6, github.com/apache/spark/pull/4067 + + [HOTFIX] Fix the maven build after adding sqlContext to spark-shell + Michael Armbrust + 2015-02-06 14:27:06 -0800 + Commit: 11dbf71, github.com/apache/spark/pull/4443 + + [SPARK-5600] [core] Clean up FsHistoryProvider test, fix app sort order. + Marcelo Vanzin + 2015-02-06 14:23:09 -0800 + Commit: 09feecc, github.com/apache/spark/pull/4370 + + SPARK-5633 pyspark saveAsTextFile support for compression codec + Vladimir Vladimirov + 2015-02-06 13:55:02 -0800 + Commit: 1d32341, github.com/apache/spark/pull/4403 + + [HOTFIX][MLLIB] fix a compilation error with java 6 + Xiangrui Meng + 2015-02-06 13:52:35 -0800 + Commit: 87e0f0d, github.com/apache/spark/pull/4442 + + [SPARK-4983] Insert waiting time before tagging EC2 instances + GenTang , Gen TANG + 2015-02-06 13:27:34 -0800 + Commit: 2872d83, github.com/apache/spark/pull/3986 + + [SPARK-5586][Spark Shell][SQL] Make `sqlContext` available in spark shell + OopsOutOfMemory + 2015-02-06 13:20:10 -0800 + Commit: 2ef9853, github.com/apache/spark/pull/4387 + + [SPARK-5278][SQL] Introduce UnresolvedGetField and complete the check of ambiguous reference to fields + Wenchen Fan + 2015-02-06 13:08:09 -0800 + Commit: 1b148ad, github.com/apache/spark/pull/4068 + + [SQL][Minor] Remove cache keyword in SqlParser + wangfei + 2015-02-06 12:42:23 -0800 + Commit: d822606, github.com/apache/spark/pull/4393 + + [SQL][HiveConsole][DOC] HiveConsole `correct hiveconsole imports` + OopsOutOfMemory + 2015-02-06 12:41:28 -0800 + Commit: 2abaa6e, github.com/apache/spark/pull/4389 + + [SPARK-5595][SPARK-5603][SQL] Add a rule to do PreInsert type casting and field renaming and invalidating in memory cache after INSERT + Yin Huai + 2015-02-06 12:38:07 -0800 + Commit: 3c34d62, github.com/apache/spark/pull/4373 + + [SPARK-5324][SQL] Results of describe can't be queried + OopsOutOfMemory , Sheng, Li + 2015-02-06 12:33:20 -0800 + Commit: 0fc35da, github.com/apache/spark/pull/4249 + + [SPARK-5619][SQL] Support 'show roles' in HiveContext + q00251598 + 2015-02-06 12:29:26 -0800 + Commit: cc66a3c, github.com/apache/spark/pull/4397 + + [SPARK-5640] Synchronize ScalaReflection where necessary + Tobias Schlatter + 2015-02-06 12:15:02 -0800 + Commit: 779e28b, github.com/apache/spark/pull/4431 + + [SPARK-5650][SQL] Support optional 'FROM' clause + Liang-Chi Hsieh + 2015-02-06 12:13:44 -0800 + Commit: 921121d, github.com/apache/spark/pull/4426 + + [SPARK-5628] Add version option to spark-ec2 + Nicholas Chammas + 2015-02-06 12:08:22 -0800 + Commit: ab0ffde, github.com/apache/spark/pull/4414 + + [SPARK-2945][YARN][Doc]add doc for spark.executor.instances + WangTaoTheTonic + 2015-02-06 11:57:02 -0800 + Commit: 540f474, github.com/apache/spark/pull/4350 + + [SPARK-4361][Doc] Add more docs for Hadoop Configuration + zsxwing + 2015-02-06 11:50:20 -0800 + Commit: 528dd34, github.com/apache/spark/pull/3225 + + [HOTFIX] Fix test build break in ExecutorAllocationManagerSuite. + Josh Rosen + 2015-02-06 11:47:32 -0800 + Commit: 9e828f4 + + [SPARK-5652][Mllib] Use broadcasted weights in LogisticRegressionModel + Liang-Chi Hsieh + 2015-02-06 11:22:11 -0800 + Commit: 6fda4c1, github.com/apache/spark/pull/4429 + + [SPARK-5555] Enable UISeleniumSuite tests + Josh Rosen + 2015-02-06 11:14:58 -0800 + Commit: 93fee7b, github.com/apache/spark/pull/4334 + + SPARK-2450 Adds executor log links to Web UI + Kostas Sakellis , Josh Rosen + 2015-02-06 11:13:00 -0800 + Commit: e74dd04, github.com/apache/spark/pull/3486 + + [SPARK-5618][Spark Core][Minor] Optimise utility code. + Makoto Fukuhara + 2015-02-06 11:11:38 -0800 + Commit: 3feb798, github.com/apache/spark/pull/4396 + + [SPARK-5593][Core]Replace BlockManagerListener with ExecutorListener in ExecutorAllocationListener + lianhuiwang + 2015-02-06 11:09:37 -0800 + Commit: 9387dc1, github.com/apache/spark/pull/4369 + + [SPARK-4877] Allow user first classes to extend classes in the parent. + Stephen Haberman + 2015-02-06 11:03:56 -0800 + Commit: 52386cf, github.com/apache/spark/pull/3725 + + [SPARK-5396] Syntax error in spark scripts on windows. + Masayoshi TSUZUKI + 2015-02-06 10:58:26 -0800 + Commit: 2dc94cd, github.com/apache/spark/pull/4428 + + [SPARK-5636] Ramp up faster in dynamic allocation + Andrew Or + 2015-02-06 10:54:23 -0800 + Commit: 0a90305, github.com/apache/spark/pull/4409 + + SPARK-4337. [YARN] Add ability to cancel pending requests + Sandy Ryza + 2015-02-06 10:53:16 -0800 + Commit: 1568391, github.com/apache/spark/pull/4141 + + [SPARK-5416] init Executor.threadPool before ExecutorSource + Ryan Williams + 2015-02-06 12:22:25 +0000 + Commit: f9bc4ef, github.com/apache/spark/pull/4212 + + [Build] Set all Debian package permissions to 755 + Nicholas Chammas + 2015-02-06 11:38:39 +0000 + Commit: 3638216, github.com/apache/spark/pull/4277 + + Update ec2-scripts.md + Miguel Peralvo + 2015-02-06 11:04:48 +0000 + Commit: f6613fc, github.com/apache/spark/pull/4300 + + [SPARK-5470][Core]use defaultClassLoader to load classes in KryoSerializer + lianhuiwang + 2015-02-06 11:00:35 +0000 + Commit: 8007a4f, github.com/apache/spark/pull/4258 + + [SPARK-5653][YARN] In ApplicationMaster rename isDriver to isClusterMode + lianhuiwang + 2015-02-06 10:48:31 -0800 + Commit: 4ff8855, github.com/apache/spark/pull/4430 + + [SPARK-5582] [history] Ignore empty log directories. + Marcelo Vanzin + 2015-02-06 10:07:20 +0000 + Commit: faccdcb, github.com/apache/spark/pull/4352 + + [SPARK-5157][YARN] Configure more JVM options properly when we use ConcMarkSweepGC for AM. + Kousuke Saruta + 2015-02-06 09:39:12 +0000 + Commit: 25d8044, github.com/apache/spark/pull/3956 + + [Minor] Remove permission for execution from spark-shell.cmd + Kousuke Saruta + 2015-02-06 09:33:36 +0000 + Commit: 7c54681, github.com/apache/spark/pull/3983 + + [SPARK-5380][GraphX] Solve an ArrayIndexOutOfBoundsException when build graph with a file format error + Leolh + 2015-02-06 09:01:53 +0000 + Commit: ffdb2e9, github.com/apache/spark/pull/4176 + + [SPARK-5013] [MLlib] Added documentation and sample data file for GaussianMixture + Travis Galoppo + 2015-02-06 10:26:51 -0800 + Commit: f408db6, github.com/apache/spark/pull/4401 + + [SPARK-4789] [SPARK-4942] [SPARK-5031] [mllib] Standardize ML Prediction APIs + Joseph K. Bradley + 2015-02-05 23:43:47 -0800 + Commit: 45b95e7, github.com/apache/spark/pull/3637 + + [SPARK-5604][MLLIB] remove checkpointDir from trees + Xiangrui Meng + 2015-02-05 23:32:09 -0800 + Commit: c35a11e, github.com/apache/spark/pull/4407 + + [SPARK-5639][SQL] Support DataFrame.renameColumn. + Reynold Xin + 2015-02-05 23:02:40 -0800 + Commit: 0639d3e, github.com/apache/spark/pull/4410 + + Revert "SPARK-5607: Update to Kryo 2.24.0 to avoid including objenesis 1.2." + Patrick Wendell + 2015-02-05 18:37:55 -0800 + Commit: 6d31531 + + SPARK-5557: Explicitly include servlet API in dependencies. + Patrick Wendell + 2015-02-05 18:14:54 -0800 + Commit: 34131fd, github.com/apache/spark/pull/4411 + + [HOTFIX] [SQL] Disables Metastore Parquet table conversion for "SQLQuerySuite.CTAS with serde" + Cheng Lian + 2015-02-05 18:09:18 -0800 + Commit: ce6d8bb, github.com/apache/spark/pull/4413 + + [SPARK-5638][SQL] Add a config flag to disable eager analysis of DataFrames + Reynold Xin + 2015-02-05 18:07:10 -0800 + Commit: 4fd67e4, github.com/apache/spark/pull/4408 + + [SPARK-5620][DOC] group methods in generated unidoc + Xiangrui Meng + 2015-02-05 16:26:51 -0800 + Commit: e2be79d, github.com/apache/spark/pull/4404 + + [SPARK-5182] [SPARK-5528] [SPARK-5509] [SPARK-3575] [SQL] Parquet data source improvements + Cheng Lian + 2015-02-05 15:29:56 -0800 + Commit: 50c48eb, github.com/apache/spark/pull/4308 + + [SPARK-5604[MLLIB] remove checkpointDir from LDA + Xiangrui Meng + 2015-02-05 15:07:33 -0800 + Commit: 59798cb, github.com/apache/spark/pull/4390 + + [SPARK-5460][MLlib] Wrapped `Try` around `deleteAllCheckpoints` - RandomForest. + x1- + 2015-02-05 15:02:04 -0800 + Commit: 44768f5, github.com/apache/spark/pull/4347 + + [SPARK-5135][SQL] Add support for describe table to DDL in SQLContext + OopsOutOfMemory + 2015-02-05 13:07:48 -0800 + Commit: 55cebcf, github.com/apache/spark/pull/4227 + + [SPARK-5617][SQL] fix test failure of SQLQuerySuite + wangfei + 2015-02-05 12:44:12 -0800 + Commit: 785a2e3, github.com/apache/spark/pull/4395 + + [Branch-1.3] [DOC] doc fix for date + Daoyuan Wang + 2015-02-05 12:42:27 -0800 + Commit: 17ef7f9, github.com/apache/spark/pull/4400 + + [SPARK-5474][Build]curl should support URL redirection in build/mvn + GuoQiang Li + 2015-02-05 12:03:13 -0800 + Commit: d1066e9, github.com/apache/spark/pull/4263 + + [HOTFIX] MLlib build break. + Reynold Xin + 2015-02-05 00:42:50 -0800 + Commit: c83d118 + + SPARK-5548: Fixed a race condition in AkkaUtilsSuite + Jacek Lewandowski + 2015-02-05 12:00:04 -0800 + Commit: fba2dc6, github.com/apache/spark/pull/4343 + + [SPARK-5608] Improve SEO of Spark documentation pages + Matei Zaharia + 2015-02-05 11:12:50 -0800 + Commit: de112a2, github.com/apache/spark/pull/4381 + + SPARK-4687. Add a recursive option to the addFile API + Sandy Ryza + 2015-02-05 10:15:55 -0800 + Commit: c22ccc0, github.com/apache/spark/pull/3670 + + [MLlib] Minor: UDF style update. + Reynold Xin + 2015-02-04 23:57:53 -0800 + Commit: 4074674, github.com/apache/spark/pull/4388 + + [SPARK-5612][SQL] Move DataFrame implicit functions into SQLContext.implicits. + Reynold Xin + 2015-02-04 23:44:34 -0800 + Commit: 0040b61, github.com/apache/spark/pull/4386 + + [SPARK-5606][SQL] Support plus sign in HiveContext + q00251598 + 2015-02-04 23:16:01 -0800 + Commit: bf43781, github.com/apache/spark/pull/4378 + + [SPARK-5599] Check MLlib public APIs for 1.3 + Xiangrui Meng + 2015-02-04 23:03:47 -0800 + Commit: abc184e, github.com/apache/spark/pull/4377 + + [SPARK-5596] [mllib] ML model import/export for GLMs, NaiveBayes + Joseph K. Bradley + 2015-02-04 22:46:48 -0800 + Commit: 885bcbb, github.com/apache/spark/pull/4233 + + SPARK-5607: Update to Kryo 2.24.0 to avoid including objenesis 1.2. + Patrick Wendell + 2015-02-04 22:39:44 -0800 + Commit: 59fb5c7, github.com/apache/spark/pull/4383 + + [SPARK-5602][SQL] Better support for creating DataFrame from local data collection + Reynold Xin + 2015-02-04 19:53:57 -0800 + Commit: b8f9c00, github.com/apache/spark/pull/4372 + + [SPARK-5538][SQL] Fix flaky CachedTableSuite + Reynold Xin + 2015-02-04 19:52:41 -0800 + Commit: 1901b19, github.com/apache/spark/pull/4379 + + [SQL][DataFrame] Minor cleanup. + Reynold Xin + 2015-02-04 19:51:48 -0800 + Commit: f05bfa6, github.com/apache/spark/pull/4374 + + [SPARK-4520] [SQL] This pr fixes the ArrayIndexOutOfBoundsException as r... + Sadhan Sood + 2015-02-04 19:18:06 -0800 + Commit: aa6f4ca, github.com/apache/spark/pull/4148 + + [SPARK-5605][SQL][DF] Allow using String to specify colum name in DSL aggregate functions + Reynold Xin + 2015-02-04 18:35:51 -0800 + Commit: 478ee3f, github.com/apache/spark/pull/4376 + + [SPARK-5411] Allow SparkListeners to be specified in SparkConf and loaded when creating SparkContext + Josh Rosen + 2015-02-04 17:18:03 -0800 + Commit: 47e4d57, github.com/apache/spark/pull/4111 + + [SPARK-5577] Python udf for DataFrame + Davies Liu + 2015-02-04 15:55:09 -0800 + Commit: dc9ead9, github.com/apache/spark/pull/4351 + + [SPARK-5118][SQL] Fix: create table test stored as parquet as select .. + guowei2 + 2015-02-04 15:26:10 -0800 + Commit: 06da868, github.com/apache/spark/pull/3921 + + [SQL] Use HiveContext's sessionState in HiveMetastoreCatalog.hiveDefaultTableFilePath + Yin Huai + 2015-02-04 15:22:40 -0800 + Commit: cb4c3e5, github.com/apache/spark/pull/4355 + + [SQL] Correct the default size of TimestampType and expose NumericType + Yin Huai + 2015-02-04 15:14:49 -0800 + Commit: 513bb2c, github.com/apache/spark/pull/4314 + + [SQL][Hiveconsole] Bring hive console code up to date and update README.md + OopsOutOfMemory , Sheng, Li + 2015-02-04 15:13:54 -0800 + Commit: 2cdcfe3, github.com/apache/spark/pull/4330 + + [SPARK-5367][SQL] Support star expression in udfs + wangfei , scwf + 2015-02-04 15:12:07 -0800 + Commit: 8b803f6, github.com/apache/spark/pull/4353 + + [SPARK-5426][SQL] Add SparkSQL Java API helper methods. + kul + 2015-02-04 15:08:37 -0800 + Commit: 38ab92e, github.com/apache/spark/pull/4243 + + [SPARK-5587][SQL] Support change database owner + wangfei + 2015-02-04 14:35:12 -0800 + Commit: 7920791, github.com/apache/spark/pull/4357 + + [SPARK-5591][SQL] Fix NoSuchObjectException for CTAS + wangfei + 2015-02-04 14:33:07 -0800 + Commit: c79dd1e, github.com/apache/spark/pull/4365 + + [SPARK-4939] move to next locality when no pending tasks + Davies Liu + 2015-02-04 14:22:07 -0800 + Commit: f9bb3cb, github.com/apache/spark/pull/3779 + + [SPARK-4707][STREAMING] Reliable Kafka Receiver can lose data if the blo... + Hari Shreedharan + 2015-02-04 14:20:44 -0800 + Commit: 14c9f32, github.com/apache/spark/pull/3655 + + [SPARK-4964] [Streaming] Exactly-once semantics for Kafka + cody koeninger + 2015-02-04 12:06:34 -0800 + Commit: a119cae, github.com/apache/spark/pull/3798 + + [SPARK-5588] [SQL] support select/filter by SQL expression + Davies Liu + 2015-02-04 11:34:46 -0800 + Commit: 950a0d3, github.com/apache/spark/pull/4359 + + [SPARK-5585] Flaky test in MLlib python + Davies Liu + 2015-02-04 08:54:20 -0800 + Commit: 84c6273, github.com/apache/spark/pull/4358 + + [SPARK-5574] use given name prefix in dir + Imran Rashid + 2015-02-04 01:02:20 -0800 + Commit: 5d9278a, github.com/apache/spark/pull/4344 + + [Minor] Fix incorrect warning log + Liang-Chi Hsieh + 2015-02-04 00:52:41 -0800 + Commit: 316a4bb, github.com/apache/spark/pull/4360 + + [SPARK-5379][Streaming] Add awaitTerminationOrTimeout + zsxwing + 2015-02-04 00:40:28 -0800 + Commit: 4d3dbfd, github.com/apache/spark/pull/4171 + + [SPARK-5341] Use maven coordinates as dependencies in spark-shell and spark-submit + Burak Yavuz + 2015-02-03 22:39:17 -0800 + Commit: 3b7acd2, github.com/apache/spark/pull/4215 + + [SPARK-4939] revive offers periodically in LocalBackend + Davies Liu + 2015-02-03 22:30:23 -0800 + Commit: e196da8, github.com/apache/spark/pull/4147 + + [SPARK-4969][STREAMING][PYTHON] Add binaryRecords to streaming + freeman + 2015-02-03 22:24:30 -0800 + Commit: 9a33f89, github.com/apache/spark/pull/3803 + + [SPARK-5579][SQL][DataFrame] Support for project/filter using SQL expressions + Reynold Xin + 2015-02-03 22:15:35 -0800 + Commit: cb7f783, github.com/apache/spark/pull/4348 + + [FIX][MLLIB] fix seed handling in Python GMM + Xiangrui Meng + 2015-02-03 20:39:11 -0800 + Commit: 679228b, github.com/apache/spark/pull/4349 + + [SPARK-4795][Core] Redesign the "primitive type => Writable" implicit APIs to make them be activated automatically + zsxwing + 2015-02-03 20:17:12 -0800 + Commit: 5c63e05, github.com/apache/spark/pull/3642 + + [SPARK-5578][SQL][DataFrame] Provide a convenient way for Scala users to use UDFs + Reynold Xin + 2015-02-03 20:07:46 -0800 + Commit: b22d5b5, github.com/apache/spark/pull/4345 + + [SPARK-5520][MLlib] Make FP-Growth implementation take generic item types (WIP) + Jacky Li , Jacky Li , Xiangrui Meng + 2015-02-03 17:02:42 -0800 + Commit: 298ef5b, github.com/apache/spark/pull/4340 + + [SPARK-5554] [SQL] [PySpark] add more tests for DataFrame Python API + Davies Liu + 2015-02-03 16:01:56 -0800 + Commit: 4640623, github.com/apache/spark/pull/4331 + + [STREAMING] SPARK-4986 Wait for receivers to deregister and receiver job to terminate + Jesper Lundgren + 2015-02-03 14:53:39 -0800 + Commit: 092d4ba, github.com/apache/spark/pull/4338 + + [SPARK-5153][Streaming][Test] Increased timeout to deal with flaky KafkaStreamSuite + Tathagata Das + 2015-02-03 13:46:02 -0800 + Commit: d644bd9, github.com/apache/spark/pull/4342 + + [SPARK-4508] [SQL] build native date type to conform behavior to Hive + Daoyuan Wang + 2015-02-03 12:21:45 -0800 + Commit: 6e244cf, github.com/apache/spark/pull/4325 + + [SPARK-5383][SQL] Support alias for udtfs + wangfei , scwf , Fei Wang + 2015-02-03 12:16:31 -0800 + Commit: 5dbeb21, github.com/apache/spark/pull/4186 + + [SPARK-5550] [SQL] Support the case insensitive for UDF + Cheng Hao + 2015-02-03 12:12:26 -0800 + Commit: 654c992, github.com/apache/spark/pull/4326 + + [SPARK-4987] [SQL] parquet timestamp type support + Daoyuan Wang + 2015-02-03 12:06:06 -0800 + Commit: 67d5220, github.com/apache/spark/pull/3820 + + [SQL] DataFrame API update + Reynold Xin + 2015-02-03 10:34:56 -0800 + Commit: 4204a12, github.com/apache/spark/pull/4332 + + Minor: Fix TaskContext deprecated annotations. + Reynold Xin + 2015-02-03 10:34:16 -0800 + Commit: f7948f3, github.com/apache/spark/pull/4333 + + [SPARK-5549] Define TaskContext interface in Scala. + Reynold Xin + 2015-02-03 00:46:04 -0800 + Commit: bebf4c4, github.com/apache/spark/pull/4324 + + [SPARK-5551][SQL] Create type alias for SchemaRDD for source backward compatibility + Reynold Xin + 2015-02-03 00:29:23 -0800 + Commit: 523a935, github.com/apache/spark/pull/4327 + + [SQL][DataFrame] Remove DataFrameApi, ExpressionApi, and GroupedDataFrameApi + Reynold Xin + 2015-02-03 00:29:04 -0800 + Commit: 37df330, github.com/apache/spark/pull/4328 + + [minor] update streaming linear algorithms + Xiangrui Meng + 2015-02-03 00:14:43 -0800 + Commit: 659329f, github.com/apache/spark/pull/4329 + + [SPARK-1405] [mllib] Latent Dirichlet Allocation (LDA) using EM + Xiangrui Meng + 2015-02-02 23:57:35 -0800 + Commit: 980764f, github.com/apache/spark/pull/2388 + + [SPARK-5536] replace old ALS implementation by the new one + Xiangrui Meng + 2015-02-02 23:49:09 -0800 + Commit: 0cc7b88, github.com/apache/spark/pull/4321 + + [SPARK-5414] Add SparkFirehoseListener class for consuming all SparkListener events + Josh Rosen + 2015-02-02 23:35:07 -0800 + Commit: b8ebebe, github.com/apache/spark/pull/4210 + + [SPARK-5501][SPARK-5420][SQL] Write support for the data source API + Yin Huai + 2015-02-02 23:30:44 -0800 + Commit: 13531dd, github.com/apache/spark/pull/4294 + + [SPARK-5012][MLLib][PySpark]Python API for Gaussian Mixture Model + FlytxtRnD + 2015-02-02 23:04:55 -0800 + Commit: 50a1a87, github.com/apache/spark/pull/4059 + + [SPARK-3778] newAPIHadoopRDD doesn't properly pass credentials for secure hdfs + Thomas Graves + 2015-02-02 22:45:55 -0800 + Commit: c31c36c, github.com/apache/spark/pull/4292 + + [SPARK-4979][MLLIB] Streaming logisitic regression + freeman + 2015-02-02 22:42:15 -0800 + Commit: eb0da6c, github.com/apache/spark/pull/4306 + + [SPARK-5219][Core] Add locks to avoid scheduling race conditions + zsxwing + 2015-02-02 21:42:18 -0800 + Commit: c306555, github.com/apache/spark/pull/4019 + + [Doc] Minor: Fixes several formatting issues + Cheng Lian + 2015-02-02 21:14:21 -0800 + Commit: 60f67e7, github.com/apache/spark/pull/4316 + + SPARK-3996: Add jetty servlet and continuations. + Patrick Wendell + 2015-02-02 21:01:36 -0800 + Commit: 7930d2b, github.com/apache/spark/pull/4323 + + SPARK-5542: Decouple publishing, packaging, and tagging in release script + Patrick Wendell , Patrick Wendell + 2015-02-02 21:00:30 -0800 + Commit: 0ef38f5, github.com/apache/spark/pull/4319 + + [SPARK-5543][WebUI] Remove unused import JsonUtil from from JsonProtocol + nemccarthy + 2015-02-02 20:03:13 -0800 + Commit: cb39f12, github.com/apache/spark/pull/4320 + + [SPARK-5472][SQL] A JDBC data source for Spark SQL. + Tor Myklebust + 2015-02-02 19:50:14 -0800 + Commit: 8f471a6, github.com/apache/spark/pull/4261 + + [SPARK-5512][Mllib] Run the PIC algorithm with initial vector suggected by the PIC paper + Liang-Chi Hsieh + 2015-02-02 19:34:25 -0800 + Commit: 1bcd465, github.com/apache/spark/pull/4301 + + [SPARK-5154] [PySpark] [Streaming] Kafka streaming support in Python + Davies Liu , Tathagata Das + 2015-02-02 19:16:27 -0800 + Commit: 0561c45, github.com/apache/spark/pull/3715 + + [SQL] Improve DataFrame API error reporting + Reynold Xin , Davies Liu + 2015-02-02 19:01:47 -0800 + Commit: 554403f, github.com/apache/spark/pull/4296 + + Revert "[SPARK-4508] [SQL] build native date type to conform behavior to Hive" + Patrick Wendell + 2015-02-02 17:52:17 -0800 + Commit: eccb9fb + + Spark 3883: SSL support for HttpServer and Akka + Jacek Lewandowski , Jacek Lewandowski + 2015-02-02 17:18:54 -0800 + Commit: cfea300, github.com/apache/spark/pull/3571 + + [SPARK-5540] hide ALS.solveLeastSquares + Xiangrui Meng + 2015-02-02 17:10:01 -0800 + Commit: ef65cf0, github.com/apache/spark/pull/4318 + + [SPARK-5534] [graphx] Graph getStorageLevel fix + Joseph K. Bradley + 2015-02-02 17:02:29 -0800 + Commit: f133dec, github.com/apache/spark/pull/4317 + + [SPARK-5514] DataFrame.collect should call executeCollect + Reynold Xin + 2015-02-02 16:55:36 -0800 + Commit: 8aa3cff, github.com/apache/spark/pull/4313 + + [SPARK-5195][sql]Update HiveMetastoreCatalog.scala(override the MetastoreRelation's sameresult method only compare databasename and table name) + seayi <405078363@qq.com>, Michael Armbrust + 2015-02-02 16:06:52 -0800 + Commit: dca6faa, github.com/apache/spark/pull/3898 + + [SPARK-2309][MLlib] Multinomial Logistic Regression + DB Tsai + 2015-02-02 15:59:15 -0800 + Commit: b1aa8fe, github.com/apache/spark/pull/3833 + + [SPARK-5513][MLLIB] Add nonnegative option to ml's ALS + Xiangrui Meng + 2015-02-02 15:55:44 -0800 + Commit: 46d50f1, github.com/apache/spark/pull/4302 + + [SPARK-4508] [SQL] build native date type to conform behavior to Hive + Daoyuan Wang + 2015-02-02 15:49:22 -0800 + Commit: 1646f89, github.com/apache/spark/pull/3732 + + SPARK-5500. Document that feeding hadoopFile into a shuffle operation wi... + Sandy Ryza + 2015-02-02 14:52:46 -0800 + Commit: 8309349, github.com/apache/spark/pull/4293 + + [SPARK-5461] [graphx] Add isCheckpointed, getCheckpointedFiles methods to Graph + Joseph K. Bradley + 2015-02-02 14:34:48 -0800 + Commit: 842d000, github.com/apache/spark/pull/4253 + + SPARK-5425: Use synchronised methods in system properties to create SparkConf + Jacek Lewandowski + 2015-02-02 14:07:19 -0800 + Commit: 5a55261, github.com/apache/spark/pull/4222 + + Disabling Utils.chmod700 for Windows + Martin Weindel , mweindel + 2015-02-02 13:46:18 -0800 + Commit: bff65b5, github.com/apache/spark/pull/4299 + + Make sure only owner can read / write to directories created for the job. + Josh Rosen + 2015-01-21 14:38:14 -0800 + Commit: 52f5754 + + [HOTFIX] Add jetty references to build for YARN module. + Patrick Wendell + 2015-02-02 14:00:14 -0800 + Commit: 2321dd1 + + [SPARK-4631][streaming][FIX] Wait for a receiver to start before publishing test data. + Iulian Dragos + 2015-02-02 14:00:33 -0800 + Commit: e908322, github.com/apache/spark/pull/4270 + + [SPARK-5212][SQL] Add support of schema-less, custom field delimiter and SerDe for HiveQL transform + Liang-Chi Hsieh + 2015-02-02 13:53:55 -0800 + Commit: 683e938, github.com/apache/spark/pull/4014 + + [SPARK-5530] Add executor container to executorIdToContainer + Xutingjun <1039320815@qq.com> + 2015-02-02 12:37:51 -0800 + Commit: 62a93a1, github.com/apache/spark/pull/4309 + + [Docs] Fix Building Spark link text + Nicholas Chammas + 2015-02-02 12:33:49 -0800 + Commit: 3f941b6, github.com/apache/spark/pull/4312 + + [SPARK-5173]support python application running on yarn cluster mode + lianhuiwang , Wang Lianhui + 2015-02-02 12:32:28 -0800 + Commit: f5e6375, github.com/apache/spark/pull/3976 + + SPARK-4585. Spark dynamic executor allocation should use minExecutors as... + Sandy Ryza + 2015-02-02 12:27:08 -0800 + Commit: b2047b5, github.com/apache/spark/pull/4051 + + [MLLIB] SPARK-5491 (ex SPARK-1473): Chi-square feature selection + Alexander Ulanov + 2015-02-02 12:13:05 -0800 + Commit: c081b21, github.com/apache/spark/pull/1484 + + SPARK-5492. Thread statistics can break with older Hadoop versions + Sandy Ryza + 2015-02-02 00:54:06 -0800 + Commit: 6f34131, github.com/apache/spark/pull/4305 + + [SPARK-5478][UI][Minor] Add missing right parentheses + jerryshao + 2015-02-01 23:56:13 -0800 + Commit: 63dfe21, github.com/apache/spark/pull/4267 + + [SPARK-5353] Log failures in REPL class loading + Tobias Schlatter + 2015-02-01 21:43:49 -0800 + Commit: 9f0a6e1, github.com/apache/spark/pull/4130 + + [SPARK-3996]: Shade Jetty in Spark deliverables + Patrick Wendell + 2015-02-01 21:13:57 -0800 + Commit: a15f6e3, github.com/apache/spark/pull/4285 + + [SPARK-4001][MLlib] adding parallel FP-Growth algorithm for frequent pattern mining in MLlib + Jacky Li , Jacky Li , Xiangrui Meng + 2015-02-01 20:07:25 -0800 + Commit: 859f724, github.com/apache/spark/pull/2847 + + [Spark-5406][MLlib] LocalLAPACK mode in RowMatrix.computeSVD should have much smaller upper bound + Yuhao Yang + 2015-02-01 19:40:26 -0800 + Commit: d85cd4e, github.com/apache/spark/pull/4200 + + [SPARK-5465] [SQL] Fixes filter push-down for Parquet data source + Cheng Lian + 2015-02-01 18:52:39 -0800 + Commit: ec10032, github.com/apache/spark/pull/4255 + + [SPARK-5262] [SPARK-5244] [SQL] add coalesce in SQLParser and widen types for parameters of coalesce + Daoyuan Wang + 2015-02-01 18:51:38 -0800 + Commit: 8cf4a1f, github.com/apache/spark/pull/4057 + + [SPARK-5196][SQL] Support `comment` in Create Table Field DDL + OopsOutOfMemory + 2015-02-01 18:41:49 -0800 + Commit: 1b56f1d, github.com/apache/spark/pull/3999 + + [SPARK-1825] Make Windows Spark client work fine with Linux YARN cluster + Masayoshi TSUZUKI + 2015-02-01 18:26:28 -0800 + Commit: 7712ed5, github.com/apache/spark/pull/3943 + + [SPARK-5176] The thrift server does not support cluster mode + Tom Panning + 2015-02-01 17:57:31 -0800 + Commit: 1ca0a10, github.com/apache/spark/pull/4137 + + [SPARK-5155] Build fails with spark-ganglia-lgpl profile + Kousuke Saruta + 2015-02-01 17:53:56 -0800 + Commit: c80194b, github.com/apache/spark/pull/4303 + + [Minor][SQL] Little refactor DataFrame related codes + Liang-Chi Hsieh + 2015-02-01 17:52:18 -0800 + Commit: ef89b82, github.com/apache/spark/pull/4298 + + [SPARK-4859][Core][Streaming] Refactor LiveListenerBus and StreamingListenerBus + zsxwing + 2015-02-01 17:47:51 -0800 + Commit: 883bc88, github.com/apache/spark/pull/4006 + + [SPARK-5424][MLLIB] make the new ALS impl take generic ID types + Xiangrui Meng + 2015-02-01 14:13:31 -0800 + Commit: 4a17122, github.com/apache/spark/pull/4281 + + [SPARK-5207] [MLLIB] StandardScalerModel mean and variance re-use + Octavian Geagla + 2015-02-01 09:21:14 -0800 + Commit: bdb0680, github.com/apache/spark/pull/4140 + + [SPARK-5422] Add support for sending Graphite metrics via UDP + Ryan Williams + 2015-01-31 23:41:05 -0800 + Commit: 80bd715, github.com/apache/spark/pull/4218 + + SPARK-3359 [CORE] [DOCS] `sbt/sbt unidoc` doesn't work with Java 8 + Sean Owen + 2015-01-31 10:40:42 -0800 + Commit: c84d5a1, github.com/apache/spark/pull/4193 + + [SPARK-3975] Added support for BlockMatrix addition and multiplication + Burak Yavuz , Burak Yavuz , Burak Yavuz , Burak Yavuz , Burak Yavuz + 2015-01-31 00:47:30 -0800 + Commit: ef8974b, github.com/apache/spark/pull/4274 + + [MLLIB][SPARK-3278] Monotone (Isotonic) regression using parallel pool adjacent violators algorithm + martinzapletal , Xiangrui Meng , Martin Zapletal + 2015-01-31 00:46:02 -0800 + Commit: 34250a6, github.com/apache/spark/pull/3519 + + [SPARK-5307] Add a config option for SerializationDebugger. + Reynold Xin + 2015-01-31 00:06:36 -0800 + Commit: 6364083, github.com/apache/spark/pull/4297 + + [SQL] remove redundant field "childOutput" from execution.Aggregate, use child.output instead + kai + 2015-01-30 23:19:10 -0800 + Commit: f54c9f6, github.com/apache/spark/pull/4291 + + [SPARK-5307] SerializationDebugger + Reynold Xin + 2015-01-30 22:34:10 -0800 + Commit: 740a568, github.com/apache/spark/pull/4098 + + [SPARK-5504] [sql] convertToCatalyst should support nested arrays + Joseph K. Bradley + 2015-01-30 15:40:14 -0800 + Commit: e643de4, github.com/apache/spark/pull/4295 + + SPARK-5400 [MLlib] Changed name of GaussianMixtureEM to GaussianMixture + Travis Galoppo + 2015-01-30 15:32:25 -0800 + Commit: 9869773, github.com/apache/spark/pull/4290 + + [SPARK-4259][MLlib]: Add Power Iteration Clustering Algorithm with Gaussian Similarity Function + sboeschhuawei , Fan Jiang , Jiang Fan , Stephen Boesch , Xiangrui Meng + 2015-01-30 14:09:49 -0800 + Commit: f377431, github.com/apache/spark/pull/4254 + + [SPARK-5486] Added validate method to BlockMatrix + Burak Yavuz + 2015-01-30 13:59:10 -0800 + Commit: 6ee8338, github.com/apache/spark/pull/4279 + + [SPARK-5496][MLLIB] Allow both classification and Classification in Algo for trees. + Xiangrui Meng + 2015-01-30 10:08:07 -0800 + Commit: 0a95085, github.com/apache/spark/pull/4287 + + [MLLIB] SPARK-4846: throw a RuntimeException and give users hints to increase the minCount + Joseph J.C. Tang + 2015-01-30 10:07:26 -0800 + Commit: 54d9575, github.com/apache/spark/pull/4247 + + SPARK-5393. Flood of util.RackResolver log messages after SPARK-1714 + Sandy Ryza + 2015-01-30 11:31:54 -0600 + Commit: 254eaa4, github.com/apache/spark/pull/4192 + + [SPARK-5457][SQL] Add missing DSL for ApproxCountDistinct. + Takuya UESHIN + 2015-01-30 01:21:35 -0800 + Commit: 6f21dce, github.com/apache/spark/pull/4250 + + [SPARK-5094][MLlib] Add Python API for Gradient Boosted Trees + Kazuki Taniguchi + 2015-01-30 00:39:44 -0800 + Commit: bc1fc9b, github.com/apache/spark/pull/3951 + + [SPARK-5322] Added transpose functionality to BlockMatrix + Burak Yavuz + 2015-01-29 21:26:29 -0800 + Commit: dd4d84c, github.com/apache/spark/pull/4275 + + [SQL] Support df("*") to select all columns in a data frame. + Reynold Xin + 2015-01-29 19:09:08 -0800 + Commit: 80def9d, github.com/apache/spark/pull/4283 + + [SPARK-5462] [SQL] Use analyzed query plan in DataFrame.apply() + Josh Rosen + 2015-01-29 18:23:05 -0800 + Commit: 22271f9, github.com/apache/spark/pull/4282 + + [SPARK-5395] [PySpark] fix python process leak while coalesce() + Davies Liu + 2015-01-29 17:28:37 -0800 + Commit: 5c746ee, github.com/apache/spark/pull/4238 + + [SQL] DataFrame API improvements + Reynold Xin + 2015-01-29 17:24:00 -0800 + Commit: ce9c43b, github.com/apache/spark/pull/4280 + + Revert "[WIP] [SPARK-3996]: Shade Jetty in Spark deliverables" + Patrick Wendell + 2015-01-29 17:14:27 -0800 + Commit: d2071e8 + + remove 'return' + Yoshihiro Shimizu + 2015-01-29 16:55:00 -0800 + Commit: 5338772, github.com/apache/spark/pull/4268 + + [WIP] [SPARK-3996]: Shade Jetty in Spark deliverables + Patrick Wendell + 2015-01-29 16:31:19 -0800 + Commit: f240fe3, github.com/apache/spark/pull/4252 + + [SPARK-5464] Fix help() for Python DataFrame instances + Josh Rosen + 2015-01-29 16:23:20 -0800 + Commit: 0bb15f2, github.com/apache/spark/pull/4278 + + [SPARK-4296][SQL] Trims aliases when resolving and checking aggregate expressions + Yin Huai , Cheng Lian + 2015-01-29 15:49:34 -0800 + Commit: c00d517, github.com/apache/spark/pull/4010 + + [SPARK-5373][SQL] Literal in agg grouping expressions leads to incorrect result + wangfei + 2015-01-29 15:47:13 -0800 + Commit: c1b3eeb, github.com/apache/spark/pull/4169 + + [SPARK-5367][SQL] Support star expression in udf + wangfei , scwf + 2015-01-29 15:44:53 -0800 + Commit: fbaf9e0, github.com/apache/spark/pull/4163 + + [SPARK-4786][SQL]: Parquet filter pushdown for castable types + Yash Datta + 2015-01-29 15:42:23 -0800 + Commit: de221ea, github.com/apache/spark/pull/4156 + + [SPARK-5309][SQL] Add support for dictionaries in PrimitiveConverter for Strin... + Michael Davies + 2015-01-29 15:40:59 -0800 + Commit: 940f375, github.com/apache/spark/pull/4187 + + [SPARK-5429][SQL] Use javaXML plan serialization for Hive golden answers on Hive 0.13.1 + Liang-Chi Hsieh + 2015-01-29 15:28:22 -0800 + Commit: bce0ba1, github.com/apache/spark/pull/4223 + + [SPARK-5445][SQL] Consolidate Java and Scala DSL static methods. + Reynold Xin + 2015-01-29 15:13:09 -0800 + Commit: 7156322, github.com/apache/spark/pull/4276 + + [SPARK-5466] Add explicit guava dependencies where needed. + Marcelo Vanzin + 2015-01-29 13:00:45 -0800 + Commit: f9e5694, github.com/apache/spark/pull/4272 + + [SPARK-5477] refactor stat.py + Xiangrui Meng + 2015-01-29 10:11:44 -0800 + Commit: a3dc618, github.com/apache/spark/pull/4266 + + [SQL] Various DataFrame DSL update. + Reynold Xin + 2015-01-29 00:01:10 -0800 + Commit: 5ad78f6, github.com/apache/spark/pull/4260 + + [SPARK-3977] Conversion methods for BlockMatrix to other Distributed Matrices + Burak Yavuz + 2015-01-28 23:42:07 -0800 + Commit: a63be1a, github.com/apache/spark/pull/4256 + + [SPARK-5445][SQL] Made DataFrame dsl usable in Java + Reynold Xin + 2015-01-28 19:10:32 -0800 + Commit: 5b9760d, github.com/apache/spark/pull/4241 + + [SPARK-5430] move treeReduce and treeAggregate from mllib to core + Xiangrui Meng + 2015-01-28 17:26:03 -0800 + Commit: 4ee79c7, github.com/apache/spark/pull/4228 + + [SPARK-4586][MLLIB] Python API for ML pipeline and parameters + Xiangrui Meng , Davies Liu + 2015-01-28 17:14:23 -0800 + Commit: e80dc1c, github.com/apache/spark/pull/4151 + + [SPARK-5441][pyspark] Make SerDeUtil PairRDD to Python conversions more robust + Michael Nazario + 2015-01-28 13:55:01 -0800 + Commit: e023112, github.com/apache/spark/pull/4236 + + [SPARK-4387][PySpark] Refactoring python profiling code to make it extensible + Yandu Oppacher , Davies Liu + 2015-01-28 13:48:06 -0800 + Commit: 3bead67, github.com/apache/spark/pull/3255. + + [SPARK-5417] Remove redundant executor-id set() call + Ryan Williams + 2015-01-28 13:04:52 -0800 + Commit: a731314, github.com/apache/spark/pull/4213 + + [SPARK-5434] [EC2] Preserve spaces in EC2 path + Nicholas Chammas + 2015-01-28 12:56:03 -0800 + Commit: d44ee43, github.com/apache/spark/pull/4224 + + [SPARK-5437] Fix DriverSuite and SparkSubmitSuite timeout issues + Andrew Or + 2015-01-28 12:52:31 -0800 + Commit: 84b6ecd, github.com/apache/spark/pull/4230 + + [SPARK-4955]With executor dynamic scaling enabled,executor shoude be added or killed in yarn-cluster mode. + lianhuiwang + 2015-01-28 12:50:57 -0800 + Commit: 81f8f34, github.com/apache/spark/pull/3962 + + [SPARK-5440][pyspark] Add toLocalIterator to pyspark rdd + Michael Nazario + 2015-01-28 12:47:12 -0800 + Commit: 456c11f, github.com/apache/spark/pull/4237 + + SPARK-1934 [CORE] "this" reference escape to "selectorThread" during construction in ConnectionManager + Sean Owen + 2015-01-28 12:44:35 -0800 + Commit: 9b18009, github.com/apache/spark/pull/4225 + + [SPARK-5188][BUILD] make-distribution.sh should support curl, not only wget to get Tachyon + Kousuke Saruta + 2015-01-28 12:43:22 -0800 + Commit: e902dc4, github.com/apache/spark/pull/3988 + + SPARK-5458. Refer to aggregateByKey instead of combineByKey in docs + Sandy Ryza + 2015-01-28 12:41:23 -0800 + Commit: 406f6d3, github.com/apache/spark/pull/4251 + + [SPARK-5447][SQL] Replaced reference to SchemaRDD with DataFrame. + Reynold Xin + 2015-01-28 12:10:01 -0800 + Commit: c8e934e, github.com/apache/spark/pull/4242 + + [SPARK-5361]Multiple Java RDD <-> Python RDD conversions not working correctly + Winston Chen + 2015-01-28 11:08:44 -0800 + Commit: 453d799, github.com/apache/spark/pull/4146 + + [SPARK-5291][CORE] Add timestamp and reason why an executor is removed to SparkListenerExecutorAdded and SparkListenerExecutorRemoved + Kousuke Saruta + 2015-01-28 11:02:51 -0800 + Commit: 0b35fcd, github.com/apache/spark/pull/4082 + + [SPARK-3974][MLlib] Distributed Block Matrix Abstractions + Burak Yavuz , Xiangrui Meng , Burak Yavuz , Burak Yavuz , Burak Yavuz + 2015-01-28 10:06:37 -0800 + Commit: eeb53bf, github.com/apache/spark/pull/3200 + + MAINTENANCE: Automated closing of pull requests. + Patrick Wendell + 2015-01-28 02:15:14 -0800 + Commit: 622ff09, github.com/apache/spark/pull/1480 + + [SPARK-5415] bump sbt to version to 0.13.7 + Ryan Williams + 2015-01-28 02:13:06 -0800 + Commit: 661d3f9, github.com/apache/spark/pull/4211 + + [SPARK-4809] Rework Guava library shading. + Marcelo Vanzin + 2015-01-28 00:29:29 -0800 + Commit: 37a5e27, github.com/apache/spark/pull/3658 + + [SPARK-5097][SQL] Test cases for DataFrame expressions. + Reynold Xin + 2015-01-27 18:10:49 -0800 + Commit: d743732, github.com/apache/spark/pull/4235 + + [SPARK-5097][SQL] DataFrame + Reynold Xin , Davies Liu + 2015-01-27 16:08:24 -0800 + Commit: 119f45d, github.com/apache/spark/pull/4173 + + SPARK-5199. FS read metrics should support CombineFileSplits and track bytes from all FSs + Sandy Ryza + 2015-01-27 15:42:55 -0800 + Commit: b1b35ca, github.com/apache/spark/pull/4050 + + [MLlib] fix python example of ALS in guide + Davies Liu + 2015-01-27 15:33:01 -0800 + Commit: fdaad4e, github.com/apache/spark/pull/4226 + + SPARK-5308 [BUILD] MD5 / SHA1 hash format doesn't match standard Maven output + Sean Owen + 2015-01-27 10:22:50 -0800 + Commit: ff356e2, github.com/apache/spark/pull/4161 + + [SPARK-5321] Support for transposing local matrices + Burak Yavuz + 2015-01-27 01:46:17 -0800 + Commit: 9142674, github.com/apache/spark/pull/4109 + + [SPARK-5419][Mllib] Fix the logic in Vectors.sqdist + Liang-Chi Hsieh + 2015-01-27 01:29:14 -0800 + Commit: 7b0ed79, github.com/apache/spark/pull/4217 + + [SPARK-3726] [MLlib] Allow sampling_rate not equal to 1.0 in RandomForests + MechCoder + 2015-01-26 19:46:17 -0800 + Commit: d6894b1, github.com/apache/spark/pull/4073 + + [SPARK-5119] java.lang.ArrayIndexOutOfBoundsException on trying to train... + lewuathe + 2015-01-26 18:03:21 -0800 + Commit: f2ba5c6, github.com/apache/spark/pull/3975 + + [SPARK-5052] Add common/base classes to fix guava methods signatures. + Elmer Garduno + 2015-01-26 17:40:48 -0800 + Commit: 661e0fc, github.com/apache/spark/pull/3874 + + SPARK-960 [CORE] [TEST] JobCancellationSuite "two jobs sharing the same stage" is broken + Sean Owen + 2015-01-26 14:32:27 -0800 + Commit: 0497ea5, github.com/apache/spark/pull/4180 + + Fix command spaces issue in make-distribution.sh + David Y. Ross + 2015-01-26 14:26:10 -0800 + Commit: b38034e, github.com/apache/spark/pull/4126 + + SPARK-4147 [CORE] Reduce log4j dependency + Sean Owen + 2015-01-26 14:23:42 -0800 + Commit: 54e7b45, github.com/apache/spark/pull/4190 + + [SPARK-5339][BUILD] build/mvn doesn't work because of invalid URL for maven's tgz. + Kousuke Saruta + 2015-01-26 13:07:49 -0800 + Commit: c094c73, github.com/apache/spark/pull/4124 + + [SPARK-5355] use j.u.c.ConcurrentHashMap instead of TrieMap + Davies Liu + 2015-01-26 12:51:32 -0800 + Commit: 1420931, github.com/apache/spark/pull/4208 + + [SPARK-5384][mllib] Vectors.sqdist returns inconsistent results for sparse/dense vectors when the vectors have different lengths + Yuhao Yang + 2015-01-25 22:18:09 -0800 + Commit: 8125168, github.com/apache/spark/pull/4183 + + [SPARK-5268] don't stop CoarseGrainedExecutorBackend for irrelevant DisassociatedEvent + CodingCat + 2015-01-25 19:28:53 -0800 + Commit: 8df9435, github.com/apache/spark/pull/4063 + + SPARK-4430 [STREAMING] [TEST] Apache RAT Checks fail spuriously on test files + Sean Owen + 2015-01-25 19:16:44 -0800 + Commit: 0528b85, github.com/apache/spark/pull/4189 + + [SPARK-5326] Show fetch wait time as optional metric in the UI + Kay Ousterhout + 2015-01-25 16:48:26 -0800 + Commit: fc2168f, github.com/apache/spark/pull/4110 + + [SPARK-5344][WebUI] HistoryServer cannot recognize that inprogress file was renamed to completed file + Kousuke Saruta + 2015-01-25 15:34:20 -0800 + Commit: 8f5c827, github.com/apache/spark/pull/4132 + + SPARK-4506 [DOCS] Addendum: Update more docs to reflect that standalone works in cluster mode + Sean Owen + 2015-01-25 15:25:05 -0800 + Commit: 9f64357, github.com/apache/spark/pull/4160 + + SPARK-5382: Use SPARK_CONF_DIR in spark-class if it is defined + Jacek Lewandowski + 2015-01-25 15:15:09 -0800 + Commit: 1c30afd, github.com/apache/spark/pull/4179 + + SPARK-3782 [CORE] Direct use of log4j in AkkaUtils interferes with certain logging configurations + Sean Owen + 2015-01-25 15:11:57 -0800 + Commit: 383425a, github.com/apache/spark/pull/4184 + + SPARK-3852 [DOCS] Document spark.driver.extra* configs + Sean Owen + 2015-01-25 15:08:05 -0800 + Commit: c586b45, github.com/apache/spark/pull/4185 + + [SPARK-5402] log executor ID at executor-construction time + Ryan Williams + 2015-01-25 14:20:02 -0800 + Commit: aea2548, github.com/apache/spark/pull/4195 + + [SPARK-5401] set executor ID before creating MetricsSystem + Ryan Williams + 2015-01-25 14:17:59 -0800 + Commit: 2d9887b, github.com/apache/spark/pull/4194 + + Add comment about defaultMinPartitions + Idan Zalzberg + 2015-01-25 11:28:05 -0800 + Commit: 412a58e, github.com/apache/spark/pull/4102 + + Closes #4157 + Reynold Xin + 2015-01-25 00:24:59 -0800 + Commit: d22ca1e + + [SPARK-5214][Test] Add a test to demonstrate EventLoop can be stopped in the event thread + zsxwing + 2015-01-24 11:00:35 -0800 + Commit: 0d1e67e, github.com/apache/spark/pull/4174 + + [SPARK-5058] Part 2. Typos and broken URL + Jongyoul Lee + 2015-01-23 23:34:11 -0800 + Commit: 09e09c5, github.com/apache/spark/pull/4172 + + [SPARK-5351][GraphX] Do not use Partitioner.defaultPartitioner as a partitioner of EdgeRDDImp... + Takeshi Yamamuro + 2015-01-23 19:25:15 -0800 + Commit: e224dbb, github.com/apache/spark/pull/4136 + + [SPARK-5063] More helpful error messages for several invalid operations + Josh Rosen + 2015-01-23 17:53:15 -0800 + Commit: cef1f09, github.com/apache/spark/pull/3884 + + [SPARK-3541][MLLIB] New ALS implementation with improved storage + Xiangrui Meng + 2015-01-22 22:09:13 -0800 + Commit: ea74365, github.com/apache/spark/pull/3720 + + [SPARK-5315][Streaming] Fix reduceByWindow Java API not work bug + jerryshao + 2015-01-22 22:04:21 -0800 + Commit: e0f7fb7, github.com/apache/spark/pull/4104 + + [SPARK-5233][Streaming] Fix error replaying of WAL introduced bug + jerryshao + 2015-01-22 21:58:53 -0800 + Commit: 3c3fa63, github.com/apache/spark/pull/4032 + + SPARK-5370. [YARN] Remove some unnecessary synchronization in YarnAlloca... + Sandy Ryza + 2015-01-22 13:49:35 -0600 + Commit: 820ce03, github.com/apache/spark/pull/4164 + + [SPARK-5365][MLlib] Refactor KMeans to reduce redundant data + Liang-Chi Hsieh + 2015-01-22 08:16:35 -0800 + Commit: 246111d, github.com/apache/spark/pull/4159 + + [SPARK-5147][Streaming] Delete the received data WAL log periodically + Tathagata Das , jerryshao + 2015-01-21 23:41:44 -0800 + Commit: 3027f06, github.com/apache/spark/pull/4149 + + [SPARK-5317]Set BoostingStrategy.defaultParams With Enumeration Algo.Classification or Algo.Regression + Basin + 2015-01-21 23:06:34 -0800 + Commit: fcb3e18, github.com/apache/spark/pull/4103 + + [SPARK-3424][MLLIB] cache point distances during k-means|| init + Xiangrui Meng + 2015-01-21 21:20:31 -0800 + Commit: ca7910d, github.com/apache/spark/pull/4144 + + [SPARK-5202] [SQL] Add hql variable substitution support + Cheng Hao + 2015-01-21 17:34:18 -0800 + Commit: 27bccc5, github.com/apache/spark/pull/4003 + + [SPARK-5355] make SparkConf thread-safe + Davies Liu + 2015-01-21 16:51:42 -0800 + Commit: 9bad062, github.com/apache/spark/pull/4143 + + [SPARK-4984][CORE][WEBUI] Adding a pop-up containing the full job description when it is very long + wangfei + 2015-01-21 15:27:42 -0800 + Commit: 3be2a88, github.com/apache/spark/pull/3819 + + [SQL] [Minor] Remove deprecated parquet tests + Cheng Lian + 2015-01-21 14:38:10 -0800 + Commit: ba19689, github.com/apache/spark/pull/4116 + + Revert "[SPARK-5244] [SQL] add coalesce() in sql parser" + Josh Rosen + 2015-01-21 14:27:43 -0800 + Commit: b328ac6 + + [SPARK-5009] [SQL] Long keyword support in SQL Parsers + Cheng Hao + 2015-01-21 13:05:56 -0800 + Commit: 8361078, github.com/apache/spark/pull/3926 + + [SPARK-5244] [SQL] add coalesce() in sql parser + Daoyuan Wang + 2015-01-21 12:59:41 -0800 + Commit: 812d367, github.com/apache/spark/pull/4040 + + [SPARK-5064][GraphX] Add numEdges upperbound validation for R-MAT graph generator to prevent infinite loop + Kenji Kikushima + 2015-01-21 12:34:00 -0800 + Commit: 3ee3ab5, github.com/apache/spark/pull/3950 + + [SPARK-4749] [mllib]: Allow initializing KMeans clusters using a seed + nate.crosswhite , nxwhite-str , Xiangrui Meng + 2015-01-21 10:32:10 -0800 + Commit: 7450a99, github.com/apache/spark/pull/3610 + + [MLlib] [SPARK-5301] Missing conversions and operations on IndexedRowMatrix and CoordinateMatrix + Reza Zadeh + 2015-01-21 09:48:38 -0800 + Commit: aa1e22b, github.com/apache/spark/pull/4089 + + SPARK-1714. Take advantage of AMRMClient APIs to simplify logic in YarnA... + Sandy Ryza + 2015-01-21 10:31:54 -0600 + Commit: 2eeada3, github.com/apache/spark/pull/3765 + + [SPARK-5336][YARN]spark.executor.cores must not be less than spark.task.cpus + WangTao , WangTaoTheTonic + 2015-01-21 09:42:30 -0600 + Commit: 8c06a5f, github.com/apache/spark/pull/4123 + + [SPARK-5297][Streaming] Fix Java file stream type erasure problem + jerryshao + 2015-01-20 23:37:47 -0800 + Commit: 424d8c6, github.com/apache/spark/pull/4101 + + [HOTFIX] Update pom.xml to pull MapR's Hadoop version 2.4.1. + Kannan Rajah + 2015-01-20 23:34:04 -0800 + Commit: ec5b0f2, github.com/apache/spark/pull/4108 + + [SPARK-5275] [Streaming] include python source code + Davies Liu + 2015-01-20 22:44:58 -0800 + Commit: bad6c57, github.com/apache/spark/pull/4128 + + [SPARK-5294][WebUI] Hide tables in AllStagePages for "Active Stages, Completed Stages and Failed Stages" when they are empty + Kousuke Saruta + 2015-01-20 16:40:46 -0800 + Commit: 9a151ce, github.com/apache/spark/pull/4083 + + [SPARK-5186] [MLLIB] Vector.equals and Vector.hashCode are very inefficient + Yuhao Yang , Yuhao Yang + 2015-01-20 15:20:20 -0800 + Commit: 2f82c84, github.com/apache/spark/pull/3997 + + [SPARK-5323][SQL] Remove Row's Seq inheritance. + Reynold Xin + 2015-01-20 15:16:14 -0800 + Commit: d181c2a, github.com/apache/spark/pull/4115 + + [SPARK-5287][SQL] Add defaultSizeOf to every data type. + Yin Huai + 2015-01-20 13:26:36 -0800 + Commit: bc20a52, github.com/apache/spark/pull/4081 + + SPARK-5019 [MLlib] - GaussianMixtureModel exposes instances of MultivariateGauss... + Travis Galoppo + 2015-01-20 12:58:11 -0800 + Commit: 23e2554, github.com/apache/spark/pull/4088 + + [SPARK-5329][WebUI] UIWorkloadGenerator should stop SparkContext. + Kousuke Saruta + 2015-01-20 12:40:55 -0800 + Commit: 769aced, github.com/apache/spark/pull/4112 + + SPARK-4660: Use correct class loader in JavaSerializer (copy of PR #3840... + Jacek Lewandowski + 2015-01-20 12:38:01 -0800 + Commit: c93a57f, github.com/apache/spark/pull/4113 + + [SQL][Minor] Refactors deeply nested FP style code in BooleanSimplification + Cheng Lian + 2015-01-20 11:20:14 -0800 + Commit: 8140802, github.com/apache/spark/pull/4091 + + [SPARK-5333][Mesos] MesosTaskLaunchData occurs BufferUnderflowException + Jongyoul Lee + 2015-01-20 10:17:29 -0800 + Commit: 9d9294a, github.com/apache/spark/pull/4119 + + [SPARK-4803] [streaming] Remove duplicate RegisterReceiver message + Ilayaperumal Gopinathan + 2015-01-20 01:41:10 -0800 + Commit: 4afad9c, github.com/apache/spark/pull/3648 + + [SQL][minor] Add a log4j file for catalyst test. + Reynold Xin + 2015-01-20 00:55:25 -0800 + Commit: debc031, github.com/apache/spark/pull/4117 + + SPARK-5270 [CORE] Provide isEmpty() function in RDD API + Sean Owen + 2015-01-19 22:50:44 -0800 + Commit: 306ff18, github.com/apache/spark/pull/4074 + + [SPARK-5214][Core] Add EventLoop and change DAGScheduler to an EventLoop + zsxwing + 2015-01-19 18:15:51 -0800 + Commit: e69fb8c, github.com/apache/spark/pull/4016 + + [SPARK-4504][Examples] fix run-example failure if multiple assembly jars exist + Venkata Ramana Gollamudi + 2015-01-19 11:58:16 -0800 + Commit: 74de94e, github.com/apache/spark/pull/3377 + + [SPARK-5286][SQL] Fail to drop an invalid table when using the data source API + Yin Huai + 2015-01-19 10:45:29 -0800 + Commit: 2604bc3, github.com/apache/spark/pull/4076 + + [SPARK-5284][SQL] Insert into Hive throws NPE when a inner complex type field has a null value + Yin Huai + 2015-01-19 10:44:12 -0800 + Commit: cd5da42, github.com/apache/spark/pull/4077 + + [SPARK-5282][mllib]: RowMatrix easily gets int overflow in the memory size warning + Yuhao Yang + 2015-01-19 10:10:15 -0800 + Commit: 4432568, github.com/apache/spark/pull/4069 + + MAINTENANCE: Automated closing of pull requests. + Patrick Wendell + 2015-01-19 02:05:24 -0800 + Commit: 1ac1c1d, github.com/apache/spark/pull/3584 + + [SPARK-5088] Use spark-class for running executors directly + Jongyoul Lee + 2015-01-19 02:01:56 -0800 + Commit: 4a4f9cc, github.com/apache/spark/pull/3897 + + [SPARK-3288] All fields in TaskMetrics should be private and use getters/setters + Ilya Ganelin + 2015-01-19 01:32:22 -0800 + Commit: 3453d57, github.com/apache/spark/pull/4020 + + SPARK-5217 Spark UI should report pending stages during job execution on AllStagesPage. + Prashant Sharma + 2015-01-19 01:28:42 -0800 + Commit: 851b6a9, github.com/apache/spark/pull/4043 + + [SQL] fix typo in class description + Jacky Li + 2015-01-18 23:59:08 -0800 + Commit: 7dbf1fd, github.com/apache/spark/pull/4100 + + [SQL][minor] Put DataTypes.java in java dir. + Reynold Xin + 2015-01-18 16:35:40 -0800 + Commit: 1955645, github.com/apache/spark/pull/4097 + + [SQL][Minor] Update sql doc according to data type APIs changes + scwf + 2015-01-18 11:03:13 -0800 + Commit: 1a200a3, github.com/apache/spark/pull/4095 + + [SPARK-5279][SQL] Use java.math.BigDecimal as the exposed Decimal type. + Reynold Xin + 2015-01-18 11:01:42 -0800 + Commit: 1727e08, github.com/apache/spark/pull/4092 + + [HOTFIX]: Minor clean up regarding skipped artifacts in build files. + Patrick Wendell + 2015-01-17 23:15:12 -0800 + Commit: ad16da1, github.com/apache/spark/pull/4080 + + MAINTENANCE: Automated closing of pull requests. + Patrick Wendell + 2015-01-17 20:39:54 -0800 + Commit: e12b5b6, github.com/apache/spark/pull/681 + + [SQL][Minor] Added comments and examples to explain BooleanSimplification + Reynold Xin + 2015-01-17 17:35:53 -0800 + Commit: e7884bc, github.com/apache/spark/pull/4090 + + [SPARK-5096] Use sbt tasks instead of vals to get hadoop version + Michael Armbrust + 2015-01-17 17:03:07 -0800 + Commit: 6999910, github.com/apache/spark/pull/3905 + + [SPARK-4937][SQL] Comment for the newly optimization rules in `BooleanSimplification` + scwf + 2015-01-17 15:51:24 -0800 + Commit: c1f3c27, github.com/apache/spark/pull/4086 + + [SQL][minor] Improved Row documentation. + Reynold Xin + 2015-01-17 00:11:08 -0800 + Commit: f3bfc76, github.com/apache/spark/pull/4085 + + [SPARK-5193][SQL] Remove Spark SQL Java-specific API. + Reynold Xin + 2015-01-16 21:09:06 -0800 + Commit: 61b427d, github.com/apache/spark/pull/4065 + + [SPARK-4937][SQL] Adding optimization to simplify the And, Or condition in spark sql + scwf , wangfei + 2015-01-16 14:01:22 -0800 + Commit: ee1c1f3, github.com/apache/spark/pull/3778 + + [SPARK-733] Add documentation on use of accumulators in lazy transformation + Ilya Ganelin + 2015-01-16 13:25:17 -0800 + Commit: fd3a8a1, github.com/apache/spark/pull/4022 + + [SPARK-4923][REPL] Add Developer API to REPL to allow re-publishing the REPL jar + Chip Senkbeil , Chip Senkbeil + 2015-01-16 12:56:40 -0800 + Commit: d05c9ee, github.com/apache/spark/pull/4034 + + [WebUI] Fix collapse of WebUI layout + Kousuke Saruta + 2015-01-16 12:19:08 -0800 + Commit: ecf943d, github.com/apache/spark/pull/3995 + + [SPARK-5231][WebUI] History Server shows wrong job submission time. + Kousuke Saruta + 2015-01-16 10:05:11 -0800 + Commit: e8422c5, github.com/apache/spark/pull/4029 + + [DOCS] Fix typo in return type of cogroup + Sean Owen + 2015-01-16 09:28:44 -0800 + Commit: f6b852a, github.com/apache/spark/pull/4072 + + [SPARK-5201][CORE] deal with int overflow in the ParallelCollectionRDD.slice method + Ye Xianjin + 2015-01-16 09:20:53 -0800 + Commit: e200ac8, github.com/apache/spark/pull/4002 + + [SPARK-1507][YARN]specify # cores for ApplicationMaster + WangTaoTheTonic , WangTao + 2015-01-16 09:16:56 -0800 + Commit: 2be82b1, github.com/apache/spark/pull/4018 + + [SPARK-4092] [CORE] Fix InputMetrics for coalesce'd Rdds + Kostas Sakellis + 2015-01-15 18:48:39 -0800 + Commit: a79a9f9, github.com/apache/spark/pull/3120 + + [SPARK-4857] [CORE] Adds Executor membership events to SparkListener + Kostas Sakellis + 2015-01-15 17:53:42 -0800 + Commit: 96c2c71, github.com/apache/spark/pull/3711 + + [Minor] Fix tiny typo in BlockManager + Kousuke Saruta + 2015-01-15 17:07:44 -0800 + Commit: 65858ba, github.com/apache/spark/pull/4046 + + [SPARK-5274][SQL] Reconcile Java and Scala UDFRegistration. + Reynold Xin + 2015-01-15 16:15:12 -0800 + Commit: 1881431, github.com/apache/spark/pull/4056 + + [SPARK-5224] [PySpark] improve performance of parallelize list/ndarray + Davies Liu + 2015-01-15 11:40:41 -0800 + Commit: 3c8650c, github.com/apache/spark/pull/4024 + + [SPARK-5193][SQL] Tighten up HiveContext API + Reynold Xin + 2015-01-14 20:31:02 -0800 + Commit: 4b325c7, github.com/apache/spark/pull/4054 + + [SPARK-5254][MLLIB] remove developers section from spark.ml guide + Xiangrui Meng + 2015-01-14 18:54:17 -0800 + Commit: 6abc45e, github.com/apache/spark/pull/4053 + + [SPARK-5193][SQL] Tighten up SQLContext API + Reynold Xin + 2015-01-14 18:36:15 -0800 + Commit: cfa397c, github.com/apache/spark/pull/4049 + + [SPARK-5254][MLLIB] Update the user guide to position spark.ml better + Xiangrui Meng + 2015-01-14 17:50:33 -0800 + Commit: 13d2406, github.com/apache/spark/pull/4052 + + [SPARK-5234][ml]examples for ml don't have sparkContext.stop + Yuhao Yang + 2015-01-14 11:53:43 -0800 + Commit: 76389c5, github.com/apache/spark/pull/4044 + + [SPARK-5235] Make SQLConf Serializable + Alex Baretta + 2015-01-14 11:51:55 -0800 + Commit: 2fd7f72, github.com/apache/spark/pull/4031 + + [SPARK-4014] Add TaskContext.attemptNumber and deprecate TaskContext.attemptId + Josh Rosen + 2015-01-14 11:45:40 -0800 + Commit: 259936b, github.com/apache/spark/pull/3849 + + [SPARK-5228][WebUI] Hide tables for "Active Jobs/Completed Jobs/Failed Jobs" when they are empty + Kousuke Saruta + 2015-01-14 11:10:29 -0800 + Commit: 9d4449c, github.com/apache/spark/pull/4028 + + [SPARK-2909] [MLlib] [PySpark] SparseVector in pyspark now supports indexing + MechCoder + 2015-01-14 11:03:11 -0800 + Commit: 5840f54, github.com/apache/spark/pull/4025 + + [SQL] some comments fix for GROUPING SETS + Daoyuan Wang + 2015-01-14 09:50:01 -0800 + Commit: 38bdc99, github.com/apache/spark/pull/4000 + + [SPARK-5211][SQL]Restore HiveMetastoreTypes.toDataType + Yin Huai + 2015-01-14 09:47:30 -0800 + Commit: 81f72a0, github.com/apache/spark/pull/4026 + + [SPARK-5248] [SQL] move sql.types.decimal.Decimal to sql.types.Decimal + Daoyuan Wang + 2015-01-14 09:36:59 -0800 + Commit: a3f7421, github.com/apache/spark/pull/4041 + + [SPARK-5167][SQL] Move Row into sql package and make it usable for Java. + Reynold Xin + 2015-01-14 00:38:55 -0800 + Commit: d5eeb35, github.com/apache/spark/pull/4030 + + [SPARK-5123][SQL] Reconcile Java/Scala API for data types. + Reynold Xin + 2015-01-13 17:16:41 -0800 + Commit: f996909, github.com/apache/spark/pull/3958 + + [SPARK-5168] Make SQLConf a field rather than mixin in SQLContext + Reynold Xin + 2015-01-13 13:30:35 -0800 + Commit: 14e3f11, github.com/apache/spark/pull/3965 + + [SPARK-4912][SQL] Persistent tables for the Spark SQL data sources api + Yin Huai , Michael Armbrust + 2015-01-13 13:01:27 -0800 + Commit: 6463e0b, github.com/apache/spark/pull/3960 + + [SPARK-5223] [MLlib] [PySpark] fix MapConverter and ListConverter in MLlib + Davies Liu + 2015-01-13 12:50:31 -0800 + Commit: 8ead999, github.com/apache/spark/pull/4023 + + [SPARK-5131][Streaming][DOC]: There is a discrepancy in WAL implementation and configuration doc. + uncleGen + 2015-01-13 10:07:19 -0800 + Commit: 39e333e, github.com/apache/spark/pull/3930 + + [SPARK-4697][YARN]System properties should override environment variables + WangTaoTheTonic , WangTao + 2015-01-13 09:43:48 -0800 + Commit: 9dea64e, github.com/apache/spark/pull/3557 + + [SPARK-5006][Deploy]spark.port.maxRetries doesn't work + WangTaoTheTonic , WangTao + 2015-01-13 09:28:21 -0800 + Commit: f7741a9, github.com/apache/spark/pull/3841 + + [SPARK-5138][SQL] Ensure schema can be inferred from a namedtuple + Gabe Mulley + 2015-01-12 21:44:51 -0800 + Commit: 1e42e96, github.com/apache/spark/pull/3978 + + [SPARK-5049][SQL] Fix ordering of partition columns in ParquetTableScan + Michael Armbrust + 2015-01-12 15:19:09 -0800 + Commit: 5d9fa55, github.com/apache/spark/pull/3990 + + [SPARK-4999][Streaming] Change storeInBlockManager to false by default + jerryshao + 2015-01-12 13:14:44 -0800 + Commit: 3aed305, github.com/apache/spark/pull/3906 + + SPARK-5172 [BUILD] spark-examples-***.jar shades a wrong Hadoop distribution + Sean Owen + 2015-01-12 12:15:34 -0800 + Commit: aff49a3, github.com/apache/spark/pull/3992 + + [SPARK-5078] Optionally read from SPARK_LOCAL_HOSTNAME + Michael Armbrust + 2015-01-12 11:57:59 -0800 + Commit: a3978f3, github.com/apache/spark/pull/3893 + + SPARK-4159 [BUILD] Addendum: improve running of single test after enabling Java tests + Sean Owen + 2015-01-12 11:00:56 -0800 + Commit: 13e610b, github.com/apache/spark/pull/3993 + + [SPARK-5102][Core]subclass of MapStatus needs to be registered with Kryo + lianhuiwang + 2015-01-12 10:57:12 -0800 + Commit: ef9224e, github.com/apache/spark/pull/4007 + + [SPARK-5200] Disable web UI in Hive ThriftServer tests + Josh Rosen + 2015-01-12 10:47:12 -0800 + Commit: 82fd38d, github.com/apache/spark/pull/3998 + + SPARK-5018 [MLlib] [WIP] Make MultivariateGaussian public + Travis Galoppo + 2015-01-11 21:31:16 -0800 + Commit: 2130de9, github.com/apache/spark/pull/3923 + + [SPARK-4033][Examples]Input of the SparkPi too big causes the emption exception + huangzhaowei + 2015-01-11 16:32:47 -0800 + Commit: f38ef65, github.com/apache/spark/pull/2874 + + [SPARK-4951][Core] Fix the issue that a busy executor may be killed + zsxwing + 2015-01-11 16:23:28 -0800 + Commit: 6942b97, github.com/apache/spark/pull/3783 + + [SPARK-5073] spark.storage.memoryMapThreshold have two default value + lewuathe + 2015-01-11 13:50:42 -0800 + Commit: 1656aae, github.com/apache/spark/pull/3900 + + [SPARK-5032] [graphx] Remove GraphX MIMA exclude for 1.3 + Joseph K. Bradley + 2015-01-10 17:25:39 -0800 + Commit: 3313260, github.com/apache/spark/pull/3856 + + [SPARK-5029][SQL] Enable from follow multiple brackets + scwf + 2015-01-10 17:07:34 -0800 + Commit: d22a31f, github.com/apache/spark/pull/3853 + + [SPARK-4871][SQL] Show sql statement in spark ui when run sql with spark-sql + wangfei + 2015-01-10 17:04:56 -0800 + Commit: 92d9a70, github.com/apache/spark/pull/3718 + + [Minor]Resolve sbt warnings during build (MQTTStreamSuite.scala). + GuoQiang Li + 2015-01-10 15:38:43 -0800 + Commit: 8a29dc7, github.com/apache/spark/pull/3989 + + [SPARK-5181] do not print writing WAL log when WAL is disabled + CodingCat + 2015-01-10 15:35:41 -0800 + Commit: f0d558b, github.com/apache/spark/pull/3985 + + [SPARK-4692] [SQL] Support ! boolean logic operator like NOT + YanTangZhai , Michael Armbrust + 2015-01-10 15:05:23 -0800 + Commit: 0ca51cc, github.com/apache/spark/pull/3555 + + [SPARK-5187][SQL] Fix caching of tables with HiveUDFs in the WHERE clause + Michael Armbrust + 2015-01-10 14:25:45 -0800 + Commit: 3684fd2, github.com/apache/spark/pull/3987 + + SPARK-4963 [SQL] Add copy to SQL's Sample operator + Yanbo Liang + 2015-01-10 14:16:37 -0800 + Commit: 77106df, github.com/apache/spark/pull/3827 + + [SPARK-4861][SQL] Refactory command in spark sql + scwf + 2015-01-10 14:08:04 -0800 + Commit: b3e86dc, github.com/apache/spark/pull/3948 + + [SPARK-4574][SQL] Adding support for defining schema in foreign DDL commands. + scwf , Yin Huai , Fei Wang , wangfei + 2015-01-10 13:53:21 -0800 + Commit: 693a323, github.com/apache/spark/pull/3431 + + [SPARK-4943][SQL] Allow table name having dot for db/catalog + Alex Liu + 2015-01-10 13:23:09 -0800 + Commit: 4b39fd1, github.com/apache/spark/pull/3941 + + [SPARK-4925][SQL] Publish Spark SQL hive-thriftserver maven artifact + Alex Liu + 2015-01-10 13:19:12 -0800 + Commit: 1e56eba, github.com/apache/spark/pull/3766 + + [SPARK-5141][SQL]CaseInsensitiveMap throws java.io.NotSerializableException + luogankun + 2015-01-09 20:38:41 -0800 + Commit: 545dfcb, github.com/apache/spark/pull/3944 + + [SPARK-4406] [MLib] FIX: Validate k in SVD + MechCoder + 2015-01-09 17:45:18 -0800 + Commit: 4554529, github.com/apache/spark/pull/3945 + + [SPARK-4990][Deploy]to find default properties file, search SPARK_CONF_DIR first + WangTaoTheTonic , WangTao + 2015-01-09 17:10:02 -0800 + Commit: 8782eb9, github.com/apache/spark/pull/3823 + + [Minor] Fix import order and other coding style + bilna , Bilna P + 2015-01-09 14:45:28 -0800 + Commit: 4e1f12d, github.com/apache/spark/pull/3966 + + [DOC] Fixed Mesos version in doc from 0.18.1 to 0.21.0 + Kousuke Saruta + 2015-01-09 14:40:45 -0800 + Commit: ae62872, github.com/apache/spark/pull/3982 + + [SPARK-4737] Task set manager properly handles serialization errors + mcheah + 2015-01-09 14:16:20 -0800 + Commit: e0f28e0, github.com/apache/spark/pull/3638 + + [SPARK-1953][YARN]yarn client mode Application Master memory size is same as driver memory... + WangTaoTheTonic + 2015-01-09 13:20:32 -0800 + Commit: e966452, github.com/apache/spark/pull/3607 + + [SPARK-5015] [mllib] Random seed for GMM + make test suite deterministic + Joseph K. Bradley + 2015-01-09 13:00:15 -0800 + Commit: 7e8e62a, github.com/apache/spark/pull/3981 + + [SPARK-3619] Upgrade to Mesos 0.21 to work around MESOS-1688 + Jongyoul Lee + 2015-01-09 10:47:08 -0800 + Commit: 454fe12, github.com/apache/spark/pull/3934 + + [SPARK-5145][Mllib] Add BLAS.dsyr and use it in GaussianMixtureEM + Liang-Chi Hsieh + 2015-01-09 10:27:33 -0800 + Commit: e9ca16e, github.com/apache/spark/pull/3949 + + [SPARK-1143] Separate pool tests into their own suite. + Kay Ousterhout + 2015-01-09 09:47:06 -0800 + Commit: b6aa557, github.com/apache/spark/pull/3967 + + HOTFIX: Minor improvements to make-distribution.sh + Patrick Wendell + 2015-01-09 09:40:18 -0800 + Commit: 1790b38, github.com/apache/spark/pull/3973 + + SPARK-5136 [DOCS] Improve documentation around setting up Spark IntelliJ project + Sean Owen + 2015-01-09 09:35:46 -0800 + Commit: 547df97, github.com/apache/spark/pull/3952 + + [Minor] Fix test RetryingBlockFetcherSuite after changed config name + Aaron Davidson + 2015-01-09 09:20:16 -0800 + Commit: b4034c3, github.com/apache/spark/pull/3972 + + [SPARK-5169][YARN]fetch the correct max attempts + WangTaoTheTonic + 2015-01-09 08:10:09 -0600 + Commit: f3da4bd, github.com/apache/spark/pull/3942 + + [SPARK-5122] Remove Shark from spark-ec2 + Nicholas Chammas + 2015-01-08 17:42:08 -0800 + Commit: 167a5ab, github.com/apache/spark/pull/3939 + + [SPARK-4048] Enhance and extend hadoop-provided profile. + Marcelo Vanzin + 2015-01-08 17:15:13 -0800 + Commit: 48cecf6, github.com/apache/spark/pull/2982 + + [SPARK-4891][PySpark][MLlib] Add gamma/log normal/exp dist sampling to P... + RJ Nowling + 2015-01-08 15:03:43 -0800 + Commit: c9c8b21, github.com/apache/spark/pull/3955 + + [SPARK-4973][CORE] Local directory in the driver of client-mode continues remaining even if application finished when external shuffle is enabled + Kousuke Saruta + 2015-01-08 13:43:09 -0800 + Commit: a00af6b, github.com/apache/spark/pull/3811 + + SPARK-5148 [MLlib] Make usersOut/productsOut storagelevel in ALS configurable + Fernando Otero (ZeoS) + 2015-01-08 12:42:54 -0800 + Commit: 72df5a3, github.com/apache/spark/pull/3953 + + Document that groupByKey will OOM for large keys + Eric Moyer + 2015-01-08 11:55:23 -0800 + Commit: 538f221, github.com/apache/spark/pull/3936 + + [SPARK-5130][Deploy]Take yarn-cluster as cluster mode in spark-submit + WangTaoTheTonic + 2015-01-08 11:45:42 -0800 + Commit: 0760787, github.com/apache/spark/pull/3929 + + [Minor] Fix the value represented by spark.executor.id for consistency. + Kousuke Saruta + 2015-01-08 11:35:56 -0800 + Commit: 0a59727, github.com/apache/spark/pull/3812 + + [SPARK-4989][CORE] avoid wrong eventlog conf cause cluster down in standalone mode + Zhang, Liye + 2015-01-08 10:40:26 -0800 + Commit: 06dc4b5, github.com/apache/spark/pull/3824 + + [SPARK-4917] Add a function to convert into a graph with canonical edges in GraphOps + Takeshi Yamamuro + 2015-01-08 09:55:12 -0800 + Commit: f825e19, github.com/apache/spark/pull/3760 + + SPARK-5087. [YARN] Merge yarn.Client and yarn.ClientBase + Sandy Ryza + 2015-01-08 09:25:43 -0800 + Commit: 8d45834, github.com/apache/spark/pull/3896 + + MAINTENANCE: Automated closing of pull requests. + Patrick Wendell + 2015-01-07 23:25:56 -0800 + Commit: c082385, github.com/apache/spark/pull/3880 + + [SPARK-5116][MLlib] Add extractor for SparseVector and DenseVector + Shuo Xiang + 2015-01-07 23:22:37 -0800 + Commit: c66a976, github.com/apache/spark/pull/3919 + + [SPARK-5126][Core] Verify Spark urls before creating Actors so that invalid urls can crash the process. + zsxwing + 2015-01-07 23:01:30 -0800 + Commit: 2b729d2, github.com/apache/spark/pull/3927 + + [SPARK-5132][Core]Correct stage Attempt Id key in stageInfofromJson + hushan[胡珊] + 2015-01-07 12:09:12 -0800 + Commit: d345ebe, github.com/apache/spark/pull/3932 + + [SPARK-5128][MLLib] Add common used log1pExp API in MLUtils + DB Tsai + 2015-01-07 10:13:41 -0800 + Commit: 60e2d9e, github.com/apache/spark/pull/3915 + + [SPARK-2458] Make failed application log visible on History Server + Masayoshi TSUZUKI + 2015-01-07 07:32:16 -0800 + Commit: 6e74ede, github.com/apache/spark/pull/3467 + + [SPARK-2165][YARN]add support for setting maxAppAttempts in the ApplicationSubmissionContext + WangTaoTheTonic + 2015-01-07 08:14:39 -0600 + Commit: 8fdd489, github.com/apache/spark/pull/3878 + + [YARN][SPARK-4929] Bug fix: fix the yarn-client code to support HA + huangzhaowei + 2015-01-07 08:10:42 -0600 + Commit: 5fde661, github.com/apache/spark/pull/3771 + + [SPARK-5099][Mllib] Simplify logistic loss function + Liang-Chi Hsieh + 2015-01-06 21:23:31 -0800 + Commit: e21acc1, github.com/apache/spark/pull/3899 + + [SPARK-5050][Mllib] Add unit test for sqdist + Liang-Chi Hsieh + 2015-01-06 14:00:45 -0800 + Commit: bb38ebb, github.com/apache/spark/pull/3869 + + SPARK-5017 [MLlib] - Use SVD to compute determinant and inverse of covariance matrix + Travis Galoppo + 2015-01-06 13:57:42 -0800 + Commit: 4108e5f, github.com/apache/spark/pull/3871 + + SPARK-4159 [CORE] Maven build doesn't run JUnit test suites + Sean Owen + 2015-01-06 12:02:08 -0800 + Commit: 4cba6eb, github.com/apache/spark/pull/3651 + + [Minor] Fix comments for GraphX 2D partitioning strategy + kj-ki + 2015-01-06 09:49:37 -0800 + Commit: 5e3ec11, github.com/apache/spark/pull/3904 + + [SPARK-1600] Refactor FileInputStream tests to remove Thread.sleep() calls and SystemClock usage + Josh Rosen + 2015-01-06 00:31:19 -0800 + Commit: a6394bc, github.com/apache/spark/pull/3801 + + SPARK-4843 [YARN] Squash ExecutorRunnableUtil and ExecutorRunnable + Kostas Sakellis + 2015-01-05 23:26:33 -0800 + Commit: 451546a, github.com/apache/spark/pull/3696 + + [SPARK-5040][SQL] Support expressing unresolved attributes using $"attribute name" notation in SQL DSL. + Reynold Xin + 2015-01-05 15:34:22 -0800 + Commit: 04d55d8, github.com/apache/spark/pull/3862 + + [SPARK-5093] Set spark.network.timeout to 120s consistently. + Reynold Xin + 2015-01-05 15:19:53 -0800 + Commit: bbcba3a, github.com/apache/spark/pull/3903 + + [SPARK-5089][PYSPARK][MLLIB] Fix vector convert + freeman + 2015-01-05 13:10:59 -0800 + Commit: 6c6f325, github.com/apache/spark/pull/3902 + + [SPARK-4465] runAsSparkUser doesn't affect TaskRunner in Mesos environme... + Jongyoul Lee + 2015-01-05 12:05:09 -0800 + Commit: 1c0e7ce, github.com/apache/spark/pull/3741 + + [SPARK-5057] Log message in failed askWithReply attempts + WangTao , WangTaoTheTonic + 2015-01-05 11:59:38 -0800 + Commit: ce39b34, github.com/apache/spark/pull/3875 + + [SPARK-4688] Have a single shared network timeout in Spark + Varun Saxena , varunsaxena + 2015-01-05 10:32:37 -0800 + Commit: d3f07fd, github.com/apache/spark/pull/3562 + + [SPARK-5074][Core] Fix a non-deterministic test failure + zsxwing + 2015-01-04 21:18:33 -0800 + Commit: 5c506ce, github.com/apache/spark/pull/3889 + + [SPARK-5083][Core] Fix a flaky test in TaskResultGetterSuite + zsxwing + 2015-01-04 21:09:21 -0800 + Commit: 27e7f5a, github.com/apache/spark/pull/3894 + + [SPARK-5069][Core] Fix the race condition of TaskSchedulerImpl.dagScheduler + zsxwing + 2015-01-04 21:06:04 -0800 + Commit: 6c726a3, github.com/apache/spark/pull/3887 + + [SPARK-5067][Core] Use '===' to compare well-defined case class + zsxwing + 2015-01-04 21:03:17 -0800 + Commit: 7239652, github.com/apache/spark/pull/3886 + + [SPARK-4835] Disable validateOutputSpecs for Spark Streaming jobs + Josh Rosen + 2015-01-04 20:26:18 -0800 + Commit: 939ba1f, github.com/apache/spark/pull/3832 + + [SPARK-4631] unit test for MQTT + bilna , Bilna P + 2015-01-04 19:37:48 -0800 + Commit: e767d7d, github.com/apache/spark/pull/3844 + + [SPARK-4787] Stop SparkContext if a DAGScheduler init error occurs + Dale + 2015-01-04 13:28:37 -0800 + Commit: 3fddc94, github.com/apache/spark/pull/3809 + + [SPARK-794][Core] Remove sleep() in ClusterScheduler.stop + Brennon York + 2015-01-04 12:40:39 -0800 + Commit: b96008d, github.com/apache/spark/pull/3851 + + [SPARK-5058] Updated broken links + sigmoidanalytics + 2015-01-03 19:46:08 -0800 + Commit: 342612b, github.com/apache/spark/pull/3877 + + Fixed typos in streaming-kafka-integration.md + Akhil Das + 2015-01-02 15:12:27 -0800 + Commit: cdccc26, github.com/apache/spark/pull/3876 + + [SPARK-3325][Streaming] Add a parameter to the method print in class DStream + Yadong Qi , q00251598 , Tathagata Das , wangfei + 2015-01-02 15:09:41 -0800 + Commit: bd88b71, github.com/apache/spark/pull/3865 + + [HOTFIX] Bind web UI to ephemeral port in DriverSuite + Josh Rosen + 2015-01-01 15:03:54 -0800 + Commit: 0128398, github.com/apache/spark/pull/3873 + + [SPARK-5038] Add explicit return type for implicit functions. + Reynold Xin + 2014-12-31 17:07:47 -0800 + Commit: 7749dd6, github.com/apache/spark/pull/3860 + + SPARK-2757 [BUILD] [STREAMING] Add Mima test for Spark Sink after 1.10 is released + Sean Owen + 2014-12-31 16:59:17 -0800 + Commit: 4bb1248, github.com/apache/spark/pull/3842 + + [SPARK-5035] [Streaming] ReceiverMessage trait should extend Serializable + Josh Rosen + 2014-12-31 16:02:47 -0800 + Commit: fe6efac, github.com/apache/spark/pull/3857 + + SPARK-5020 [MLlib] GaussianMixtureModel.predictMembership() should take an RDD only + Travis Galoppo + 2014-12-31 15:39:58 -0800 + Commit: c4f0b4f, github.com/apache/spark/pull/3854 + + [SPARK-5028][Streaming]Add total received and processed records metrics to Streaming UI + jerryshao + 2014-12-31 14:45:31 -0800 + Commit: fdc2aa4, github.com/apache/spark/pull/3852 + + [SPARK-4790][STREAMING] Fix ReceivedBlockTrackerSuite waits for old file... + Hari Shreedharan + 2014-12-31 14:35:07 -0800 + Commit: 3610d3c, github.com/apache/spark/pull/3726 + + [SPARK-5038][SQL] Add explicit return type for implicit functions in Spark SQL + Reynold Xin + 2014-12-31 14:25:03 -0800 + Commit: c88a3d7, github.com/apache/spark/pull/3859 + + [HOTFIX] Disable Spark UI in SparkSubmitSuite tests + Josh Rosen + 2014-12-12 12:38:37 -0800 + Commit: e24d3a9 + + SPARK-4547 [MLLIB] OOM when making bins in BinaryClassificationMetrics + Sean Owen + 2014-12-31 13:37:04 -0800 + Commit: 3d194cc, github.com/apache/spark/pull/3702 + + [SPARK-4298][Core] - The spark-submit cannot read Main-Class from Manifest. + Brennon York + 2014-12-31 11:54:10 -0800 + Commit: 8e14c5e, github.com/apache/spark/pull/3561 + + [SPARK-4797] Replace breezeSquaredDistance + Liang-Chi Hsieh + 2014-12-31 11:50:53 -0800 + Commit: 06a9aa5, github.com/apache/spark/pull/3643 + + [SPARK-1010] Clean up uses of System.setProperty in unit tests + Josh Rosen + 2014-12-30 18:12:20 -0800 + Commit: 352ed6b, github.com/apache/spark/pull/3739 + + [SPARK-4998][MLlib]delete the "train" function + Liu Jiongzhou + 2014-12-30 15:55:56 -0800 + Commit: 035bac8, github.com/apache/spark/pull/3836 + + [SPARK-4813][Streaming] Fix the issue that ContextWaiter didn't handle 'spurious wakeup' + zsxwing + 2014-12-30 14:39:13 -0800 + Commit: 6a89782, github.com/apache/spark/pull/3661 + + [Spark-4995] Replace Vector.toBreeze.activeIterator with foreachActive + Jakub Dubovsky + 2014-12-30 14:19:07 -0800 + Commit: 0f31992, github.com/apache/spark/pull/3846 + + SPARK-3955 part 2 [CORE] [HOTFIX] Different versions between jackson-mapper-asl and jackson-core-asl + Sean Owen + 2014-12-30 14:00:57 -0800 + Commit: b239ea1, github.com/apache/spark/pull/3829 + + [SPARK-4570][SQL]add BroadcastLeftSemiJoinHash + wangxiaojing + 2014-12-30 13:54:12 -0800 + Commit: 07fa191, github.com/apache/spark/pull/3442 + + [SPARK-4935][SQL] When hive.cli.print.header configured, spark-sql aborted if passed in a invalid sql + wangfei , Fei Wang + 2014-12-30 13:44:30 -0800 + Commit: 8f29b7c, github.com/apache/spark/pull/3761 + + [SPARK-4386] Improve performance when writing Parquet files + Michael Davies + 2014-12-30 13:40:51 -0800 + Commit: 7425bec, github.com/apache/spark/pull/3843 + + [SPARK-4937][SQL] Normalizes conjunctions and disjunctions to eliminate common predicates + Cheng Lian + 2014-12-30 13:38:27 -0800 + Commit: 61a99f6, github.com/apache/spark/pull/3784 + + [SPARK-4928][SQL] Fix: Operator '>,<,>=,<=' with decimal between different precision report error + guowei2 + 2014-12-30 12:21:00 -0800 + Commit: a75dd83, github.com/apache/spark/pull/3767 + + [SPARK-4930][SQL][DOCS]Update SQL programming guide, CACHE TABLE is eager + luogankun + 2014-12-30 12:18:55 -0800 + Commit: 2deac74, github.com/apache/spark/pull/3773 + + [SPARK-4916][SQL][DOCS]Update SQL programming guide about cache section + luogankun + 2014-12-30 12:17:49 -0800 + Commit: f7a41a0, github.com/apache/spark/pull/3759 + + [SPARK-4493][SQL] Tests for IsNull / IsNotNull in the ParquetFilterSuite + Cheng Lian + 2014-12-30 12:16:45 -0800 + Commit: 19a8802, github.com/apache/spark/pull/3748 + + [Spark-4512] [SQL] Unresolved Attribute Exception in Sort By + Cheng Hao + 2014-12-30 12:11:44 -0800 + Commit: 53f0a00, github.com/apache/spark/pull/3386 + + [SPARK-5002][SQL] Using ascending by default when not specify order in order by + wangfei + 2014-12-30 12:07:24 -0800 + Commit: daac221, github.com/apache/spark/pull/3838 + + [SPARK-4904] [SQL] Remove the unnecessary code change in Generic UDF + Cheng Hao + 2014-12-30 11:47:08 -0800 + Commit: 63b84b7, github.com/apache/spark/pull/3745 + + [SPARK-4959] [SQL] Attributes are case sensitive when using a select query from a projection + Cheng Hao + 2014-12-30 11:33:47 -0800 + Commit: 5595eaa, github.com/apache/spark/pull/3796 + + [SPARK-4975][SQL] Fix HiveInspectorSuite test failure + scwf , Fei Wang + 2014-12-30 11:30:47 -0800 + Commit: 65357f1, github.com/apache/spark/pull/3814 + + [SQL] enable view test + Daoyuan Wang + 2014-12-30 11:29:13 -0800 + Commit: 94d60b7, github.com/apache/spark/pull/3826 + + [SPARK-4908][SQL] Prevent multiple concurrent hive native commands + Michael Armbrust + 2014-12-30 11:24:46 -0800 + Commit: 480bd1d, github.com/apache/spark/pull/3834 + + [SPARK-4882] Register PythonBroadcast with Kryo so that PySpark works with KryoSerializer + Josh Rosen + 2014-12-30 09:29:52 -0800 + Commit: efa80a53, github.com/apache/spark/pull/3831 + + [SPARK-4920][UI] add version on master and worker page for standalone mode + Zhang, Liye + 2014-12-30 09:19:47 -0800 + Commit: 9077e72, github.com/apache/spark/pull/3769 + + [SPARK-4972][MLlib] Updated the scala doc for lasso and ridge regression for the change of LeastSquaresGradient + DB Tsai + 2014-12-29 17:17:12 -0800 + Commit: 040d6f2, github.com/apache/spark/pull/3808 + + Added setMinCount to Word2Vec.scala + ganonp + 2014-12-29 15:31:19 -0800 + Commit: 343db39, github.com/apache/spark/pull/3693 + + SPARK-4156 [MLLIB] EM algorithm for GMMs + Travis Galoppo , Travis Galoppo , tgaloppo , FlytxtRnD + 2014-12-29 15:29:15 -0800 + Commit: 6cf6fdf, github.com/apache/spark/pull/3022 + + SPARK-4968: takeOrdered to skip reduce step in case mappers return no partitions + Yash Datta + 2014-12-29 13:49:45 -0800 + Commit: 9bc0df6, github.com/apache/spark/pull/3830 + + [SPARK-4409][MLlib] Additional Linear Algebra Utils + Burak Yavuz , Xiangrui Meng + 2014-12-29 13:24:26 -0800 + Commit: 02b55de, github.com/apache/spark/pull/3319 + + [Minor] Fix a typo of type parameter in JavaUtils.scala + Kousuke Saruta + 2014-12-29 12:05:08 -0800 + Commit: 8d72341, github.com/apache/spark/pull/3789 + + [SPARK-4946] [CORE] Using AkkaUtils.askWithReply in MapOutputTracker.askTracker to reduce the chance of the communicating problem + YanTangZhai , yantangzhai + 2014-12-29 11:30:54 -0800 + Commit: 815de54, github.com/apache/spark/pull/3785 + + Adde LICENSE Header to build/mvn, build/sbt and sbt/sbt + Kousuke Saruta + 2014-12-29 10:48:53 -0800 + Commit: 4cef05e, github.com/apache/spark/pull/3817 + + [SPARK-4982][DOC] `spark.ui.retainedJobs` description is wrong in Spark UI configuration guide + wangxiaojing + 2014-12-29 10:45:14 -0800 + Commit: 6645e52, github.com/apache/spark/pull/3818 + + [SPARK-4966][YARN]The MemoryOverhead value is setted not correctly + meiyoula <1039320815@qq.com> + 2014-12-29 08:20:30 -0600 + Commit: 14fa87b, github.com/apache/spark/pull/3797 + + [SPARK-4501][Core] - Create build/mvn to automatically download maven/zinc/scalac + Brennon York + 2014-12-27 13:25:18 -0800 + Commit: a3e51cc, github.com/apache/spark/pull/3707 + + [SPARK-4952][Core]Handle ConcurrentModificationExceptions in SparkEnv.environmentDetails + GuoQiang Li + 2014-12-26 23:31:29 -0800 + Commit: 080ceb7, github.com/apache/spark/pull/3788 + + [SPARK-4954][Core] add spark version infomation in log for standalone mode + Zhang, Liye + 2014-12-26 23:23:13 -0800 + Commit: 786808a, github.com/apache/spark/pull/3790 + + [SPARK-3955] Different versions between jackson-mapper-asl and jackson-c... + Jongyoul Lee + 2014-12-26 22:59:34 -0800 + Commit: 2483c1e, github.com/apache/spark/pull/3716 + + HOTFIX: Slight tweak on previous commit. + Patrick Wendell + 2014-12-26 22:55:04 -0800 + Commit: 82bf4be + + [SPARK-3787][BUILD] Assembly jar name is wrong when we build with sbt omitting -Dhadoop.version + Kousuke Saruta + 2014-12-26 22:52:04 -0800 + Commit: de95c57, github.com/apache/spark/pull/3046 + + MAINTENANCE: Automated closing of pull requests. + Patrick Wendell + 2014-12-26 22:39:56 -0800 + Commit: 534f24b, github.com/apache/spark/pull/3456 + + SPARK-4971: Fix typo in BlockGenerator comment + CodingCat + 2014-12-26 12:03:22 -0800 + Commit: fda4331, github.com/apache/spark/pull/3807 + + [SPARK-4608][Streaming] Reorganize StreamingContext implicit to improve API convenience + zsxwing + 2014-12-25 19:46:05 -0800 + Commit: f9ed2b6, github.com/apache/spark/pull/3464 + + [SPARK-4537][Streaming] Expand StreamingSource to add more metrics + jerryshao + 2014-12-25 19:39:49 -0800 + Commit: f205fe4, github.com/apache/spark/pull/3466 + + [EC2] Update mesos/spark-ec2 branch to branch-1.3 + Nicholas Chammas + 2014-12-25 14:16:50 -0800 + Commit: ac82785, github.com/apache/spark/pull/3804 + + [EC2] Update default Spark version to 1.2.0 + Nicholas Chammas + 2014-12-25 14:13:12 -0800 + Commit: b6b6393, github.com/apache/spark/pull/3793 + + Fix "Building Spark With Maven" link in README.md + Denny Lee + 2014-12-25 14:05:55 -0800 + Commit: 08b18c7, github.com/apache/spark/pull/3802 + + [SPARK-4953][Doc] Fix the description of building Spark with YARN + Kousuke Saruta + 2014-12-25 07:05:43 -0800 + Commit: 11dd993, github.com/apache/spark/pull/3787 + + [SPARK-4873][Streaming] Use `Future.zip` instead of `Future.flatMap`(for-loop) in WriteAheadLogBasedBlockHandler + zsxwing + 2014-12-24 19:49:41 -0800 + Commit: b4d0db8, github.com/apache/spark/pull/3721 + + SPARK-4297 [BUILD] Build warning fixes omnibus + Sean Owen + 2014-12-24 13:32:51 -0800 + Commit: 29fabb1, github.com/apache/spark/pull/3157 + + [SPARK-4881][Minor] Use SparkConf#getBoolean instead of get().toBoolean + Kousuke Saruta + 2014-12-23 19:14:34 -0800 + Commit: 199e59a, github.com/apache/spark/pull/3733 + + [SPARK-4860][pyspark][sql] speeding up `sample()` and `takeSample()` + jbencook , J. Benjamin Cook + 2014-12-23 17:46:24 -0800 + Commit: fd41eb9, github.com/apache/spark/pull/3764 + + [SPARK-4606] Send EOF to child JVM when there's no more data to read. + Marcelo Vanzin + 2014-12-23 16:02:59 -0800 + Commit: 7e2deb7, github.com/apache/spark/pull/3460 + + [SPARK-4671][Streaming]Do not replicate streaming block when WAL is enabled + jerryshao + 2014-12-23 15:45:53 -0800 + Commit: 3f5f4cc, github.com/apache/spark/pull/3534 + + [SPARK-4802] [streaming] Remove receiverInfo once receiver is de-registered + Ilayaperumal Gopinathan + 2014-12-23 15:14:54 -0800 + Commit: 10d69e9, github.com/apache/spark/pull/3647 + + [SPARK-4913] Fix incorrect event log path + Liang-Chi Hsieh + 2014-12-23 14:58:33 -0800 + Commit: 96281cd, github.com/apache/spark/pull/3755 + + [SPARK-4730][YARN] Warn against deprecated YARN settings + Andrew Or + 2014-12-23 14:28:36 -0800 + Commit: 27c5399, github.com/apache/spark/pull/3590 + + [SPARK-4914][Build] Cleans lib_managed before compiling with Hive 0.13.1 + Cheng Lian + 2014-12-23 12:54:20 -0800 + Commit: 395b771, github.com/apache/spark/pull/3756 + + [SPARK-4932] Add help comments in Analytics + Takeshi Yamamuro + 2014-12-23 12:39:41 -0800 + Commit: 9c251c5, github.com/apache/spark/pull/3775 + + [SPARK-4834] [standalone] Clean up application files after app finishes. + Marcelo Vanzin + 2014-12-23 12:02:08 -0800 + Commit: dd15536, github.com/apache/spark/pull/3705 + + [SPARK-4931][Yarn][Docs] Fix the format of running-on-yarn.md + zsxwing + 2014-12-23 11:18:06 -0800 + Commit: 2d215ae, github.com/apache/spark/pull/3774 + + [SPARK-4890] Ignore downloaded EC2 libs + Nicholas Chammas + 2014-12-23 11:12:16 -0800 + Commit: 2823c7f, github.com/apache/spark/pull/3770 + + [Docs] Minor typo fixes + Nicholas Chammas + 2014-12-22 22:54:32 -0800 + Commit: 0e532cc, github.com/apache/spark/pull/3772 + + [SPARK-4907][MLlib] Inconsistent loss and gradient in LeastSquaresGradient compared with R + DB Tsai + 2014-12-22 16:42:55 -0800 + Commit: a96b727, github.com/apache/spark/pull/3746 + + [SPARK-4818][Core] Add 'iterator' to reduce memory consumed by join + zsxwing + 2014-12-22 14:26:28 -0800 + Commit: c233ab3, github.com/apache/spark/pull/3671 + + [SPARK-4920][UI]:current spark version in UI is not striking. + genmao.ygm + 2014-12-22 14:14:39 -0800 + Commit: de9d7d2, github.com/apache/spark/pull/3763 + + [Minor] Fix scala doc + Liang-Chi Hsieh + 2014-12-22 14:13:31 -0800 + Commit: a61aa66, github.com/apache/spark/pull/3751 + + [SPARK-4864] Add documentation to Netty-based configs + Aaron Davidson + 2014-12-22 13:09:22 -0800 + Commit: fbca6b6, github.com/apache/spark/pull/3713 + + [SPARK-4079] [CORE] Consolidates Errors if a CompressionCodec is not available + Kostas Sakellis + 2014-12-22 13:07:01 -0800 + Commit: 7c0ed13, github.com/apache/spark/pull/3119 + + SPARK-4447. Remove layers of abstraction in YARN code no longer needed after dropping yarn-alpha + Sandy Ryza + 2014-12-22 12:23:43 -0800 + Commit: d62da64, github.com/apache/spark/pull/3652 + + [SPARK-4733] Add missing prameter comments in ShuffleDependency + Takeshi Yamamuro + 2014-12-22 12:19:23 -0800 + Commit: fb8e85e, github.com/apache/spark/pull/3594 + + [Minor] Improve some code in BroadcastTest for short + carlmartin + 2014-12-22 12:13:53 -0800 + Commit: 1d9788e, github.com/apache/spark/pull/3750 + + [SPARK-4883][Shuffle] Add a name to the directoryCleaner thread + zsxwing + 2014-12-22 12:11:36 -0800 + Commit: 8773705, github.com/apache/spark/pull/3734 + + [SPARK-4870] Add spark version to driver log + Zhang, Liye + 2014-12-22 11:36:49 -0800 + Commit: 39272c8, github.com/apache/spark/pull/3717 + + [SPARK-4915][YARN] Fix classname to be specified for external shuffle service. + Tsuyoshi Ozawa + 2014-12-22 11:28:05 -0800 + Commit: 96606f6, github.com/apache/spark/pull/3757 + + [SPARK-4918][Core] Reuse Text in saveAsTextFile + zsxwing + 2014-12-22 11:20:00 -0800 + Commit: 93b2f3a, github.com/apache/spark/pull/3762 + + [SPARK-2075][Core] Make the compiler generate same bytes code for Hadoop 1.+ and Hadoop 2.+ + zsxwing + 2014-12-21 22:10:19 -0800 + Commit: 6ee6aa7, github.com/apache/spark/pull/3740 + + SPARK-4910 [CORE] build failed (use of FileStatus.isFile in Hadoop 1.x) + Sean Owen + 2014-12-21 13:16:57 -0800 + Commit: c6a3c0d, github.com/apache/spark/pull/3754 + + [Minor] Build Failed: value defaultProperties not found + huangzhaowei + 2014-12-19 23:32:56 -0800 + Commit: a764960, github.com/apache/spark/pull/3749 + + [SPARK-4140] Document dynamic allocation + Andrew Or , Tsuyoshi Ozawa + 2014-12-19 19:36:20 -0800 + Commit: 15c03e1, github.com/apache/spark/pull/3731 + + [SPARK-4831] Do not include SPARK_CLASSPATH if empty + Daniel Darabos + 2014-12-19 19:32:39 -0800 + Commit: 7cb3f54, github.com/apache/spark/pull/3678 + + SPARK-2641: Passing num executors to spark arguments from properties file + Kanwaljit Singh + 2014-12-19 19:25:39 -0800 + Commit: 1d64812, github.com/apache/spark/pull/1657 + + [SPARK-3060] spark-shell.cmd doesn't accept application options in Windows OS + Masayoshi TSUZUKI + 2014-12-19 19:19:53 -0800 + Commit: 8d93247, github.com/apache/spark/pull/3350 + + change signature of example to match released code + Eran Medan + 2014-12-19 18:29:36 -0800 + Commit: c25c669, github.com/apache/spark/pull/3747 + + [SPARK-2261] Make event logger use a single file. + Marcelo Vanzin + 2014-12-19 18:21:15 -0800 + Commit: 4564519, github.com/apache/spark/pull/1222 + + [SPARK-4890] Upgrade Boto to 2.34.0; automatically download Boto from PyPi instead of packaging it + Josh Rosen + 2014-12-19 17:02:37 -0800 + Commit: c28083f, github.com/apache/spark/pull/3737 + + [SPARK-4896] don’t redundantly overwrite executor JAR deps + Ryan Williams + 2014-12-19 15:24:41 -0800 + Commit: 7981f96, github.com/apache/spark/pull/2848 + + [SPARK-4889] update history server example cmds + Ryan Williams + 2014-12-19 13:56:04 -0800 + Commit: cdb2c64, github.com/apache/spark/pull/3736 + + Small refactoring to pass SparkEnv into Executor rather than creating SparkEnv in Executor. + Reynold Xin + 2014-12-19 12:51:12 -0800 + Commit: 336cd34, github.com/apache/spark/pull/3738 + + [Build] Remove spark-staging-1038 + scwf + 2014-12-19 08:29:38 -0800 + Commit: 8e253eb, github.com/apache/spark/pull/3743 + + [SPARK-4901] [SQL] Hot fix for ByteWritables.copyBytes + Cheng Hao + 2014-12-19 08:04:41 -0800 + Commit: 5479450, github.com/apache/spark/pull/3742 + + SPARK-3428. TaskMetrics for running tasks is missing GC time metrics + Sandy Ryza + 2014-12-18 22:40:44 -0800 + Commit: 283263f, github.com/apache/spark/pull/3684 + + [SPARK-4674] Refactor getCallSite + Liang-Chi Hsieh + 2014-12-18 21:41:02 -0800 + Commit: d7fc69a, github.com/apache/spark/pull/3532 + + [SPARK-4728][MLLib] Add exponential, gamma, and log normal sampling to MLlib da... + RJ Nowling + 2014-12-18 21:00:49 -0800 + Commit: ee1fb97, github.com/apache/spark/pull/3680 + + [SPARK-4861][SQL] Refactory command in spark sql + wangfei , scwf + 2014-12-18 20:24:56 -0800 + Commit: c3d91da, github.com/apache/spark/pull/3712 + + [SPARK-4573] [SQL] Add SettableStructObjectInspector support in "wrap" function + Cheng Hao + 2014-12-18 20:21:52 -0800 + Commit: ae9f128, github.com/apache/spark/pull/3429 + + [SPARK-2554][SQL] Supporting SumDistinct partial aggregation + ravipesala + 2014-12-18 20:19:10 -0800 + Commit: 7687415, github.com/apache/spark/pull/3348 + + [SPARK-4693] [SQL] PruningPredicates may be wrong if predicates contains an empty AttributeSet() references + YanTangZhai , yantangzhai + 2014-12-18 20:13:46 -0800 + Commit: e7de7e5, github.com/apache/spark/pull/3556 + + [SPARK-4756][SQL] FIX: sessionToActivePool grow infinitely, even as sessions expire + guowei2 + 2014-12-18 20:10:23 -0800 + Commit: 22ddb6e, github.com/apache/spark/pull/3617 + + [SPARK-3928][SQL] Support wildcard matches on Parquet files. + Thu Kyaw + 2014-12-18 20:08:32 -0800 + Commit: b68bc6d, github.com/apache/spark/pull/3407 + + [SPARK-2663] [SQL] Support the Grouping Set + Cheng Hao + 2014-12-18 18:58:29 -0800 + Commit: f728e0f, github.com/apache/spark/pull/1567 + + [SPARK-4754] Refactor SparkContext into ExecutorAllocationClient + Andrew Or + 2014-12-18 17:37:42 -0800 + Commit: 9804a75, github.com/apache/spark/pull/3614 + + [SPARK-4837] NettyBlockTransferService should use spark.blockManager.port config + Aaron Davidson + 2014-12-18 16:43:16 -0800 + Commit: 105293a, github.com/apache/spark/pull/3688 + + SPARK-4743 - Use SparkEnv.serializer instead of closureSerializer in aggregateByKey and foldByKey + Ivan Vergiliev + 2014-12-18 16:29:36 -0800 + Commit: f9f58b9, github.com/apache/spark/pull/3605 + + [SPARK-4884]: Improve Partition docs + Madhu Siddalingaiah + 2014-12-18 16:00:53 -0800 + Commit: d5a596d, github.com/apache/spark/pull/3722 + + [SPARK-4880] remove spark.locality.wait in Analytics + Ernest + 2014-12-18 15:42:26 -0800 + Commit: a7ed6f3, github.com/apache/spark/pull/3730 + + [SPARK-4887][MLlib] Fix a bad unittest in LogisticRegressionSuite + DB Tsai + 2014-12-18 13:55:49 -0800 + Commit: 59a49db, github.com/apache/spark/pull/3735 + + [SPARK-3607] ConnectionManager threads.max configs on the thread pools don't work + Ilya Ganelin + 2014-12-18 12:53:18 -0800 + Commit: 3720057, github.com/apache/spark/pull/3664 + + Add mesos specific configurations into doc + Timothy Chen + 2014-12-18 12:15:53 -0800 + Commit: d9956f8, github.com/apache/spark/pull/3349 + + SPARK-3779. yarn spark.yarn.applicationMaster.waitTries config should be... + Sandy Ryza + 2014-12-18 12:19:07 -0600 + Commit: 253b72b, github.com/apache/spark/pull/3471 + + [SPARK-4461][YARN] pass extra java options to yarn application master + Zhan Zhang + 2014-12-18 10:01:46 -0600 + Commit: 3b76469, github.com/apache/spark/pull/3409 + + [SPARK-4822] Use sphinx tags for Python doc annotations + lewuathe + 2014-12-17 17:31:24 -0800 + Commit: 3cd5161, github.com/apache/spark/pull/3685 + + MAINTENANCE: Automated closing of pull requests. + Patrick Wendell + 2014-12-17 15:50:10 -0800 + Commit: ca12608, github.com/apache/spark/pull/3137 + + [SPARK-3891][SQL] Add array support to percentile, percentile_approx and constant inspectors support + Venkata Ramana G , Venkata Ramana Gollamudi + 2014-12-17 15:41:35 -0800 + Commit: f33d550, github.com/apache/spark/pull/2802 + + [SPARK-4856] [SQL] NullType instead of StringType when sampling against empty string or nul... + Cheng Hao + 2014-12-17 15:01:59 -0800 + Commit: 8d0d2a6, github.com/apache/spark/pull/3708 + + [HOTFIX][SQL] Fix parquet filter suite + Michael Armbrust + 2014-12-17 14:27:02 -0800 + Commit: 19c0faa, github.com/apache/spark/pull/3727 + + [SPARK-4821] [mllib] [python] [docs] Fix for pyspark.mllib.rand doc + Joseph K. Bradley + 2014-12-17 14:12:46 -0800 + Commit: affc3f4, github.com/apache/spark/pull/3669 + + [SPARK-3739] [SQL] Update the split num base on block size for table scanning + Cheng Hao + 2014-12-17 13:39:36 -0800 + Commit: 636d9fc, github.com/apache/spark/pull/2589 + + [SPARK-4755] [SQL] sqrt(negative value) should return null + Daoyuan Wang + 2014-12-17 12:51:27 -0800 + Commit: 902e4d5, github.com/apache/spark/pull/3616 + + [SPARK-4493][SQL] Don't pushdown Eq, NotEq, Lt, LtEq, Gt and GtEq predicates with nulls for Parquet + Cheng Lian + 2014-12-17 12:48:04 -0800 + Commit: 6277135, github.com/apache/spark/pull/3367 + + [SPARK-3698][SQL] Fix case insensitive resolution of GetField. + Michael Armbrust + 2014-12-17 12:43:51 -0800 + Commit: 7ad579e, github.com/apache/spark/pull/3724 + + [SPARK-4694]Fix HiveThriftServer2 cann't stop In Yarn HA mode. + carlmartin + 2014-12-17 12:24:03 -0800 + Commit: 4782def, github.com/apache/spark/pull/3576 + + [SPARK-4625] [SQL] Add sort by for DSL & SimpleSqlParser + Cheng Hao + 2014-12-17 12:01:57 -0800 + Commit: 5fdcbdc, github.com/apache/spark/pull/3481 + + [SPARK-4595][Core] Fix MetricsServlet not work issue + Saisai Shao , Josh Rosen , jerryshao + 2014-12-17 11:47:44 -0800 + Commit: cf50631, github.com/apache/spark/pull/3444 + + [HOTFIX] Fix RAT exclusion for known_translations file + Josh Rosen + 2014-12-16 23:00:25 -0800 + Commit: 3d0c37b, github.com/apache/spark/pull/3719 + + [Release] Update contributors list format and sort it + Andrew Or + 2014-12-16 22:11:03 -0800 + Commit: 4e1112e + + [SPARK-4618][SQL] Make foreign DDL commands options case-insensitive + scwf , wangfei + 2014-12-16 21:26:36 -0800 + Commit: 6069880, github.com/apache/spark/pull/3470 + + [SPARK-4866] support StructType as key in MapType + Davies Liu + 2014-12-16 21:23:28 -0800 + Commit: ec5c427, github.com/apache/spark/pull/3714 + + [SPARK-4375] [SQL] Add 0 argument support for udf + Cheng Hao + 2014-12-16 21:21:11 -0800 + Commit: 770d815, github.com/apache/spark/pull/3595 + + [SPARK-4720][SQL] Remainder should also return null if the divider is 0. + Takuya UESHIN + 2014-12-16 21:19:57 -0800 + Commit: ddc7ba3, github.com/apache/spark/pull/3581 + + [SPARK-4744] [SQL] Short circuit evaluation for AND & OR in CodeGen + Cheng Hao + 2014-12-16 21:18:39 -0800 + Commit: 0aa834a, github.com/apache/spark/pull/3606 + + [SPARK-4798][SQL] A new set of Parquet testing API and test suites + Cheng Lian + 2014-12-16 21:16:03 -0800 + Commit: 3b395e1, github.com/apache/spark/pull/3644 + + [Release] Cache known author translations locally + Andrew Or + 2014-12-16 19:28:43 -0800 + Commit: b85044e + + [Release] Major improvements to generate contributors script + Andrew Or + 2014-12-16 17:55:27 -0800 + Commit: 6f80b74 + + [SPARK-4269][SQL] make wait time configurable in BroadcastHashJoin + Jacky Li + 2014-12-16 15:34:59 -0800 + Commit: fa66ef6, github.com/apache/spark/pull/3133 + + [SPARK-4827][SQL] Fix resolution of deeply nested Project(attr, Project(Star,...)). + Michael Armbrust + 2014-12-16 15:31:19 -0800 + Commit: a66c23e, github.com/apache/spark/pull/3674 + + [SPARK-4483][SQL]Optimization about reduce memory costs during the HashOuterJoin + tianyi , tianyi + 2014-12-16 15:22:29 -0800 + Commit: 30f6b85, github.com/apache/spark/pull/3375 + + [SPARK-4527][SQl]Add BroadcastNestedLoopJoin operator selection testsuite + wangxiaojing + 2014-12-16 14:45:56 -0800 + Commit: ea1315e, github.com/apache/spark/pull/3395 + + SPARK-4767: Add support for launching in a specified placement group to spark_ec2 + Holden Karau + 2014-12-16 14:37:04 -0800 + Commit: b0dfdbd, github.com/apache/spark/pull/3623 + + [SPARK-4812][SQL] Fix the initialization issue of 'codegenEnabled' + zsxwing + 2014-12-16 14:13:40 -0800 + Commit: 6530243, github.com/apache/spark/pull/3660 + + [SPARK-4847][SQL]Fix "extraStrategies cannot take effect in SQLContext" issue + jerryshao + 2014-12-16 14:08:28 -0800 + Commit: dc8280d, github.com/apache/spark/pull/3698 + + [DOCS][SQL] Add a Note on jsonFile having separate JSON objects per line + Peter Vandenabeele + 2014-12-16 13:57:55 -0800 + Commit: 1a9e35e, github.com/apache/spark/pull/3517 + + [SQL] SPARK-4700: Add HTTP protocol spark thrift server + Judy Nash , judynash + 2014-12-16 12:37:26 -0800 + Commit: 17688d1, github.com/apache/spark/pull/3672 + + [SPARK-3405] add subnet-id and vpc-id options to spark_ec2.py + Mike Jennings , Mike Jennings + 2014-12-16 12:13:21 -0800 + Commit: d12c071, github.com/apache/spark/pull/2872 + + [SPARK-4855][mllib] testing the Chi-squared hypothesis test + jbencook + 2014-12-16 11:37:23 -0800 + Commit: cb48447, github.com/apache/spark/pull/3679 + + [SPARK-4437] update doc for WholeCombineFileRecordReader + Davies Liu , Josh Rosen + 2014-12-16 11:19:36 -0800 + Commit: ed36200, github.com/apache/spark/pull/3301 + + [SPARK-4841] fix zip with textFile() + Davies Liu + 2014-12-15 22:58:26 -0800 + Commit: c246b95, github.com/apache/spark/pull/3706 + + [SPARK-4792] Add error message when making local dir unsuccessfully + meiyoula <1039320815@qq.com> + 2014-12-15 22:30:18 -0800 + Commit: c762877, github.com/apache/spark/pull/3635 + + SPARK-4814 [CORE] Enable assertions in SBT, Maven tests / AssertionError from Hive's LazyBinaryInteger + Sean Owen + 2014-12-15 17:12:05 -0800 + Commit: 81112e4, github.com/apache/spark/pull/3692 + + [Minor][Core] fix comments in MapOutputTracker + wangfei + 2014-12-15 16:46:21 -0800 + Commit: 5c24759, github.com/apache/spark/pull/3700 + + SPARK-785 [CORE] ClosureCleaner not invoked on most PairRDDFunctions + Sean Owen + 2014-12-15 16:06:15 -0800 + Commit: 2a28bc6, github.com/apache/spark/pull/3690 + + [SPARK-4668] Fix some documentation typos. + Ryan Williams + 2014-12-15 14:52:17 -0800 + Commit: 8176b7a, github.com/apache/spark/pull/3523 + + [SPARK-1037] The name of findTaskFromList & findTask in TaskSetManager.scala is confusing + Ilya Ganelin + 2014-12-15 14:51:15 -0800 + Commit: 38703bb, github.com/apache/spark/pull/3665 + + [SPARK-4826] Fix generation of temp file names in WAL tests + Josh Rosen + 2014-12-15 14:33:43 -0800 + Commit: f6b8591, github.com/apache/spark/pull/3695. + + [SPARK-4494][mllib] IDFModel.transform() add support for single vector + Yuu ISHIKAWA + 2014-12-15 13:44:15 -0800 + Commit: 8098fab, github.com/apache/spark/pull/3603 + + HOTFIX: Disabling failing block manager test + Patrick Wendell + 2014-12-15 10:54:45 -0800 + Commit: 4c06738 + + fixed spelling errors in documentation + Peter Klipfel + 2014-12-14 00:01:16 -0800 + Commit: 2a2983f, github.com/apache/spark/pull/3691 + + MAINTENANCE: Automated closing of pull requests. + Patrick Wendell + 2014-12-11 23:38:40 -0800 + Commit: ef84dab, github.com/apache/spark/pull/3488 + + [SPARK-4829] [SQL] add rule to fold count(expr) if expr is not null + Daoyuan Wang + 2014-12-11 22:56:42 -0800 + Commit: 41a3f93, github.com/apache/spark/pull/3676 + + [SPARK-4742][SQL] The name of Parquet File generated by AppendingParquetOutputFormat should be zero padded + Sasaki Toru + 2014-12-11 22:54:21 -0800 + Commit: 8091dd6, github.com/apache/spark/pull/3602 + + [SPARK-4825] [SQL] CTAS fails to resolve when created using saveAsTable + Cheng Hao + 2014-12-11 22:51:49 -0800 + Commit: 0abbff2, github.com/apache/spark/pull/3673 + + [SQL] enable empty aggr test case + Daoyuan Wang + 2014-12-11 22:50:18 -0800 + Commit: cbb634a, github.com/apache/spark/pull/3445 + + [SPARK-4828] [SQL] sum and avg on empty table should always return null + Daoyuan Wang + 2014-12-11 22:49:27 -0800 + Commit: acb3be6, github.com/apache/spark/pull/3675 + + [SQL] Remove unnecessary case in HiveContext.toHiveString + scwf + 2014-12-11 22:48:03 -0800 + Commit: d8cf678, github.com/apache/spark/pull/3563 + + [SPARK-4293][SQL] Make Cast be able to handle complex types. + Takuya UESHIN + 2014-12-11 22:45:25 -0800 + Commit: 3344803, github.com/apache/spark/pull/3150 + + [SPARK-4639] [SQL] Pass maxIterations in as a parameter in Analyzer + Jacky Li + 2014-12-11 22:44:27 -0800 + Commit: c152dde, github.com/apache/spark/pull/3499 + + [SPARK-4662] [SQL] Whitelist more unittest + Cheng Hao + 2014-12-11 22:43:02 -0800 + Commit: a7f07f5, github.com/apache/spark/pull/3522 + + [SPARK-4713] [SQL] SchemaRDD.unpersist() should not raise exception if it is not persisted + Cheng Hao + 2014-12-11 22:41:36 -0800 + Commit: bf40cf8, github.com/apache/spark/pull/3572 + + [SPARK-4806] Streaming doc update for 1.2 + Tathagata Das , Josh Rosen , Josh Rosen + 2014-12-11 06:21:23 -0800 + Commit: b004150, github.com/apache/spark/pull/3653 + + [SPARK-4791] [sql] Infer schema from case class with multiple constructors + Joseph K. Bradley + 2014-12-10 23:41:15 -0800 + Commit: 2a5b5fd, github.com/apache/spark/pull/3646 + + [CORE]codeStyle: uniform ConcurrentHashMap define in StorageLevel.scala with other places + Zhang, Liye + 2014-12-10 20:44:59 -0800 + Commit: 57d37f9, github.com/apache/spark/pull/2793 + + SPARK-3526 Add section about data locality to the tuning guide + Andrew Ash + 2014-12-10 15:01:15 -0800 + Commit: 652b781, github.com/apache/spark/pull/2519 + + MAINTENANCE: Automated closing of pull requests. + Patrick Wendell + 2014-12-10 14:41:16 -0800 + Commit: 36bdb5b, github.com/apache/spark/pull/2883 + + [SPARK-4759] Fix driver hanging from coalescing partitions + Andrew Or + 2014-12-10 14:27:53 -0800 + Commit: 4f93d0c, github.com/apache/spark/pull/3633 + + [SPARK-4569] Rename 'externalSorting' in Aggregator + Ilya Ganelin + 2014-12-10 14:19:37 -0800 + Commit: 447ae2d, github.com/apache/spark/pull/3666 + + [SPARK-4793] [Deploy] ensure .jar at end of line + Daoyuan Wang + 2014-12-10 13:29:27 -0800 + Commit: e230da1, github.com/apache/spark/pull/3641 + + [SPARK-4215] Allow requesting / killing executors only in YARN mode + Andrew Or + 2014-12-10 12:48:24 -0800 + Commit: faa8fd8, github.com/apache/spark/pull/3615 + + [SPARK-4771][Docs] Document standalone cluster supervise mode + Andrew Or + 2014-12-10 12:41:36 -0800 + Commit: 5621283, github.com/apache/spark/pull/3627 + + [SPARK-4329][WebUI] HistoryPage pagenation + Kousuke Saruta + 2014-12-10 12:29:00 -0800 + Commit: 0fc637b, github.com/apache/spark/pull/3194 + + [SPARK-4161]Spark shell class path is not correctly set if "spark.driver.extraClassPath" is set in defaults.conf + GuoQiang Li + 2014-12-10 12:24:04 -0800 + Commit: 742e709, github.com/apache/spark/pull/3050 + + [SPARK-4772] Clear local copies of accumulators as soon as we're done with them + Nathan Kronenfeld + 2014-12-09 23:53:17 -0800 + Commit: 94b377f, github.com/apache/spark/pull/3570 + + [Minor] Use tag for help icon in web UI page header + Josh Rosen + 2014-12-09 23:47:05 -0800 + Commit: f79c1cf, github.com/apache/spark/pull/3659 + + Config updates for the new shuffle transport. + Reynold Xin + 2014-12-09 19:29:09 -0800 + Commit: 9bd9334, github.com/apache/spark/pull/3657 + + [SPARK-4740] Create multiple concurrent connections between two peer nodes in Netty. + Reynold Xin + 2014-12-09 17:49:59 -0800 + Commit: 2b9b726, github.com/apache/spark/pull/3625 + + SPARK-4805 [CORE] BlockTransferMessage.toByteArray() trips assertion + Sean Owen + 2014-12-09 16:38:27 -0800 + Commit: d8f84f2, github.com/apache/spark/pull/3650 + + SPARK-4567. Make SparkJobInfo and SparkStageInfo serializable + Sandy Ryza + 2014-12-09 16:26:07 -0800 + Commit: 5e4c06f, github.com/apache/spark/pull/3426 + + [SPARK-4714] BlockManager.dropFromMemory() should check whether block has been removed after synchronizing on BlockInfo instance. + hushan[胡珊] + 2014-12-09 15:11:20 -0800 + Commit: 30dca92, github.com/apache/spark/pull/3574 + + [SPARK-4765] Make GC time always shown in UI. + Kay Ousterhout + 2014-12-09 15:10:36 -0800 + Commit: 1f51106, github.com/apache/spark/pull/3622 + + [SPARK-4691][shuffle] Restructure a few lines in shuffle code + maji2014 + 2014-12-09 13:13:12 -0800 + Commit: b310744, github.com/apache/spark/pull/3553 + + [SPARK-874] adding a --wait flag + jbencook + 2014-12-09 12:16:19 -0800 + Commit: 61f1a70, github.com/apache/spark/pull/3567 + + SPARK-4338. [YARN] Ditch yarn-alpha. + Sandy Ryza + 2014-12-09 11:02:43 -0800 + Commit: 912563a, github.com/apache/spark/pull/3215 + + [SPARK-4785][SQL] Initilize Hive UDFs on the driver and serialize them with a wrapper + Cheng Hao , Cheng Lian + 2014-12-09 10:28:15 -0800 + Commit: 383c555, github.com/apache/spark/pull/3640 + + [SPARK-3154][STREAMING] Replace ConcurrentHashMap with mutable.HashMap and remove @volatile from 'stopped' + zsxwing + 2014-12-08 23:54:15 -0800 + Commit: bcb5cda, github.com/apache/spark/pull/3634 + + [SPARK-4769] [SQL] CTAS does not work when reading from temporary tables + Cheng Hao + 2014-12-08 17:39:12 -0800 + Commit: 51b1fe1, github.com/apache/spark/pull/3336 + + [SQL] remove unnecessary import in spark-sql + Jacky Li + 2014-12-08 17:27:46 -0800 + Commit: 9443843, github.com/apache/spark/pull/3630 + + SPARK-4770. [DOC] [YARN] spark.scheduler.minRegisteredResourcesRatio doc... + Sandy Ryza + 2014-12-08 16:28:36 -0800 + Commit: cda94d1, github.com/apache/spark/pull/3624 + + SPARK-3926 [CORE] Reopened: result of JavaRDD collectAsMap() is not serializable + Sean Owen + 2014-12-08 16:13:03 -0800 + Commit: e829bfa, github.com/apache/spark/pull/3587 + + [SPARK-4750] Dynamic allocation - synchronize kills + Andrew Or + 2014-12-08 16:02:33 -0800 + Commit: 65f929d, github.com/apache/spark/pull/3612 + + [SPARK-4774] [SQL] Makes HiveFromSpark more portable + Kostas Sakellis + 2014-12-08 15:44:18 -0800 + Commit: d6a972b, github.com/apache/spark/pull/3628 + + [SPARK-4764] Ensure that files are fetched atomically + Christophe Préaud + 2014-12-08 11:44:54 -0800 + Commit: ab2abcb, github.com/apache/spark/pull/2855 + + [SPARK-4620] Add unpersist in Graph and GraphImpl + Takeshi Yamamuro + 2014-12-07 19:42:02 -0800 + Commit: 8817fc7, github.com/apache/spark/pull/3476 + + [SPARK-4646] Replace Scala.util.Sorting.quickSort with Sorter(TimSort) in Spark + Takeshi Yamamuro + 2014-12-07 19:36:08 -0800 + Commit: 2e6b736, github.com/apache/spark/pull/3507 + + [SPARK-3623][GraphX] GraphX should support the checkpoint operation + GuoQiang Li + 2014-12-06 00:56:51 -0800 + Commit: e895e0c, github.com/apache/spark/pull/2631 + + Streaming doc : do you mean inadvertently? + CrazyJvm + 2014-12-05 13:42:13 -0800 + Commit: 6eb1b6f, github.com/apache/spark/pull/3620 + + [SPARK-4005][CORE] handle message replies in receive instead of in the individual private methods + Zhang, Liye + 2014-12-05 12:00:32 -0800 + Commit: 98a7d09, github.com/apache/spark/pull/2853 + + [SPARK-4761][SQL] Enables Kryo by default in Spark SQL Thrift server + Cheng Lian + 2014-12-05 10:27:40 -0800 + Commit: 6f61e1f, github.com/apache/spark/pull/3621 + + [SPARK-4753][SQL] Use catalyst for partition pruning in newParquet. + Michael Armbrust + 2014-12-04 22:25:21 -0800 + Commit: f5801e8, github.com/apache/spark/pull/3613 + + Revert "SPARK-2624 add datanucleus jars to the container in yarn-cluster" + Andrew Or + 2014-12-04 21:53:49 -0800 + Commit: fd85253 + + Revert "[HOT FIX] [YARN] Check whether `/lib` exists before listing its files" + Andrew Or + 2014-12-04 21:53:38 -0800 + Commit: 87437df + + [SPARK-4464] Description about configuration options need to be modified in docs. + Masayoshi TSUZUKI + 2014-12-04 19:33:02 -0800 + Commit: ca37903, github.com/apache/spark/pull/3329 + + Fix typo in Spark SQL docs. + Andy Konwinski + 2014-12-04 18:27:02 -0800 + Commit: 15cf3b0, github.com/apache/spark/pull/3611 + + [SPARK-4421] Wrong link in spark-standalone.html + Masayoshi TSUZUKI + 2014-12-04 18:14:36 -0800 + Commit: ddfc09c, github.com/apache/spark/pull/3279 + + [SPARK-4397] Move object RDD to the front of RDD.scala. + Reynold Xin + 2014-12-04 16:32:20 -0800 + Commit: ed92b47, github.com/apache/spark/pull/3580 + + [SPARK-4652][DOCS] Add docs about spark-git-repo option + lewuathe , Josh Rosen + 2014-12-04 15:14:36 -0800 + Commit: ab8177d, github.com/apache/spark/pull/3513 + + [SPARK-4459] Change groupBy type parameter from K to U + Saldanha + 2014-12-04 14:22:09 -0800 + Commit: 743a889, github.com/apache/spark/pull/3327 + + [SPARK-4745] Fix get_existing_cluster() function with multiple security groups + alexdebrie + 2014-12-04 14:13:59 -0800 + Commit: 794f3ae, github.com/apache/spark/pull/3596 + + [HOTFIX] Fixing two issues with the release script. + Patrick Wendell + 2014-12-04 12:11:41 -0800 + Commit: 8dae26f, github.com/apache/spark/pull/3608 + + [SPARK-4253] Ignore spark.driver.host in yarn-cluster and standalone-cluster modes + WangTaoTheTonic , WangTao + 2014-12-04 11:52:47 -0800 + Commit: 8106b1e, github.com/apache/spark/pull/3112 + + [SPARK-4683][SQL] Add a beeline.cmd to run on Windows + Cheng Lian + 2014-12-04 10:21:03 -0800 + Commit: 28c7aca, github.com/apache/spark/pull/3599 + + [FIX][DOC] Fix broken links in ml-guide.md + Xiangrui Meng + 2014-12-04 20:16:35 +0800 + Commit: 7e758d7, github.com/apache/spark/pull/3601 + + [SPARK-4575] [mllib] [docs] spark.ml pipelines doc + bug fixes + Joseph K. Bradley , jkbradley , Xiangrui Meng + 2014-12-04 17:00:06 +0800 + Commit: 469a6e5, github.com/apache/spark/pull/3588 + + [docs] Fix outdated comment in tuning guide + Joseph K. Bradley + 2014-12-04 00:59:32 -0800 + Commit: 529439b, github.com/apache/spark/pull/3592 + + [SQL] Minor: Avoid calling Seq#size in a loop + Aaron Davidson + 2014-12-04 00:58:42 -0800 + Commit: c6c7165, github.com/apache/spark/pull/3593 + + [SPARK-4685] Include all spark.ml and spark.mllib packages in JavaDoc's MLlib group + lewuathe , Xiangrui Meng + 2014-12-04 16:51:41 +0800 + Commit: 20bfea4, github.com/apache/spark/pull/3554 + + [SPARK-4719][API] Consolidate various narrow dep RDD classes with MapPartitionsRDD + Reynold Xin + 2014-12-04 00:45:57 -0800 + Commit: c3ad486, github.com/apache/spark/pull/3578 + + [SQL] remove unnecessary import + Jacky Li + 2014-12-04 00:43:55 -0800 + Commit: ed88db4, github.com/apache/spark/pull/3585 + + MAINTENANCE: Automated closing of pull requests. + Patrick Wendell + 2014-12-03 22:15:46 -0800 + Commit: 3cdae03, github.com/apache/spark/pull/1875 + + [Release] Correctly translate contributors name in release notes + Andrew Or + 2014-12-03 19:08:29 -0800 + Commit: a4dfb4e + + [SPARK-4580] [SPARK-4610] [mllib] [docs] Documentation for tree ensembles + DecisionTree API fix + Joseph K. Bradley , Joseph K. Bradley + 2014-12-04 09:57:50 +0800 + Commit: 657a888, github.com/apache/spark/pull/3461 + + [SPARK-4711] [mllib] [docs] Programming guide advice on choosing optimizer + Joseph K. Bradley + 2014-12-04 08:58:03 +0800 + Commit: 27ab0b8, github.com/apache/spark/pull/3569 + + [SPARK-4085] Propagate FetchFailedException when Spark fails to read local shuffle file. + Reynold Xin + 2014-12-03 16:28:24 -0800 + Commit: 1826372, github.com/apache/spark/pull/3579 + + [SPARK-4498][core] Don't transition ExecutorInfo to RUNNING until Driver adds Executor + Mark Hamstra + 2014-12-03 15:08:01 -0800 + Commit: 96b2785, github.com/apache/spark/pull/3550 + + [SPARK-4552][SQL] Avoid exception when reading empty parquet data through Hive + Michael Armbrust + 2014-12-03 14:13:35 -0800 + Commit: 513ef82, github.com/apache/spark/pull/3586 + + [HOT FIX] [YARN] Check whether `/lib` exists before listing its files + Andrew Or + 2014-12-03 13:56:23 -0800 + Commit: 90ec643, github.com/apache/spark/pull/3589 + + [SPARK-4642] Add description about spark.yarn.queue to running-on-YARN document. + Masayoshi TSUZUKI + 2014-12-03 13:16:24 -0800 + Commit: 692f493, github.com/apache/spark/pull/3500 + + [SPARK-4715][Core] Make sure tryToAcquire won't return a negative value + zsxwing + 2014-12-03 12:19:40 -0800 + Commit: edd3cd4, github.com/apache/spark/pull/3575 + + [SPARK-4701] Typo in sbt/sbt + Masayoshi TSUZUKI + 2014-12-03 12:08:00 -0800 + Commit: 96786e3, github.com/apache/spark/pull/3560 + + SPARK-2624 add datanucleus jars to the container in yarn-cluster + Jim Lim + 2014-12-03 11:16:02 -0800 + Commit: a975dc3, github.com/apache/spark/pull/3238 + + [SPARK-4717][MLlib] Optimize BLAS library to avoid de-reference multiple times in loop + DB Tsai + 2014-12-03 22:31:39 +0800 + Commit: d005429, github.com/apache/spark/pull/3577 + + [SPARK-4708][MLLib] Make k-mean runs two/three times faster with dense/sparse sample + DB Tsai + 2014-12-03 19:01:56 +0800 + Commit: 7fc49ed, github.com/apache/spark/pull/3565 + + [SPARK-4710] [mllib] Eliminate MLlib compilation warnings + Joseph K. Bradley + 2014-12-03 18:50:03 +0800 + Commit: 4ac2151, github.com/apache/spark/pull/3568 + + [SPARK-4397][Core] Change the 'since' value of '@deprecated' to '1.3.0' + zsxwing + 2014-12-03 02:05:17 -0800 + Commit: 8af551f, github.com/apache/spark/pull/3573 + + [SPARK-4672][Core]Checkpoint() should clear f to shorten the serialization chain + JerryLead , Lijie Xu + 2014-12-02 23:53:29 -0800 + Commit: 77be8b9, github.com/apache/spark/pull/3545 + + [SPARK-4672][GraphX]Non-transient PartitionsRDDs will lead to StackOverflow error + JerryLead , Lijie Xu + 2014-12-02 17:14:11 -0800 + Commit: 17c162f, github.com/apache/spark/pull/3544 + + [SPARK-4672][GraphX]Perform checkpoint() on PartitionsRDD to shorten the lineage + JerryLead , Lijie Xu + 2014-12-02 17:08:02 -0800 + Commit: fc0a147, github.com/apache/spark/pull/3549 + + [Release] Translate unknown author names automatically + Andrew Or + 2014-12-02 16:36:12 -0800 + Commit: 5da21f0 + + Minor nit style cleanup in GraphX. + Reynold Xin + 2014-12-02 14:40:26 -0800 + Commit: 2d4f6e7 + + [SPARK-4695][SQL] Get result using executeCollect + wangfei + 2014-12-02 14:30:44 -0800 + Commit: 3ae0cda, github.com/apache/spark/pull/3547 + + [SPARK-4670] [SQL] wrong symbol for bitwise not + Daoyuan Wang + 2014-12-02 14:25:12 -0800 + Commit: 1f5ddf1, github.com/apache/spark/pull/3528 + + [SPARK-4593][SQL] Return null when denominator is 0 + Daoyuan Wang + 2014-12-02 14:21:12 -0800 + Commit: f6df609, github.com/apache/spark/pull/3443 + + [SPARK-4676][SQL] JavaSchemaRDD.schema may throw NullType MatchError if sql has null + YanTangZhai , yantangzhai , Michael Armbrust + 2014-12-02 14:12:48 -0800 + Commit: 1066427, github.com/apache/spark/pull/3538 + + [SPARK-4663][sql]add finally to avoid resource leak + baishuo + 2014-12-02 12:12:03 -0800 + Commit: 69b6fed, github.com/apache/spark/pull/3526 + + [SPARK-4536][SQL] Add sqrt and abs to Spark SQL DSL + Kousuke Saruta + 2014-12-02 12:07:52 -0800 + Commit: e75e04f, github.com/apache/spark/pull/3401 + + Indent license header properly for interfaces.scala. + Reynold Xin + 2014-12-02 11:59:15 -0800 + Commit: b1f8fe3, github.com/apache/spark/pull/3552 + + [SPARK-4686] Link to allowed master URLs is broken + Kay Ousterhout + 2014-12-02 09:06:02 -0800 + Commit: d9a148b, github.com/apache/spark/pull/3542 + + [SPARK-4397][Core] Cleanup 'import SparkContext._' in core + zsxwing + 2014-12-02 00:18:41 -0800 + Commit: 6dfe38a, github.com/apache/spark/pull/3530 + + [SPARK-4611][MLlib] Implement the efficient vector norm + DB Tsai + 2014-12-02 11:40:43 +0800 + Commit: 64f3175, github.com/apache/spark/pull/3462 + + MAINTENANCE: Automated closing of pull requests. + Patrick Wendell + 2014-12-01 17:27:14 -0800 + Commit: b0a46d8, github.com/apache/spark/pull/1612 + + [SPARK-4268][SQL] Use #::: to get benefit from Stream in SqlLexical.allCaseVersions + zsxwing + 2014-12-01 16:39:54 -0800 + Commit: d3e02dd, github.com/apache/spark/pull/3132 + + [SPARK-4529] [SQL] support view with column alias + Daoyuan Wang + 2014-12-01 16:08:51 -0800 + Commit: 4df60a8, github.com/apache/spark/pull/3396 + + [SQL][DOC] Date type in SQL programming guide + Daoyuan Wang + 2014-12-01 14:03:57 -0800 + Commit: 5edbcbf, github.com/apache/spark/pull/3535 + + [SQL] Minor fix for doc and comment + wangfei + 2014-12-01 14:02:02 -0800 + Commit: 7b79957, github.com/apache/spark/pull/3533 + + [SPARK-4658][SQL] Code documentation issue in DDL of datasource API + ravipesala + 2014-12-01 13:31:27 -0800 + Commit: bc35381, github.com/apache/spark/pull/3516 + + [SPARK-4650][SQL] Supporting multi column support in countDistinct function like count(distinct c1,c2..) in Spark SQL + ravipesala , Michael Armbrust + 2014-12-01 13:26:44 -0800 + Commit: 6a9ff19, github.com/apache/spark/pull/3511 + + [SPARK-4358][SQL] Let BigDecimal do checking type compatibility + Liang-Chi Hsieh + 2014-12-01 13:17:56 -0800 + Commit: b57365a, github.com/apache/spark/pull/3208 + + [SQL] add @group tab in limit() and count() + Jacky Li + 2014-12-01 13:12:30 -0800 + Commit: bafee67, github.com/apache/spark/pull/3458 + + [SPARK-4258][SQL][DOC] Documents spark.sql.parquet.filterPushdown + Cheng Lian + 2014-12-01 13:09:51 -0800 + Commit: 5db8dca, github.com/apache/spark/pull/3440 + + Documentation: add description for repartitionAndSortWithinPartitions + Madhu Siddalingaiah + 2014-12-01 08:45:34 -0800 + Commit: 2b233f5, github.com/apache/spark/pull/3390 + + [SPARK-4661][Core] Minor code and docs cleanup + zsxwing + 2014-12-01 00:35:01 -0800 + Commit: 30a86ac, github.com/apache/spark/pull/3521 + + [SPARK-4664][Core] Throw an exception when spark.akka.frameSize > 2047 + zsxwing + 2014-12-01 00:32:54 -0800 + Commit: 1d238f2, github.com/apache/spark/pull/3527 + + SPARK-2192 [BUILD] Examples Data Not in Binary Distribution + Sean Owen + 2014-12-01 16:31:04 +0800 + Commit: 6384f42, github.com/apache/spark/pull/3480 + + Fix wrong file name pattern in .gitignore + Kousuke Saruta + 2014-12-01 00:29:28 -0800 + Commit: 97eb6d7, github.com/apache/spark/pull/3529 + + [SPARK-4632] version update + Prabeesh K + 2014-11-30 20:51:53 -0800 + Commit: 5e7a6dc, github.com/apache/spark/pull/3495 + + MAINTENANCE: Automated closing of pull requests. + Patrick Wendell + 2014-11-30 20:51:13 -0800 + Commit: 06dc1b1, github.com/apache/spark/pull/2915 + + [DOC] Fixes formatting typo in SQL programming guide + Cheng Lian + 2014-11-30 19:04:07 -0800 + Commit: 2a4d389, github.com/apache/spark/pull/3498 + + [SPARK-4656][Doc] Typo in Programming Guide markdown + lewuathe + 2014-11-30 17:18:50 -0800 + Commit: a217ec5, github.com/apache/spark/pull/3412 + + [SPARK-4623]Add the some error infomation if using spark-sql in yarn-cluster mode + carlmartin , huangzhaowei + 2014-11-30 16:19:41 -0800 + Commit: aea7a99, github.com/apache/spark/pull/3479 + + SPARK-2143 [WEB UI] Add Spark version to UI footer + Sean Owen + 2014-11-30 11:40:08 -0800 + Commit: 048ecca, github.com/apache/spark/pull/3410 + + [DOCS][BUILD] Add instruction to use change-version-to-2.11.sh in 'Building for Scala 2.11'. + Takuya UESHIN + 2014-11-30 00:10:31 -0500 + Commit: 0fcd24c, github.com/apache/spark/pull/3361 + + SPARK-4507: PR merge script should support closing multiple JIRA tickets + Takayuki Hasegawa + 2014-11-29 23:12:10 -0500 + Commit: 4316a7b, github.com/apache/spark/pull/3428 + + [SPARK-4505][Core] Add a ClassTag parameter to CompactBuffer[T] + zsxwing + 2014-11-29 20:23:08 -0500 + Commit: c062224, github.com/apache/spark/pull/3378 + + [SPARK-4057] Use -agentlib instead of -Xdebug in sbt-launch-lib.bash for debugging + Kousuke Saruta + 2014-11-29 20:14:14 -0500 + Commit: 938dc14, github.com/apache/spark/pull/2904 + + Include the key name when failing on an invalid value. + Stephen Haberman + 2014-11-29 20:12:05 -0500 + Commit: 95290bf, github.com/apache/spark/pull/3514 + + [SPARK-3398] [SPARK-4325] [EC2] Use EC2 status checks. + Nicholas Chammas + 2014-11-29 00:31:06 -0800 + Commit: 317e114, github.com/apache/spark/pull/3195 + + MAINTENANCE: Automated closing of pull requests. + Patrick Wendell + 2014-11-29 00:24:35 -0500 + Commit: 047ff57, github.com/apache/spark/pull/3451 + + [SPARK-4597] Use proper exception and reset variable in Utils.createTempDir() + Liang-Chi Hsieh + 2014-11-28 18:04:05 -0800 + Commit: 49fe879, github.com/apache/spark/pull/3449 + + SPARK-1450 [EC2] Specify the default zone in the EC2 script help + Sean Owen + 2014-11-28 17:43:38 -0500 + Commit: 48223d8, github.com/apache/spark/pull/3454 + + [SPARK-4584] [yarn] Remove security manager from Yarn AM. + Marcelo Vanzin + 2014-11-28 15:15:30 -0500 + Commit: 915f8ee, github.com/apache/spark/pull/3484 + + [SPARK-4193][BUILD] Disable doclint in Java 8 to prevent from build error. + Takuya UESHIN + 2014-11-28 13:00:15 -0500 + Commit: e464f0a, github.com/apache/spark/pull/3058 + + [SPARK-4643] [Build] Remove unneeded staging repositories from build + Daoyuan Wang + 2014-11-28 12:41:38 -0500 + Commit: 53ed7f1, github.com/apache/spark/pull/3504 + + Delete unnecessary function + KaiXinXiaoLei + 2014-11-28 12:34:07 -0500 + Commit: 052e658, github.com/apache/spark/pull/3224 + + [SPARK-4645][SQL] Disables asynchronous execution in Hive 0.13.1 HiveThriftServer2 + Cheng Lian + 2014-11-28 11:42:40 -0500 + Commit: 5b99bf2, github.com/apache/spark/pull/3506 + + [SPARK-4619][Storage]delete redundant time suffix + maji2014 + 2014-11-28 00:36:22 -0800 + Commit: ceb6281, github.com/apache/spark/pull/3475 + + [SPARK-4613][Core] Java API for JdbcRDD + Cheng Lian + 2014-11-27 18:01:14 -0800 + Commit: 120a350, github.com/apache/spark/pull/3478 + + [SPARK-4626] Kill a task only if the executorId is (still) registered with the scheduler + roxchkplusony + 2014-11-27 15:54:40 -0800 + Commit: 84376d3, github.com/apache/spark/pull/3483 + + SPARK-4170 [CORE] Closure problems when running Scala app that "extends App" + Sean Owen + 2014-11-27 09:03:17 -0800 + Commit: 5d7fe17, github.com/apache/spark/pull/3497 + + [Release] Automate generation of contributors list + Andrew Or + 2014-11-26 23:16:23 -0800 + Commit: c86e9bc + + [SPARK-732][SPARK-3628][CORE][RESUBMIT] eliminate duplicate update on accmulator + CodingCat + 2014-11-26 16:52:04 -0800 + Commit: 5af53ad, github.com/apache/spark/pull/2524 + + [SPARK-4614][MLLIB] Slight API changes in Matrix and Matrices + Xiangrui Meng + 2014-11-26 08:22:50 -0800 + Commit: 561d31d, github.com/apache/spark/pull/3468 + + Removing confusing TripletFields + Joseph E. Gonzalez + 2014-11-26 00:55:28 -0800 + Commit: 288ce58, github.com/apache/spark/pull/3472 + + [SPARK-4612] Reduce task latency and increase scheduling throughput by making configuration initialization lazy + Tathagata Das + 2014-11-25 23:15:58 -0800 + Commit: e7f4d25, github.com/apache/spark/pull/3463 + + [SPARK-4516] Avoid allocating Netty PooledByteBufAllocators unnecessarily + Aaron Davidson + 2014-11-26 00:32:45 -0500 + Commit: 346bc17, github.com/apache/spark/pull/3465 + + [SPARK-4516] Cap default number of Netty threads at 8 + Aaron Davidson + 2014-11-25 23:57:04 -0500 + Commit: f5f2d27, github.com/apache/spark/pull/3469 + + [SPARK-4604][MLLIB] make MatrixFactorizationModel public + Xiangrui Meng + 2014-11-25 20:11:40 -0800 + Commit: b5fb141, github.com/apache/spark/pull/3459 + + [HOTFIX]: Adding back without-hive dist + Patrick Wendell + 2014-11-25 23:10:19 -0500 + Commit: 4d95526 + + [SPARK-4583] [mllib] LogLoss for GradientBoostedTrees fix + doc updates + Joseph K. Bradley + 2014-11-25 20:10:15 -0800 + Commit: c251fd7, github.com/apache/spark/pull/3439 + + [Spark-4509] Revert EC2 tag-based cluster membership patch + Xiangrui Meng + 2014-11-25 16:07:09 -0800 + Commit: 7eba0fb, github.com/apache/spark/pull/3453 + + Fix SPARK-4471: blockManagerIdFromJson function throws exception while B... + hushan[胡珊] + 2014-11-25 15:51:08 -0800 + Commit: 9bdf5da, github.com/apache/spark/pull/3340 + + [SPARK-4546] Improve HistoryServer first time user experience + Andrew Or + 2014-11-25 15:48:02 -0800 + Commit: 9afcbe4, github.com/apache/spark/pull/3411 + + [SPARK-4592] Avoid duplicate worker registrations in standalone mode + Andrew Or + 2014-11-25 15:46:26 -0800 + Commit: 1b2ab1c, github.com/apache/spark/pull/3447 + + [SPARK-4196][SPARK-4602][Streaming] Fix serialization issue in PairDStreamFunctions.saveAsNewAPIHadoopFiles + Tathagata Das + 2014-11-25 14:16:27 -0800 + Commit: 8838ad7, github.com/apache/spark/pull/3457 + + [SPARK-4581][MLlib] Refactorize StandardScaler to improve the transformation performance + DB Tsai + 2014-11-25 11:07:11 -0800 + Commit: bf1a6aa, github.com/apache/spark/pull/3435 + + [SPARK-4601][Streaming] Set correct call site for streaming jobs so that it is displayed correctly on the Spark UI + Tathagata Das + 2014-11-25 06:50:36 -0800 + Commit: 69cd53e, github.com/apache/spark/pull/3455 + + [SPARK-4344][DOCS] adding documentation on spark.yarn.user.classpath.first + arahuja + 2014-11-25 08:23:41 -0600 + Commit: d240760, github.com/apache/spark/pull/3209 + + [SPARK-4381][Streaming]Add warning log when user set spark.master to local in Spark Streaming and there's no job executed + jerryshao + 2014-11-25 05:36:29 -0800 + Commit: fef27b2, github.com/apache/spark/pull/3244 + + [SPARK-4535][Streaming] Fix the error in comments + q00251598 + 2014-11-25 04:01:56 -0800 + Commit: a51118a, github.com/apache/spark/pull/3400 + + [SPARK-4526][MLLIB]GradientDescent get a wrong gradient value according to the gradient formula. + GuoQiang Li + 2014-11-25 02:01:19 -0800 + Commit: f515f94, github.com/apache/spark/pull/3399 + + [SPARK-4596][MLLib] Refactorize Normalizer to make code cleaner + DB Tsai + 2014-11-25 01:57:34 -0800 + Commit: 89f9122, github.com/apache/spark/pull/3446 + + [DOC][Build] Wrong cmd for build spark with apache hadoop 2.4.X and hive 12 + wangfei + 2014-11-24 22:32:39 -0800 + Commit: 0fe54cf, github.com/apache/spark/pull/3335 + + [SQL] Compute timeTaken correctly + w00228970 + 2014-11-24 21:17:24 -0800 + Commit: 723be60, github.com/apache/spark/pull/3423 + + [SPARK-4582][MLLIB] get raw vectors for further processing in Word2Vec + tkaessmann , tkaessmann + 2014-11-24 19:58:01 -0800 + Commit: 9ce2bf3, github.com/apache/spark/pull/3309 + + [SPARK-4525] Mesos should decline unused offers + Patrick Wendell , Jongyoul Lee + 2014-11-24 19:14:14 -0800 + Commit: f0afb62, github.com/apache/spark/pull/3436 + + Revert "[SPARK-4525] Mesos should decline unused offers" + Patrick Wendell + 2014-11-24 19:16:53 -0800 + Commit: a68d442 + + [SPARK-4525] Mesos should decline unused offers + Patrick Wendell , Jongyoul Lee + 2014-11-24 19:14:14 -0800 + Commit: b043c27, github.com/apache/spark/pull/3436 + + [SPARK-4266] [Web-UI] Reduce stage page load time. + Kay Ousterhout + 2014-11-24 18:03:10 -0800 + Commit: d24d5bf, github.com/apache/spark/pull/3328 + + [SPARK-4548] []SPARK-4517] improve performance of python broadcast + Davies Liu + 2014-11-24 17:17:03 -0800 + Commit: 6cf5076, github.com/apache/spark/pull/3417 + + [SPARK-4578] fix asDict() with nested Row() + Davies Liu + 2014-11-24 16:41:23 -0800 + Commit: 050616b, github.com/apache/spark/pull/3434 + + [SPARK-4562] [MLlib] speedup vector + Davies Liu + 2014-11-24 16:37:14 -0800 + Commit: b660de7, github.com/apache/spark/pull/3420 + + [SPARK-4518][SPARK-4519][Streaming] Refactored file stream to prevent files from being processed multiple times + Tathagata Das + 2014-11-24 13:50:20 -0800 + Commit: cb0e9b0, github.com/apache/spark/pull/3419 + + [SPARK-4145] Web UI job pages + Josh Rosen + 2014-11-24 13:18:14 -0800 + Commit: 4a90276, github.com/apache/spark/pull/3009 + + [SPARK-4487][SQL] Fix attribute reference resolution error when using ORDER BY. + Kousuke Saruta + 2014-11-24 12:54:37 -0800 + Commit: dd1c9cb, github.com/apache/spark/pull/3363 + + [SQL] Fix path in HiveFromSpark + scwf + 2014-11-24 12:49:08 -0800 + Commit: b384119, github.com/apache/spark/pull/3415 + + [SQL] Fix comment in HiveShim + Daniel Darabos + 2014-11-24 12:45:07 -0800 + Commit: d5834f0, github.com/apache/spark/pull/3432 + + [SPARK-4479][SQL] Avoids unnecessary defensive copies when sort based shuffle is on + Cheng Lian + 2014-11-24 12:43:45 -0800 + Commit: a6d7b61, github.com/apache/spark/pull/3422 + + SPARK-4457. Document how to build for Hadoop versions greater than 2.4 + Sandy Ryza + 2014-11-24 13:28:48 -0600 + Commit: 29372b6, github.com/apache/spark/pull/3322 + + [SPARK-4377] Fixed serialization issue by switching to akka provided serializer. + Prashant Sharma + 2014-11-22 14:05:38 -0800 + Commit: 9b2a3c6, github.com/apache/spark/pull/3402 + + [SPARK-4431][MLlib] Implement efficient foreachActive for dense and sparse vector + DB Tsai + 2014-11-21 18:15:07 -0800 + Commit: b5d17ef, github.com/apache/spark/pull/3288 + + [SPARK-4531] [MLlib] cache serialized java object + Davies Liu + 2014-11-21 15:02:31 -0800 + Commit: ce95bd8, github.com/apache/spark/pull/3397 + + SPARK-4532: Fix bug in detection of Hive in Spark 1.2 + Patrick Wendell + 2014-11-21 12:10:04 -0800 + Commit: a81918c, github.com/apache/spark/pull/3398 + + [SPARK-4397][Core] Reorganize 'implicit's to improve the API convenience + zsxwing + 2014-11-21 10:06:30 -0800 + Commit: 65b987c, github.com/apache/spark/pull/3262 + + [SPARK-4472][Shell] Print "Spark context available as sc." only when SparkContext is created... + zsxwing + 2014-11-21 00:42:43 -0800 + Commit: f1069b8, github.com/apache/spark/pull/3341 + + [Doc][GraphX] Remove unused png files. + Reynold Xin + 2014-11-21 00:30:58 -0800 + Commit: 28fdc6f + + [Doc][GraphX] Remove Motivation section and did some minor update. + Reynold Xin + 2014-11-21 00:29:02 -0800 + Commit: b97070e + + [SPARK-4522][SQL] Parse schema with missing metadata. + Michael Armbrust + 2014-11-20 20:34:43 -0800 + Commit: 90a6a46, github.com/apache/spark/pull/3392 + + add Sphinx as a dependency of building docs + Davies Liu + 2014-11-20 19:12:45 -0800 + Commit: 8cd6eea, github.com/apache/spark/pull/3388 + + [SPARK-4413][SQL] Parquet support through datasource API + Michael Armbrust + 2014-11-20 18:31:02 -0800 + Commit: 02ec058, github.com/apache/spark/pull/3269 + + [SPARK-4244] [SQL] Support Hive Generic UDFs with constant object inspector parameters + Cheng Hao + 2014-11-20 16:50:59 -0800 + Commit: 84d79ee, github.com/apache/spark/pull/3109 + + [SPARK-4477] [PySpark] remove numpy from RDDSampler + Davies Liu , Xiangrui Meng + 2014-11-20 16:40:25 -0800 + Commit: d39f2e9, github.com/apache/spark/pull/3351 + + [SQL] fix function description mistake + Jacky Li + 2014-11-20 15:48:36 -0800 + Commit: ad5f1f3, github.com/apache/spark/pull/3344 + + [SPARK-2918] [SQL] Support the CTAS in EXPLAIN command + Cheng Hao + 2014-11-20 15:46:00 -0800 + Commit: 6aa0fc9, github.com/apache/spark/pull/3357 + + [SPARK-4318][SQL] Fix empty sum distinct. + Takuya UESHIN + 2014-11-20 15:41:24 -0800 + Commit: 2c2e7a4, github.com/apache/spark/pull/3184 + + [SPARK-4513][SQL] Support relational operator '<=>' in Spark SQL + ravipesala + 2014-11-20 15:34:03 -0800 + Commit: 98e9419, github.com/apache/spark/pull/3387 + + [SPARK-4439] [MLlib] add python api for random forest + Davies Liu + 2014-11-20 15:31:28 -0800 + Commit: 1c53a5d, github.com/apache/spark/pull/3320 + + [SPARK-4228][SQL] SchemaRDD to JSON + Dan McClary + 2014-11-20 13:36:50 -0800 + Commit: b8e6886, github.com/apache/spark/pull/3213 + + [SPARK-3938][SQL] Names in-memory columnar RDD with corresponding table name + Cheng Lian + 2014-11-20 13:12:24 -0800 + Commit: abf2918, github.com/apache/spark/pull/3383 + + [SPARK-4486][MLLIB] Improve GradientBoosting APIs and doc + Xiangrui Meng + 2014-11-20 00:48:59 -0800 + Commit: 15cacc8, github.com/apache/spark/pull/3374 + + [SPARK-4446] [SPARK CORE] + Leolh + 2014-11-19 18:18:55 -0800 + Commit: e216ffa, github.com/apache/spark/pull/3306 + + [SPARK-4480] Avoid many small spills in external data structures + Andrew Or + 2014-11-19 18:07:27 -0800 + Commit: 0eb4a7f, github.com/apache/spark/pull/3353 + + [Spark-4484] Treat maxResultSize as unlimited when set to 0; improve error message + Nishkam Ravi , nravi , nishkamravi2 + 2014-11-19 17:23:42 -0800 + Commit: 73fedf5, github.com/apache/spark/pull/3360 + + [SPARK-4478] Keep totalRegisteredExecutors up-to-date + Akshat Aranya + 2014-11-19 17:20:20 -0800 + Commit: 9ccc53c, github.com/apache/spark/pull/3373 + + Updating GraphX programming guide and documentation + Joseph E. Gonzalez + 2014-11-19 16:53:33 -0800 + Commit: 377b068, github.com/apache/spark/pull/3359 + + [SPARK-4495] Fix memory leak in JobProgressListener + Josh Rosen + 2014-11-19 16:50:21 -0800 + Commit: 04d462f, github.com/apache/spark/pull/3372 + + [SPARK-4294][Streaming] UnionDStream stream should express the requirements in the same way as TransformedDStream + Yadong Qi + 2014-11-19 15:53:06 -0800 + Commit: c3002c4, github.com/apache/spark/pull/3152 + + [SPARK-4384] [PySpark] improve sort spilling + Davies Liu + 2014-11-19 15:45:37 -0800 + Commit: 73c8ea8, github.com/apache/spark/pull/3252 + + [SPARK-4429][BUILD] Build for Scala 2.11 using sbt fails. + Takuya UESHIN + 2014-11-19 14:40:21 -0800 + Commit: f9adda9, github.com/apache/spark/pull/3342 + + [DOC][PySpark][Streaming] Fix docstring for sphinx + Ken Takagiwa + 2014-11-19 14:23:18 -0800 + Commit: 9b7bbce, github.com/apache/spark/pull/3311 + + SPARK-3962 Marked scope as provided for external projects. + Prashant Sharma , Prashant Sharma + 2014-11-19 14:18:10 -0800 + Commit: 1c93841, github.com/apache/spark/pull/2959 + + [HOT FIX] MiMa tests are broken + Andrew Or + 2014-11-19 14:03:44 -0800 + Commit: 0df02ca, github.com/apache/spark/pull/3371 + + [SPARK-4481][Streaming][Doc] Fix the wrong description of updateFunc + zsxwing + 2014-11-19 13:17:15 -0800 + Commit: 3bf7cee, github.com/apache/spark/pull/3356 + + [SPARK-4482][Streaming] Disable ReceivedBlockTracker's write ahead log by default + Tathagata Das + 2014-11-19 13:06:48 -0800 + Commit: 22fc4e7, github.com/apache/spark/pull/3358 + + [SPARK-4470] Validate number of threads in local mode + Kenichi Maehashi + 2014-11-19 12:11:09 -0800 + Commit: eacc788, github.com/apache/spark/pull/3337 + + [SPARK-4467] fix elements read count for ExtrenalSorter + Tianshuo Deng + 2014-11-19 10:01:09 -0800 + Commit: d75579d, github.com/apache/spark/pull/3302 + + SPARK-4455 Exclude dependency on hbase-annotations module + tedyu + 2014-11-19 00:55:39 -0800 + Commit: 5f5ac2d, github.com/apache/spark/pull/3286 + + MAINTENANCE: Automated closing of pull requests. + Patrick Wendell + 2014-11-19 00:27:31 -0800 + Commit: 8327df6, github.com/apache/spark/pull/2777 + + [Spark-4432]close InStream after the block is accessed + Mingfei + 2014-11-18 22:17:06 -0800 + Commit: 165cec9, github.com/apache/spark/pull/3290 + + [SPARK-4441] Close Tachyon client when TachyonBlockManager is shutdown + Mingfei + 2014-11-18 22:16:36 -0800 + Commit: 67e9876, github.com/apache/spark/pull/3299 + + Bumping version to 1.3.0-SNAPSHOT. + Marcelo Vanzin + 2014-11-18 21:24:18 -0800 + Commit: 397d3aa, github.com/apache/spark/pull/3277 + + [SPARK-4468][SQL] Fixes Parquet filter creation for inequality predicates with literals on the left hand side + Cheng Lian + 2014-11-18 17:41:54 -0800 + Commit: 423baea, github.com/apache/spark/pull/3334 + + [SPARK-4327] [PySpark] Python API for RDD.randomSplit() + Davies Liu + 2014-11-18 16:37:35 -0800 + Commit: 7f22fa8, github.com/apache/spark/pull/3193 + + [SPARK-4433] fix a racing condition in zipWithIndex + Xiangrui Meng + 2014-11-18 16:25:44 -0800 + Commit: bb46046, github.com/apache/spark/pull/3291 + + [SPARK-3721] [PySpark] broadcast objects larger than 2G + Davies Liu , Davies Liu + 2014-11-18 16:17:51 -0800 + Commit: 4a377af, github.com/apache/spark/pull/2659 + + [SPARK-4306] [MLlib] Python API for LogisticRegressionWithLBFGS + Davies Liu + 2014-11-18 15:57:33 -0800 + Commit: d2e2951, github.com/apache/spark/pull/3307 + + [SPARK-4463] Add (de)select all button for add'l metrics. + Kay Ousterhout + 2014-11-18 15:01:06 -0800 + Commit: 010bc86, github.com/apache/spark/pull/3331 + + [SPARK-4017] show progress bar in console + Davies Liu + 2014-11-18 13:37:21 -0800 + Commit: e34f38f, github.com/apache/spark/pull/3029 + + [SPARK-4404] remove sys.exit() in shutdown hook + Davies Liu + 2014-11-18 13:11:38 -0800 + Commit: 80f3177, github.com/apache/spark/pull/3289 + + [SPARK-4075][SPARK-4434] Fix the URI validation logic for Application Jar name. + Kousuke Saruta + 2014-11-18 12:17:33 -0800 + Commit: bfebfd8, github.com/apache/spark/pull/3326 + + [SQL] Support partitioned parquet tables that have the key in both the directory and the file + Michael Armbrust + 2014-11-18 12:13:23 -0800 + Commit: 90d72ec, github.com/apache/spark/pull/3272 + + [SPARK-4396] allow lookup by index in Python's Rating + Xiangrui Meng + 2014-11-18 10:35:29 -0800 + Commit: b54c6ab, github.com/apache/spark/pull/3261 + + [SPARK-4435] [MLlib] [PySpark] improve classification + Davies Liu + 2014-11-18 10:11:13 -0800 + Commit: 8fbf72b, github.com/apache/spark/pull/3305 + + ALS implicit: added missing parameter alpha in doc string + Felix Maximilian Möller + 2014-11-18 10:08:24 -0800 + Commit: cedc3b5, github.com/apache/spark/pull/3343 + + SPARK-4466: Provide support for publishing Scala 2.11 artifacts to Maven + Patrick Wendell + 2014-11-17 21:07:50 -0800 + Commit: c6e0c2a, github.com/apache/spark/pull/3332 + + [SPARK-4453][SPARK-4213][SQL] Simplifies Parquet filter generation code + Cheng Lian + 2014-11-17 16:55:12 -0800 + Commit: 36b0956, github.com/apache/spark/pull/3317 + + [SPARK-4448] [SQL] unwrap for the ConstantObjectInspector + Cheng Hao + 2014-11-17 16:35:49 -0800 + Commit: ef7c464, github.com/apache/spark/pull/3308 + + [SPARK-4443][SQL] Fix statistics for external table in spark sql hive + w00228970 + 2014-11-17 16:33:50 -0800 + Commit: 42389b1, github.com/apache/spark/pull/3304 + + [SPARK-4309][SPARK-4407][SQL] Date type support for Thrift server, and fixes for complex types + Cheng Lian + 2014-11-17 16:31:05 -0800 + Commit: 6b7f2f7, github.com/apache/spark/pull/3298 + + [SQL] Construct the MutableRow from an Array + Cheng Hao + 2014-11-17 16:29:52 -0800 + Commit: 69e858c, github.com/apache/spark/pull/3217 + + [SPARK-4425][SQL] Handle NaN or Infinity cast to Timestamp correctly. + Takuya UESHIN + 2014-11-17 16:28:07 -0800 + Commit: 566c791, github.com/apache/spark/pull/3283 + + [SPARK-4420][SQL] Change nullability of Cast from DoubleType/FloatType to DecimalType. + Takuya UESHIN + 2014-11-17 16:26:48 -0800 + Commit: 3a81a1c, github.com/apache/spark/pull/3278 + + [SQL] Makes conjunction pushdown more aggressive for in-memory table + Cheng Lian + 2014-11-17 15:33:13 -0800 + Commit: 5ce7dae, github.com/apache/spark/pull/3318 + + [SPARK-4180] [Core] Prevent creation of multiple active SparkContexts + Josh Rosen + 2014-11-17 12:48:18 -0800 + Commit: 0f3ceb5, github.com/apache/spark/pull/3121 + + [DOCS][SQL] Fix broken link to Row class scaladoc + Andy Konwinski + 2014-11-17 11:52:23 -0800 + Commit: cec1116, github.com/apache/spark/pull/3323 + + Revert "[SPARK-4075] [Deploy] Jar url validation is not enough for Jar file" + Andrew Or + 2014-11-17 11:24:28 -0800 + Commit: dbb9da5 + + [SPARK-4444] Drop VD type parameter from EdgeRDD + Ankur Dave + 2014-11-17 11:06:31 -0800 + Commit: 9ac2bb1, github.com/apache/spark/pull/3303 + + SPARK-2811 upgrade algebird to 0.8.1 + Adam Pingel + 2014-11-17 10:47:29 -0800 + Commit: e7690ed, github.com/apache/spark/pull/3282 + + SPARK-4445, Don't display storage level in toDebugString unless RDD is persisted. + Prashant Sharma + 2014-11-17 10:40:33 -0800 + Commit: 5c92d47, github.com/apache/spark/pull/3310 + + [SPARK-4410][SQL] Add support for external sort + Michael Armbrust + 2014-11-16 21:55:57 -0800 + Commit: 64c6b9b, github.com/apache/spark/pull/3268 + + [SPARK-4422][MLLIB]In some cases, Vectors.fromBreeze get wrong results. + GuoQiang Li + 2014-11-16 21:31:51 -0800 + Commit: 5168c6c, github.com/apache/spark/pull/3281 + + Revert "[SPARK-4309][SPARK-4407][SQL] Date type support for Thrift server, and fixes for complex types" + Michael Armbrust + 2014-11-16 15:05:04 -0800 + Commit: 45ce327, github.com/apache/spark/pull/3292 + + [SPARK-4309][SPARK-4407][SQL] Date type support for Thrift server, and fixes for complex types + Cheng Lian + 2014-11-16 14:26:41 -0800 + Commit: cb6bd83, github.com/apache/spark/pull/3178 + + [SPARK-4393] Fix memory leak in ConnectionManager ACK timeout TimerTasks; use HashedWheelTimer + Josh Rosen + 2014-11-16 00:44:15 -0800 + Commit: 7850e0c, github.com/apache/spark/pull/3259 + + [SPARK-4426][SQL][Minor] The symbol of BitwiseOr is wrong, should not be '&' + Kousuke Saruta + 2014-11-15 22:23:47 -0800 + Commit: 84468b2, github.com/apache/spark/pull/3284 + + [SPARK-4419] Upgrade snappy-java to 1.1.1.6 + Josh Rosen + 2014-11-15 22:22:34 -0800 + Commit: 7d8e152, github.com/apache/spark/pull/3287 + + [SPARK-2321] Several progress API improvements / refactorings + Josh Rosen + 2014-11-14 23:46:25 -0800 + Commit: 40eb8b6, github.com/apache/spark/pull/3197 + + Added contains(key) to Metadata + kai + 2014-11-14 23:44:23 -0800 + Commit: cbddac2, github.com/apache/spark/pull/3273 + + [SPARK-4260] Httpbroadcast should set connection timeout. + Kousuke Saruta + 2014-11-14 22:36:56 -0800 + Commit: 60969b0, github.com/apache/spark/pull/3122 + + [SPARK-4363][Doc] Update the Broadcast example + zsxwing + 2014-11-14 22:28:48 -0800 + Commit: 861223e, github.com/apache/spark/pull/3226 + + [SPARK-4379][Core] Change Exception to SparkException in checkpoint + zsxwing + 2014-11-14 22:25:41 -0800 + Commit: dba1405, github.com/apache/spark/pull/3241 + + [SPARK-4415] [PySpark] JVM should exit after Python exit + Davies Liu + 2014-11-14 20:13:46 -0800 + Commit: 7fe08b4, github.com/apache/spark/pull/3274 + + [SPARK-4404]SparkSubmitDriverBootstrapper should stop after its SparkSubmit sub-proc... + WangTao , WangTaoTheTonic + 2014-11-14 20:11:51 -0800 + Commit: 303a4e4, github.com/apache/spark/pull/3266 + + SPARK-4214. With dynamic allocation, avoid outstanding requests for more... + Sandy Ryza + 2014-11-14 15:51:05 -0800 + Commit: ad42b28, github.com/apache/spark/pull/3204 + + [SPARK-4412][SQL] Fix Spark's control of Parquet logging. + Jim Carroll + 2014-11-14 15:33:21 -0800 + Commit: 37482ce, github.com/apache/spark/pull/3271 + + [SPARK-4365][SQL] Remove unnecessary filter call on records returned from parquet library + Yash Datta + 2014-11-14 15:16:36 -0800 + Commit: 63ca3af, github.com/apache/spark/pull/3229 + + [SPARK-4386] Improve performance when writing Parquet files. + Jim Carroll + 2014-11-14 15:11:53 -0800 + Commit: f76b968, github.com/apache/spark/pull/3254 + + [SPARK-4322][SQL] Enables struct fields as sub expressions of grouping fields + Cheng Lian + 2014-11-14 15:09:36 -0800 + Commit: 0c7b66b, github.com/apache/spark/pull/3248 + + [SQL] Don't shuffle code generated rows + Michael Armbrust + 2014-11-14 15:03:23 -0800 + Commit: 4b4b50c, github.com/apache/spark/pull/3263 + + [SQL] Minor cleanup of comments, errors and override. + Michael Armbrust + 2014-11-14 15:00:42 -0800 + Commit: f805025, github.com/apache/spark/pull/3257 + + [SPARK-4391][SQL] Configure parquet filters using SQLConf + Michael Armbrust + 2014-11-14 14:59:35 -0800 + Commit: e47c387, github.com/apache/spark/pull/3258 + + [SPARK-4390][SQL] Handle NaN cast to decimal correctly + Michael Armbrust + 2014-11-14 14:56:57 -0800 + Commit: a0300ea, github.com/apache/spark/pull/3256 + + [SPARK-4062][Streaming]Add ReliableKafkaReceiver in Spark Streaming Kafka connector + jerryshao , Tathagata Das , Saisai Shao + 2014-11-14 14:33:37 -0800 + Commit: 5930f64, github.com/apache/spark/pull/2991 + + [SPARK-4333][SQL] Correctly log number of iterations in RuleExecutor + DoingDone9 <799203320@qq.com> + 2014-11-14 14:28:06 -0800 + Commit: 0cbdb01, github.com/apache/spark/pull/3180 + + SPARK-4375. no longer require -Pscala-2.10 + Sandy Ryza + 2014-11-14 14:21:57 -0800 + Commit: f5f757e, github.com/apache/spark/pull/3239 + + [SPARK-4245][SQL] Fix containsNull of the result ArrayType of CreateArray expression. + Takuya UESHIN + 2014-11-14 14:21:16 -0800 + Commit: bbd8f5b, github.com/apache/spark/pull/3110 + + [SPARK-4239] [SQL] support view in HiveQl + Daoyuan Wang + 2014-11-14 13:51:20 -0800 + Commit: ade72c4, github.com/apache/spark/pull/3131 + + Update failed assert text to match code in SizeEstimatorSuite + Jeff Hammerbacher + 2014-11-14 13:37:48 -0800 + Commit: c258db9, github.com/apache/spark/pull/3242 + + [SPARK-4313][WebUI][Yarn] Fix link issue of the executor thread dump page in yarn-cluster mode + zsxwing + 2014-11-14 13:36:13 -0800 + Commit: 156cf33, github.com/apache/spark/pull/3183 + + SPARK-3663 Document SPARK_LOG_DIR and SPARK_PID_DIR + Andrew Ash + 2014-11-14 13:33:35 -0800 + Commit: 5c265cc, github.com/apache/spark/pull/2518 + + [Spark Core] SPARK-4380 Edit spilling log from MB to B + Hong Shen + 2014-11-14 13:29:41 -0800 + Commit: 0c56a03, github.com/apache/spark/pull/3243 + + [SPARK-4398][PySpark] specialize sc.parallelize(xrange) + Xiangrui Meng + 2014-11-14 12:43:17 -0800 + Commit: abd5817, github.com/apache/spark/pull/3264 + + [SPARK-4394][SQL] Data Sources API Improvements + Michael Armbrust + 2014-11-14 12:00:08 -0800 + Commit: 77e845c, github.com/apache/spark/pull/3260 + + [SPARK-3722][Docs]minor improvement and fix in docs + WangTao + 2014-11-14 08:09:42 -0600 + Commit: e421072, github.com/apache/spark/pull/2579 + + [SPARK-4310][WebUI] Sort 'Submitted' column in Stage page by time + zsxwing + 2014-11-13 14:37:04 -0800 + Commit: 825709a, github.com/apache/spark/pull/3179 + + [SPARK-4372][MLLIB] Make LR and SVM's default parameters consistent in Scala and Python + Xiangrui Meng + 2014-11-13 13:54:16 -0800 + Commit: 3221830, github.com/apache/spark/pull/3232 + + [SPARK-4326] fix unidoc + Xiangrui Meng + 2014-11-13 13:16:20 -0800 + Commit: 4b0c1ed, github.com/apache/spark/pull/3253 + + [HOT FIX] make-distribution.sh fails if Yarn shuffle jar DNE + Andrew Or + 2014-11-13 11:54:45 -0800 + Commit: a0fa1ba, github.com/apache/spark/pull/3250 + + [SPARK-4378][MLLIB] make ALS more Java-friendly + Xiangrui Meng + 2014-11-13 11:42:27 -0800 + Commit: ca26a21, github.com/apache/spark/pull/3240 + + [SPARK-4348] [PySpark] [MLlib] rename random.py to rand.py + Davies Liu + 2014-11-13 10:24:54 -0800 + Commit: ce0333f, github.com/apache/spark/pull/3216 + + [SPARK-4256] Make Binary Evaluation Metrics functions defined in cases where there ar... + Andrew Bullen + 2014-11-12 22:14:44 -0800 + Commit: 484fecb, github.com/apache/spark/pull/3118 + + [SPARK-4370] [Core] Limit number of Netty cores based on executor size + Aaron Davidson + 2014-11-12 18:46:37 -0800 + Commit: b9e1c2e, github.com/apache/spark/pull/3155 + + [SPARK-4373][MLLIB] fix MLlib maven tests + Xiangrui Meng + 2014-11-12 18:15:14 -0800 + Commit: 23f5bdf, github.com/apache/spark/pull/3235 + + [Release] Bring audit scripts up-to-date + Andrew Or + 2014-11-13 00:30:58 +0000 + Commit: 723a86b + + [SPARK-2672] support compressed file in wholeTextFile + Davies Liu + 2014-11-12 15:58:12 -0800 + Commit: d7d54a4, github.com/apache/spark/pull/3005 + + [SPARK-4369] [MLLib] fix TreeModel.predict() with RDD + Davies Liu + 2014-11-12 13:56:41 -0800 + Commit: bd86118, github.com/apache/spark/pull/3230 + + [SPARK-3666] Extract interfaces for EdgeRDD and VertexRDD + Ankur Dave + 2014-11-12 13:49:20 -0800 + Commit: a5ef581, github.com/apache/spark/pull/2530 + + [Release] Correct make-distribution.sh log path + Andrew Or + 2014-11-12 13:46:26 -0800 + Commit: c3afd32 + + Internal cleanup for aggregateMessages + Ankur Dave + 2014-11-12 13:44:49 -0800 + Commit: 0402be9, github.com/apache/spark/pull/3231 + + [SPARK-4281][Build] Package Yarn shuffle service into its own jar + Andrew Or + 2014-11-12 13:39:45 -0800 + Commit: aa43a8d, github.com/apache/spark/pull/3147 + + [Test] Better exception message from SparkSubmitSuite + Andrew Or + 2014-11-12 13:35:48 -0800 + Commit: 6e3c5a2, github.com/apache/spark/pull/3212 + + [SPARK-3660][STREAMING] Initial RDD for updateStateByKey transformation + Soumitra Kumar + 2014-11-12 12:25:31 -0800 + Commit: 36ddeb7, github.com/apache/spark/pull/2665 + + [SPARK-3530][MLLIB] pipeline and parameters with examples + Xiangrui Meng + 2014-11-12 10:38:57 -0800 + Commit: 4b736db, github.com/apache/spark/pull/3099 + + [SPARK-4355][MLLIB] fix OnlineSummarizer.merge when other.mean is zero + Xiangrui Meng + 2014-11-12 01:50:11 -0800 + Commit: 84324fb, github.com/apache/spark/pull/3220 + + [SPARK-3936] Add aggregateMessages, which supersedes mapReduceTriplets + Ankur Dave + 2014-11-11 23:38:27 -0800 + Commit: faeb41d, github.com/apache/spark/pull/3100 + + [MLLIB] SPARK-4347: Reducing GradientBoostingSuite run time. + Manish Amde + 2014-11-11 22:47:53 -0800 + Commit: 2ef016b, github.com/apache/spark/pull/3214 + + Support cross building for Scala 2.11 + Prashant Sharma , Patrick Wendell + 2014-11-11 21:36:48 -0800 + Commit: daaca14, github.com/apache/spark/pull/3159 + + [Release] Log build output for each distribution + Andrew Or + 2014-11-11 18:02:59 -0800 + Commit: 2ddb141 + + SPARK-2269 Refactor mesos scheduler resourceOffers and add unit test + Timothy Chen + 2014-11-11 14:29:18 -0800 + Commit: a878660, github.com/apache/spark/pull/1487 + + [SPARK-4282][YARN] Stopping flag in YarnClientSchedulerBackend should be volatile + Kousuke Saruta + 2014-11-11 12:33:53 -0600 + Commit: 7f37188, github.com/apache/spark/pull/3143 + + SPARK-4305 [BUILD] yarn-alpha profile won't build due to network/yarn module + Sean Owen + 2014-11-11 12:30:35 -0600 + Commit: f820b56, github.com/apache/spark/pull/3167 + + SPARK-1830 Deploy failover, Make Persistence engine and LeaderAgent Pluggable + Prashant Sharma + 2014-11-11 09:29:48 -0800 + Commit: deefd9d, github.com/apache/spark/pull/771 + + [Streaming][Minor]Replace some 'if-else' in Clock + huangzhaowei + 2014-11-11 03:02:12 -0800 + Commit: 6e03de3, github.com/apache/spark/pull/3088 + + [SPARK-2492][Streaming] kafkaReceiver minor changes to align with Kafka 0.8 + jerryshao + 2014-11-11 02:22:23 -0800 + Commit: c8850a3, github.com/apache/spark/pull/1420 + + [SPARK-4295][External]Fix exception in SparkSinkSuite + maji2014 + 2014-11-11 02:18:27 -0800 + Commit: f8811a5, github.com/apache/spark/pull/3177 + + [SPARK-4307] Initialize FileDescriptor lazily in FileRegion. + Reynold Xin , Reynold Xin + 2014-11-11 00:25:31 -0800 + Commit: ef29a9a, github.com/apache/spark/pull/3172 + + [SPARK-4324] [PySpark] [MLlib] support numpy.array for all MLlib API + Davies Liu + 2014-11-10 22:26:16 -0800 + Commit: 65083e9, github.com/apache/spark/pull/3189 + + [SPARK-4330][Doc] Link to proper URL for YARN overview + Kousuke Saruta + 2014-11-10 22:18:00 -0800 + Commit: 3c07b8f, github.com/apache/spark/pull/3196 + + [SPARK-3649] Remove GraphX custom serializers + Ankur Dave + 2014-11-10 19:31:52 -0800 + Commit: 300887b, github.com/apache/spark/pull/2503 + + [SPARK-4274] [SQL] Fix NPE in printing the details of the query plan + Cheng Hao + 2014-11-10 17:46:05 -0800 + Commit: c764d0a, github.com/apache/spark/pull/3139 + + [SPARK-3954][Streaming] Optimization to FileInputDStream + surq + 2014-11-10 17:37:16 -0800 + Commit: ce6ed2a, github.com/apache/spark/pull/2811 + + [SPARK-4149][SQL] ISO 8601 support for json date time strings + Daoyuan Wang + 2014-11-10 17:26:03 -0800 + Commit: a1fc059, github.com/apache/spark/pull/3012 + + [SPARK-4250] [SQL] Fix bug of constant null value mapping to ConstantObjectInspector + Cheng Hao + 2014-11-10 17:22:57 -0800 + Commit: fa77783, github.com/apache/spark/pull/3114 + + [SQL] remove a decimal case branch that has no effect at runtime + Xiangrui Meng + 2014-11-10 17:20:52 -0800 + Commit: d793d80, github.com/apache/spark/pull/3192 + + [SPARK-4308][SQL] Sets SQL operation state to ERROR when exception is thrown + Cheng Lian + 2014-11-10 16:56:36 -0800 + Commit: acb55ae, github.com/apache/spark/pull/3175 + + [SPARK-4000][Build] Uploads HiveCompatibilitySuite logs + Cheng Lian + 2014-11-10 16:17:52 -0800 + Commit: 534b231, github.com/apache/spark/pull/2993 + + [SPARK-4319][SQL] Enable an ignored test "null count". + Takuya UESHIN + 2014-11-10 15:55:15 -0800 + Commit: dbf1058, github.com/apache/spark/pull/3185 + + Revert "[SPARK-2703][Core]Make Tachyon related unit tests execute without deploying a Tachyon system locally." + Patrick Wendell + 2014-11-10 14:56:06 -0800 + Commit: 6e7a309 + + [SPARK-4047] - Generate runtime warnings for example implementation of PageRank + Varadharajan Mukundan + 2014-11-10 14:32:29 -0800 + Commit: 974d334, github.com/apache/spark/pull/2894 + + SPARK-1297 Upgrade HBase dependency to 0.98 + tedyu + 2014-11-10 13:23:33 -0800 + Commit: b32734e, github.com/apache/spark/pull/3115 + + SPARK-4230. Doc for spark.default.parallelism is incorrect + Sandy Ryza + 2014-11-10 12:40:41 -0800 + Commit: c6f4e70, github.com/apache/spark/pull/3107 + + [SPARK-4312] bash doesn't have "die" + Jey Kottalam + 2014-11-10 12:37:56 -0800 + Commit: c5db8e2, github.com/apache/spark/pull/2898 + + Update RecoverableNetworkWordCount.scala + comcmipi + 2014-11-10 12:33:48 -0800 + Commit: 0340c56, github.com/apache/spark/pull/2735 + + SPARK-2548 [STREAMING] JavaRecoverableWordCount is missing + Sean Owen + 2014-11-10 11:47:27 -0800 + Commit: 3a02d41, github.com/apache/spark/pull/2564 + + [SPARK-4169] [Core] Accommodate non-English Locales in unit tests + Niklas Wilcke <1wilcke@informatik.uni-hamburg.de> + 2014-11-10 11:37:38 -0800 + Commit: ed8bf1e, github.com/apache/spark/pull/3036 + + [SQL] support udt to hive types conversion (hive->udt is not supported) + Xiangrui Meng + 2014-11-10 11:04:12 -0800 + Commit: 894a724, github.com/apache/spark/pull/3164 + + [SPARK-2703][Core]Make Tachyon related unit tests execute without deploying a Tachyon system locally. + RongGu + 2014-11-09 23:48:15 -0800 + Commit: bd86cb1, github.com/apache/spark/pull/3030 + + MAINTENANCE: Automated closing of pull requests. + Patrick Wendell + 2014-11-09 23:07:14 -0800 + Commit: 227488d, github.com/apache/spark/pull/2898 + + SPARK-3179. Add task OutputMetrics. + Sandy Ryza + 2014-11-09 22:29:03 -0800 + Commit: 3c2cff4, github.com/apache/spark/pull/2968 + + SPARK-1209 [CORE] (Take 2) SparkHadoop{MapRed,MapReduce}Util should not use package org.apache.hadoop + Sean Owen + 2014-11-09 22:11:20 -0800 + Commit: f8e5732, github.com/apache/spark/pull/3048 + + MAINTENANCE: Automated closing of pull requests. + Patrick Wendell + 2014-11-09 18:16:20 -0800 + Commit: f73b56f, github.com/apache/spark/pull/464 + + SPARK-1344 [DOCS] Scala API docs for top methods + Sean Owen + 2014-11-09 17:42:08 -0800 + Commit: d136265, github.com/apache/spark/pull/3168 + + SPARK-971 [DOCS] Link to Confluence wiki from project website / documentation + Sean Owen + 2014-11-09 17:40:48 -0800 + Commit: 8c99a47, github.com/apache/spark/pull/3169 + + [SPARK-4301] StreamingContext should not allow start() to be called after calling stop() + Josh Rosen + 2014-11-08 18:10:23 -0800 + Commit: 7b41b17, github.com/apache/spark/pull/3160 + + [Minor] [Core] Don't NPE on closeQuietly(null) + Aaron Davidson + 2014-11-08 13:03:51 -0800 + Commit: 4af5c7e, github.com/apache/spark/pull/3166 + + [SPARK-4291][Build] Rename network module projects + Andrew Or + 2014-11-07 23:16:13 -0800 + Commit: 7afc856, github.com/apache/spark/pull/3148 + + [MLLIB] [PYTHON] SPARK-4221: Expose nonnegative ALS in the python API + Michelangelo D'Agostino + 2014-11-07 22:53:01 -0800 + Commit: 7e9d975, github.com/apache/spark/pull/3095 + + [SPARK-4304] [PySpark] Fix sort on empty RDD + Davies Liu + 2014-11-07 20:53:03 -0800 + Commit: 7779109, github.com/apache/spark/pull/3162 + + MAINTENANCE: Automated closing of pull requests. + Patrick Wendell + 2014-11-07 13:08:25 -0800 + Commit: 5923dd9, github.com/apache/spark/pull/3016 + + Update JavaCustomReceiver.java + xiao321 <1042460381@qq.com> + 2014-11-07 12:56:49 -0800 + Commit: 7c9ec52, github.com/apache/spark/pull/3153 + + [SPARK-4292][SQL] Result set iterator bug in JDBC/ODBC + wangfei + 2014-11-07 12:55:11 -0800 + Commit: d6e5552, github.com/apache/spark/pull/3149 + + [SPARK-4203][SQL] Partition directories in random order when inserting into hive table + Matthew Taylor + 2014-11-07 12:53:08 -0800 + Commit: ac70c97, github.com/apache/spark/pull/3076 + + [SPARK-4270][SQL] Fix Cast from DateType to DecimalType. + Takuya UESHIN + 2014-11-07 12:30:47 -0800 + Commit: a6405c5, github.com/apache/spark/pull/3134 + + [SPARK-4272] [SQL] Add more unwrapper functions for primitive type in TableReader + Cheng Hao + 2014-11-07 12:15:53 -0800 + Commit: 60ab80f, github.com/apache/spark/pull/3136 + + [SPARK-4213][SQL] ParquetFilters - No support for LT, LTE, GT, GTE operators + Kousuke Saruta + 2014-11-07 11:56:40 -0800 + Commit: 14c54f1, github.com/apache/spark/pull/3083 + + [SQL] Modify keyword val location according to ordering + Jacky Li + 2014-11-07 11:52:08 -0800 + Commit: 68609c5, github.com/apache/spark/pull/3080 + + [SQL] Support ScalaReflection of schema in different universes + Michael Armbrust + 2014-11-07 11:51:20 -0800 + Commit: 8154ed7, github.com/apache/spark/pull/3096 + + [SPARK-4225][SQL] Resorts to SparkContext.version to inspect Spark version + Cheng Lian + 2014-11-07 11:45:25 -0800 + Commit: 86e9eaa, github.com/apache/spark/pull/3105 + + [SQL][DOC][Minor] Spark SQL Hive now support dynamic partitioning + wangfei + 2014-11-07 11:43:35 -0800 + Commit: 636d7bc, github.com/apache/spark/pull/3127 + + [SPARK-4187] [Core] Switch to binary protocol for external shuffle service messages + Aaron Davidson + 2014-11-07 09:42:21 -0800 + Commit: d4fa04e, github.com/apache/spark/pull/3146 + + [SPARK-4204][Core][WebUI] Change Utils.exceptionString to contain the inner exceptions and make the error information in Web UI more friendly + zsxwing + 2014-11-06 21:52:12 -0800 + Commit: 3abdb1b, github.com/apache/spark/pull/3073 + + [SPARK-4236] Cleanup removed applications' files in shuffle service + Aaron Davidson + 2014-11-06 19:54:32 -0800 + Commit: 48a19a6, github.com/apache/spark/pull/3126 + + [SPARK-4188] [Core] Perform network-level retry of shuffle file fetches + Aaron Davidson + 2014-11-06 18:39:14 -0800 + Commit: f165b2b, github.com/apache/spark/pull/3101 + + [SPARK-4277] Support external shuffle service on Standalone Worker + Aaron Davidson + 2014-11-06 17:20:46 -0800 + Commit: 6e9ef10, github.com/apache/spark/pull/3142 + + [SPARK-3797] Minor addendum to Yarn shuffle service + Andrew Or + 2014-11-06 17:18:49 -0800 + Commit: 96136f2, github.com/apache/spark/pull/3144 + + [HOT FIX] Make distribution fails + Andrew Or + 2014-11-06 15:31:07 -0800 + Commit: 470881b, github.com/apache/spark/pull/3145 + + [SPARK-4249][GraphX]fix a problem of EdgePartitionBuilder in Graphx + lianhuiwang + 2014-11-06 10:46:45 -0800 + Commit: d15c6e9, github.com/apache/spark/pull/3138 + + [SPARK-4264] Completion iterator should only invoke callback once + Aaron Davidson + 2014-11-06 10:45:46 -0800 + Commit: 23eaf0e, github.com/apache/spark/pull/3128 + + [SPARK-4186] add binaryFiles and binaryRecords in Python + Davies Liu + 2014-11-06 00:22:19 -0800 + Commit: b41a39e, github.com/apache/spark/pull/3078 + + [SPARK-4255] Fix incorrect table striping + Kay Ousterhout + 2014-11-06 00:03:03 -0800 + Commit: 5f27ae1, github.com/apache/spark/pull/3117 + + [SPARK-4137] [EC2] Don't change working dir on user + Nicholas Chammas + 2014-11-05 20:45:35 -0800 + Commit: db45f5a, github.com/apache/spark/pull/2988 + + [SPARK-4262][SQL] add .schemaRDD to JavaSchemaRDD + Xiangrui Meng + 2014-11-05 19:56:16 -0800 + Commit: 3d2b5bc, github.com/apache/spark/pull/3125 + + [SPARK-4254] [mllib] MovieLensALS bug fix + Joseph K. Bradley + 2014-11-05 19:51:18 -0800 + Commit: c315d13, github.com/apache/spark/pull/3116 + + [SPARK-4158] Fix for missing resources. + Brenden Matthews + 2014-11-05 16:02:44 -0800 + Commit: cb0eae3, github.com/apache/spark/pull/3024 + + SPARK-3223 runAsSparkUser cannot change HDFS write permission properly i... + Jongyoul Lee + 2014-11-05 15:49:42 -0800 + Commit: f7ac8c2, github.com/apache/spark/pull/3034 + + SPARK-4040. Update documentation to exemplify use of local (n) value, fo... + jay@apache.org + 2014-11-05 15:45:34 -0800 + Commit: 868cd4c, github.com/apache/spark/pull/2964 + + [SPARK-3797] Run external shuffle service in Yarn NM + Andrew Or + 2014-11-05 15:42:05 -0800 + Commit: 61a5cce, github.com/apache/spark/pull/3082 + + SPARK-4222 [CORE] use readFully in FixedLengthBinaryRecordReader + industrial-sloth + 2014-11-05 15:38:48 -0800 + Commit: f37817b, github.com/apache/spark/pull/3093 + + [SPARK-3984] [SPARK-3983] Fix incorrect scheduler delay and display task deserialization time in UI + Kay Ousterhout + 2014-11-05 15:30:31 -0800 + Commit: a46497e, github.com/apache/spark/pull/2832 + + [SPARK-4242] [Core] Add SASL to external shuffle service + Aaron Davidson + 2014-11-05 14:38:43 -0800 + Commit: 4c42986, github.com/apache/spark/pull/3108 + + [SPARK-4197] [mllib] GradientBoosting API cleanup and examples in Scala, Java + Joseph K. Bradley + 2014-11-05 10:33:13 -0800 + Commit: 5b3b6f6, github.com/apache/spark/pull/3094 + + [SPARK-4029][Streaming] Update streaming driver to reliably save and recover received block metadata on driver failures + Tathagata Das + 2014-11-05 01:21:53 -0800 + Commit: 5f13759, github.com/apache/spark/pull/3026 + + [SPARK-3964] [MLlib] [PySpark] add Hypothesis test Python API + Davies Liu + 2014-11-04 21:35:52 -0800 + Commit: c8abddc, github.com/apache/spark/pull/3091 + + [SQL] Add String option for DSL AS + Michael Armbrust + 2014-11-04 18:14:28 -0800 + Commit: 515abb9, github.com/apache/spark/pull/3097 + + [SPARK-2938] Support SASL authentication in NettyBlockTransferService + Aaron Davidson + 2014-11-04 16:15:38 -0800 + Commit: 5e73138, github.com/apache/spark/pull/3087 + + [Spark-4060] [MLlib] exposing special rdd functions to the public + Niklas Wilcke <1wilcke@informatik.uni-hamburg.de> + 2014-11-04 09:57:03 -0800 + Commit: f90ad5d, github.com/apache/spark/pull/2907 + + fixed MLlib Naive-Bayes java example bug + Dariusz Kobylarz + 2014-11-04 09:53:43 -0800 + Commit: bcecd73, github.com/apache/spark/pull/3081 + + [SPARK-3886] [PySpark] simplify serializer, use AutoBatchedSerializer by default. + Davies Liu + 2014-11-03 23:56:14 -0800 + Commit: e4f4263, github.com/apache/spark/pull/2920 + + [SPARK-4166][Core] Add a backward compatibility test for ExecutorLostFailure + zsxwing + 2014-11-03 22:47:45 -0800 + Commit: b671ce0, github.com/apache/spark/pull/3085 + + [SPARK-4163][Core] Add a backward compatibility test for FetchFailed + zsxwing + 2014-11-03 22:40:43 -0800 + Commit: 9bdc841, github.com/apache/spark/pull/3086 + + [SPARK-3573][MLLIB] Make MLlib's Vector compatible with SQL's SchemaRDD + Xiangrui Meng + 2014-11-03 22:29:48 -0800 + Commit: 1a9c6cd, github.com/apache/spark/pull/3070 + + [SPARK-4192][SQL] Internal API for Python UDT + Xiangrui Meng + 2014-11-03 19:29:11 -0800 + Commit: 04450d1, github.com/apache/spark/pull/3068 + + [FIX][MLLIB] fix seed in BaggedPointSuite + Xiangrui Meng + 2014-11-03 18:50:37 -0800 + Commit: c5912ec, github.com/apache/spark/pull/3084 + + [SPARK-611] Display executor thread dumps in web UI + Josh Rosen + 2014-11-03 18:18:47 -0800 + Commit: 4f035dd, github.com/apache/spark/pull/2944 + + [SPARK-4168][WebUI] web statges number should show correctly when stages are more than 1000 + Zhang, Liye + 2014-11-03 18:17:32 -0800 + Commit: 97a466e, github.com/apache/spark/pull/3035 + + [SQL] Convert arguments to Scala UDFs + Michael Armbrust + 2014-11-03 18:04:51 -0800 + Commit: 15b58a2, github.com/apache/spark/pull/3077 + + SPARK-4178. Hadoop input metrics ignore bytes read in RecordReader insta... + Sandy Ryza + 2014-11-03 15:19:01 -0800 + Commit: 2812815, github.com/apache/spark/pull/3045 + + [SQL] More aggressive defaults + Michael Armbrust + 2014-11-03 14:08:27 -0800 + Commit: 25bef7e, github.com/apache/spark/pull/3064 + + [SPARK-4152] [SQL] Avoid data change in CTAS while table already existed + Cheng Hao + 2014-11-03 13:59:43 -0800 + Commit: e83f13e, github.com/apache/spark/pull/3013 + + [SPARK-4202][SQL] Simple DSL support for Scala UDF + Cheng Lian + 2014-11-03 13:20:33 -0800 + Commit: c238fb4, github.com/apache/spark/pull/3067 + + [SPARK-3594] [PySpark] [SQL] take more rows to infer schema or sampling + Davies Liu , Davies Liu + 2014-11-03 13:17:09 -0800 + Commit: 24544fb, github.com/apache/spark/pull/2716 + + [SPARK-4207][SQL] Query which has syntax like 'not like' is not working in Spark SQL + ravipesala + 2014-11-03 13:07:41 -0800 + Commit: 2b6e1ce, github.com/apache/spark/pull/3075 + + [SPARK-4211][Build] Fixes hive.version in Maven profile hive-0.13.1 + fi + 2014-11-03 12:56:56 -0800 + Commit: df607da, github.com/apache/spark/pull/3072 + + [SPARK-4148][PySpark] fix seed distribution and add some tests for rdd.sample + Xiangrui Meng + 2014-11-03 12:24:24 -0800 + Commit: 3cca196, github.com/apache/spark/pull/3010 + + [EC2] Factor out Mesos spark-ec2 branch + Nicholas Chammas + 2014-11-03 09:02:35 -0800 + Commit: 2aca97c, github.com/apache/spark/pull/3008 + diff --git a/LICENSE b/LICENSE index 0a42d389e4c3c..9b364a4d00079 100644 --- a/LICENSE +++ b/LICENSE @@ -771,6 +771,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. +======================================================================== +For TestTimSort (core/src/test/java/org/apache/spark/util/collection/TestTimSort.java): +======================================================================== +Copyright (C) 2015 Stijn de Gouw + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. ======================================================================== For LimitedInputStream diff --git a/assembly/pom.xml b/assembly/pom.xml index 1bb5a671f5390..0952cd2fba101 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.3.2-SNAPSHOT ../pom.xml @@ -39,7 +39,7 @@ spark /usr/share/spark root - 744 + 755 @@ -221,6 +221,24 @@ deb + + maven-antrun-plugin + + + prepare-package + + run + + + + + NOTE: Debian packaging is deprecated and is scheduled to be removed in Spark 1.4. + + + + + + org.codehaus.mojo buildnumber-maven-plugin @@ -280,7 +298,7 @@ ${deb.user} ${deb.user} ${deb.install.path}/conf - 744 + ${deb.bin.filemode} @@ -302,7 +320,7 @@ ${deb.user} ${deb.user} ${deb.install.path}/sbin - 744 + ${deb.bin.filemode} @@ -313,7 +331,7 @@ ${deb.user} ${deb.user} ${deb.install.path}/python - 744 + ${deb.bin.filemode} diff --git a/bagel/pom.xml b/bagel/pom.xml index 510e92640eff8..602cc7b822520 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.3.2-SNAPSHOT ../pom.xml diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh index a8c344b1ca594..679dfaf3a33f9 100755 --- a/bin/compute-classpath.sh +++ b/bin/compute-classpath.sh @@ -25,17 +25,24 @@ FWDIR="$(cd "`dirname "$0"`"/..; pwd)" . "$FWDIR"/bin/load-spark-env.sh -if [ -n "$SPARK_CLASSPATH" ]; then - CLASSPATH="$SPARK_CLASSPATH:$SPARK_SUBMIT_CLASSPATH" -else - CLASSPATH="$SPARK_SUBMIT_CLASSPATH" -fi +function appendToClasspath(){ + if [ -n "$1" ]; then + if [ -n "$CLASSPATH" ]; then + CLASSPATH="$CLASSPATH:$1" + else + CLASSPATH="$1" + fi + fi +} + +appendToClasspath "$SPARK_CLASSPATH" +appendToClasspath "$SPARK_SUBMIT_CLASSPATH" # Build up classpath if [ -n "$SPARK_CONF_DIR" ]; then - CLASSPATH="$CLASSPATH:$SPARK_CONF_DIR" + appendToClasspath "$SPARK_CONF_DIR" else - CLASSPATH="$CLASSPATH:$FWDIR/conf" + appendToClasspath "$FWDIR/conf" fi ASSEMBLY_DIR="$FWDIR/assembly/target/scala-$SPARK_SCALA_VERSION" @@ -51,20 +58,20 @@ if [ -n "$SPARK_PREPEND_CLASSES" ]; then echo "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark"\ "classes ahead of assembly." >&2 # Spark classes - CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/tools/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/hive-thriftserver/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SPARK_SCALA_VERSION/classes" + appendToClasspath "$FWDIR/core/target/scala-$SPARK_SCALA_VERSION/classes" + appendToClasspath "$FWDIR/repl/target/scala-$SPARK_SCALA_VERSION/classes" + appendToClasspath "$FWDIR/mllib/target/scala-$SPARK_SCALA_VERSION/classes" + appendToClasspath "$FWDIR/bagel/target/scala-$SPARK_SCALA_VERSION/classes" + appendToClasspath "$FWDIR/graphx/target/scala-$SPARK_SCALA_VERSION/classes" + appendToClasspath "$FWDIR/streaming/target/scala-$SPARK_SCALA_VERSION/classes" + appendToClasspath "$FWDIR/tools/target/scala-$SPARK_SCALA_VERSION/classes" + appendToClasspath "$FWDIR/sql/catalyst/target/scala-$SPARK_SCALA_VERSION/classes" + appendToClasspath "$FWDIR/sql/core/target/scala-$SPARK_SCALA_VERSION/classes" + appendToClasspath "$FWDIR/sql/hive/target/scala-$SPARK_SCALA_VERSION/classes" + appendToClasspath "$FWDIR/sql/hive-thriftserver/target/scala-$SPARK_SCALA_VERSION/classes" + appendToClasspath "$FWDIR/yarn/stable/target/scala-$SPARK_SCALA_VERSION/classes" # Jars for shaded deps in their original form (copied here during build) - CLASSPATH="$CLASSPATH:$FWDIR/core/target/jars/*" + appendToClasspath "$FWDIR/core/target/jars/*" fi # Use spark-assembly jar from either RELEASE or assembly directory @@ -76,7 +83,7 @@ fi num_jars=0 -for f in ${assembly_folder}/spark-assembly*hadoop*.jar; do +for f in "${assembly_folder}"/spark-assembly*hadoop*.jar; do if [[ ! -e "$f" ]]; then echo "Failed to find Spark assembly in $assembly_folder" 1>&2 echo "You need to build Spark before running this program." 1>&2 @@ -88,22 +95,25 @@ done if [ "$num_jars" -gt "1" ]; then echo "Found multiple Spark assembly jars in $assembly_folder:" 1>&2 - ls ${assembly_folder}/spark-assembly*hadoop*.jar 1>&2 + ls "${assembly_folder}"/spark-assembly*hadoop*.jar 1>&2 echo "Please remove all but one jar." 1>&2 exit 1 fi -# Verify that versions of java used to build the jars and run Spark are compatible -jar_error_check=$("$JAR_CMD" -tf "$ASSEMBLY_JAR" nonexistent/class/path 2>&1) -if [[ "$jar_error_check" =~ "invalid CEN header" ]]; then - echo "Loading Spark jar with '$JAR_CMD' failed. " 1>&2 - echo "This is likely because Spark was compiled with Java 7 and run " 1>&2 - echo "with Java 6. (see SPARK-1703). Please use Java 7 to run Spark " 1>&2 - echo "or build Spark with Java 6." 1>&2 - exit 1 +# Only able to make this check if 'jar' command is available +if [ $(command -v "$JAR_CMD") ] ; then + # Verify that versions of java used to build the jars and run Spark are compatible + jar_error_check=$("$JAR_CMD" -tf "$ASSEMBLY_JAR" nonexistent/class/path 2>&1) + if [[ "$jar_error_check" =~ "invalid CEN header" ]]; then + echo "Loading Spark jar with '$JAR_CMD' failed. " 1>&2 + echo "This is likely because Spark was compiled with Java 7 and run " 1>&2 + echo "with Java 6. (see SPARK-1703). Please use Java 7 to run Spark " 1>&2 + echo "or build Spark with Java 6." 1>&2 + exit 1 + fi fi -CLASSPATH="$CLASSPATH:$ASSEMBLY_JAR" +appendToClasspath "$ASSEMBLY_JAR" # When Hive support is needed, Datanucleus jars must be included on the classpath. # Datanucleus jars do not work if only included in the uber jar as plugin.xml metadata is lost. @@ -121,41 +131,31 @@ datanucleus_jars="$(find "$datanucleus_dir" 2>/dev/null | grep "datanucleus-.*\\ datanucleus_jars="$(echo "$datanucleus_jars" | tr "\n" : | sed s/:$//g)" if [ -n "$datanucleus_jars" ]; then - hive_files=$("$JAR_CMD" -tf "$ASSEMBLY_JAR" org/apache/hadoop/hive/ql/exec 2>/dev/null) - if [ -n "$hive_files" ]; then - echo "Spark assembly has been built with Hive, including Datanucleus jars on classpath" 1>&2 - CLASSPATH="$CLASSPATH:$datanucleus_jars" - fi + appendToClasspath "$datanucleus_jars" fi # Add test classes if we're running from SBT or Maven with SPARK_TESTING set to 1 if [[ $SPARK_TESTING == 1 ]]; then - CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SPARK_SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SPARK_SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SPARK_SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SPARK_SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SPARK_SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SPARK_SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SPARK_SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SPARK_SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SPARK_SCALA_VERSION/test-classes" + appendToClasspath "$FWDIR/core/target/scala-$SPARK_SCALA_VERSION/test-classes" + appendToClasspath "$FWDIR/repl/target/scala-$SPARK_SCALA_VERSION/test-classes" + appendToClasspath "$FWDIR/mllib/target/scala-$SPARK_SCALA_VERSION/test-classes" + appendToClasspath "$FWDIR/bagel/target/scala-$SPARK_SCALA_VERSION/test-classes" + appendToClasspath "$FWDIR/graphx/target/scala-$SPARK_SCALA_VERSION/test-classes" + appendToClasspath "$FWDIR/streaming/target/scala-$SPARK_SCALA_VERSION/test-classes" + appendToClasspath "$FWDIR/sql/catalyst/target/scala-$SPARK_SCALA_VERSION/test-classes" + appendToClasspath "$FWDIR/sql/core/target/scala-$SPARK_SCALA_VERSION/test-classes" + appendToClasspath "$FWDIR/sql/hive/target/scala-$SPARK_SCALA_VERSION/test-classes" fi # Add hadoop conf dir if given -- otherwise FileSystem.*, etc fail ! # Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts # the configurtion files. -if [ -n "$HADOOP_CONF_DIR" ]; then - CLASSPATH="$CLASSPATH:$HADOOP_CONF_DIR" -fi -if [ -n "$YARN_CONF_DIR" ]; then - CLASSPATH="$CLASSPATH:$YARN_CONF_DIR" -fi +appendToClasspath "$HADOOP_CONF_DIR" +appendToClasspath "$YARN_CONF_DIR" # To allow for distributions to append needed libraries to the classpath (e.g. when # using the "hadoop-provided" profile to build Spark), check SPARK_DIST_CLASSPATH and # append it to tbe final classpath. -if [ -n "$SPARK_DIST_CLASSPATH" ]; then - CLASSPATH="$CLASSPATH:$SPARK_DIST_CLASSPATH" -fi +appendToClasspath "$SPARK_DIST_CLASSPATH" echo "$CLASSPATH" diff --git a/bin/run-example b/bin/run-example index c567acf9a6b5c..a106411392e06 100755 --- a/bin/run-example +++ b/bin/run-example @@ -42,7 +42,7 @@ fi JAR_COUNT=0 -for f in ${JAR_PATH}/spark-examples-*hadoop*.jar; do +for f in "${JAR_PATH}"/spark-examples-*hadoop*.jar; do if [[ ! -e "$f" ]]; then echo "Failed to find Spark examples assembly in $FWDIR/lib or $FWDIR/examples/target" 1>&2 echo "You need to build Spark before running this program" 1>&2 @@ -54,7 +54,7 @@ done if [ "$JAR_COUNT" -gt "1" ]; then echo "Found multiple Spark examples assembly jars in ${JAR_PATH}" 1>&2 - ls ${JAR_PATH}/spark-examples-*hadoop*.jar 1>&2 + ls "${JAR_PATH}"/spark-examples-*hadoop*.jar 1>&2 echo "Please remove all but one jar." 1>&2 exit 1 fi diff --git a/bin/spark-shell.cmd b/bin/spark-shell.cmd old mode 100755 new mode 100644 diff --git a/bin/spark-submit2.cmd b/bin/spark-submit2.cmd index 12244a9cb04fb..446cbc74b74f9 100644 --- a/bin/spark-submit2.cmd +++ b/bin/spark-submit2.cmd @@ -25,7 +25,7 @@ set ORIG_ARGS=%* rem Reset the values of all variables used set SPARK_SUBMIT_DEPLOY_MODE=client -if not defined %SPARK_CONF_DIR% ( +if [%SPARK_CONF_DIR%] == [] ( set SPARK_CONF_DIR=%SPARK_HOME%\conf ) set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_CONF_DIR%\spark-defaults.conf diff --git a/bin/utils.sh b/bin/utils.sh index 22ea2b9a6d586..748dbe345a74c 100755 --- a/bin/utils.sh +++ b/bin/utils.sh @@ -26,16 +26,17 @@ function gatherSparkSubmitOpts() { exit 1 fi - # NOTE: If you add or remove spark-sumbmit options, + # NOTE: If you add or remove spark-submit options, # modify NOT ONLY this script but also SparkSubmitArgument.scala SUBMISSION_OPTS=() APPLICATION_OPTS=() while (($#)); do case "$1" in - --master | --deploy-mode | --class | --name | --jars | --py-files | --files | \ - --conf | --properties-file | --driver-memory | --driver-java-options | \ + --master | --deploy-mode | --class | --name | --jars | --packages | --py-files | --files | \ + --conf | --repositories | --properties-file | --driver-memory | --driver-java-options | \ --driver-library-path | --driver-class-path | --executor-memory | --driver-cores | \ - --total-executor-cores | --executor-cores | --queue | --num-executors | --archives) + --total-executor-cores | --executor-cores | --queue | --num-executors | --archives | \ + --proxy-user) if [[ $# -lt 2 ]]; then "$SUBMIT_USAGE_FUNCTION" exit 1; diff --git a/bin/windows-utils.cmd b/bin/windows-utils.cmd index 1082a952dac99..0cf9e87ca554b 100644 --- a/bin/windows-utils.cmd +++ b/bin/windows-utils.cmd @@ -32,7 +32,8 @@ SET opts="\<--master\> \<--deploy-mode\> \<--class\> \<--name\> \<--jars\> \<--p SET opts="%opts:~1,-1% \<--conf\> \<--properties-file\> \<--driver-memory\> \<--driver-java-options\>" SET opts="%opts:~1,-1% \<--driver-library-path\> \<--driver-class-path\> \<--executor-memory\>" SET opts="%opts:~1,-1% \<--driver-cores\> \<--total-executor-cores\> \<--executor-cores\> \<--queue\>" -SET opts="%opts:~1,-1% \<--num-executors\> \<--archives\>" +SET opts="%opts:~1,-1% \<--num-executors\> \<--archives\> \<--packages\> \<--repositories\>" +SET opts="%opts:~1,-1% \<--proxy-user\>" echo %1 | findstr %opts% >nul if %ERRORLEVEL% equ 0 ( diff --git a/build/mvn b/build/mvn index a87c5a26230c8..3561110a4c019 100755 --- a/build/mvn +++ b/build/mvn @@ -21,6 +21,8 @@ _DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" # Preserve the calling directory _CALLING_DIR="$(pwd)" +# Options used during compilation +_COMPILE_JVM_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m" # Installs any application tarball given a URL, the expected tarball name, # and, optionally, a checkable binary path to determine if the binary has @@ -34,14 +36,14 @@ install_app() { local binary="${_DIR}/$3" # setup `curl` and `wget` silent options if we're running on Jenkins - local curl_opts="" + local curl_opts="-L" local wget_opts="" if [ -n "$AMPLAB_JENKINS" ]; then - curl_opts="-s" - wget_opts="--quiet" + curl_opts="-s ${curl_opts}" + wget_opts="--quiet ${wget_opts}" else - curl_opts="--progress-bar" - wget_opts="--progress=bar:force" + curl_opts="--progress-bar ${curl_opts}" + wget_opts="--progress=bar:force ${wget_opts}" fi if [ -z "$3" -o ! -f "$binary" ]; then @@ -136,6 +138,7 @@ cd "${_CALLING_DIR}" # Now that zinc is ensured to be installed, check its status and, if its # not running or just installed, start it if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`${ZINC_BIN} -status`" ]; then + export ZINC_OPTS=${ZINC_OPTS:-"$_COMPILE_JVM_OPTS"} ${ZINC_BIN} -shutdown ${ZINC_BIN} -start -port ${ZINC_PORT} \ -scala-compiler "${SCALA_COMPILER}" \ @@ -143,7 +146,7 @@ if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`${ZINC_BIN} -status`" ]; then fi # Set any `mvn` options if not already present -export MAVEN_OPTS=${MAVEN_OPTS:-"-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m"} +export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"} # Last, call the `mvn` command as usual ${MVN_BIN} "$@" diff --git a/build/sbt b/build/sbt index 28ebb64f7197c..cc3203d79bccd 100755 --- a/build/sbt +++ b/build/sbt @@ -125,4 +125,32 @@ loadConfigFile() { [[ -f "$etc_sbt_opts_file" ]] && set -- $(loadConfigFile "$etc_sbt_opts_file") "$@" [[ -f "$sbt_opts_file" ]] && set -- $(loadConfigFile "$sbt_opts_file") "$@" +exit_status=127 +saved_stty="" + +restoreSttySettings() { + stty $saved_stty + saved_stty="" +} + +onExit() { + if [[ "$saved_stty" != "" ]]; then + restoreSttySettings + fi + exit $exit_status +} + +saveSttySettings() { + saved_stty=$(stty -g 2>/dev/null) + if [[ ! $? ]]; then + saved_stty="" + fi +} + +saveSttySettings +trap onExit INT + run "$@" + +exit_status=$? +onExit diff --git a/build/sbt-launch-lib.bash b/build/sbt-launch-lib.bash index 5e0c640fa5919..504be48b358fa 100755 --- a/build/sbt-launch-lib.bash +++ b/build/sbt-launch-lib.bash @@ -81,7 +81,7 @@ execRunner () { echo "" } - exec "$@" + "$@" } addJava () { diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 0886b0276fb90..67f81d33361e1 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -15,7 +15,7 @@ # - SPARK_PUBLIC_DNS, to set the public DNS name of the driver program # - SPARK_CLASSPATH, default classpath entries to append # - SPARK_LOCAL_DIRS, storage directories to use on this node for shuffle and RDD data -# - MESOS_NATIVE_LIBRARY, to point to your libmesos.so if you use Mesos +# - MESOS_NATIVE_JAVA_LIBRARY, to point to your libmesos.so if you use Mesos # Options read in YARN client mode # - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files diff --git a/core/pom.xml b/core/pom.xml index 2c115683fce66..5971d05c0db7c 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.3.2-SNAPSHOT ../pom.xml @@ -132,6 +132,13 @@ jetty-servlet compile + + + org.eclipse.jetty.orbit + javax.servlet + ${orbit.version} + org.apache.commons @@ -236,11 +243,30 @@ io.dropwizard.metrics metrics-graphite + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.module + jackson-module-scala_2.10 + org.apache.derby derby test + + org.apache.ivy + ivy + ${ivy.version} + + + oro + + oro + ${oro.version} + org.tachyonproject tachyon-client @@ -293,6 +319,12 @@ selenium-java test + + + xml-apis + xml-apis + test + org.mockito mockito-all @@ -303,16 +335,6 @@ scalacheck_${scala.binary.version} test - - org.easymock - easymockclassextension - test - - - asm - asm - test - junit junit @@ -387,7 +409,7 @@ true true - guava,jetty-io,jetty-servlet,jetty-continuation,jetty-http,jetty-plus,jetty-util,jetty-server + guava,jetty-io,jetty-servlet,jetty-continuation,jetty-http,jetty-plus,jetty-util,jetty-server,jetty-security true diff --git a/core/src/main/java/org/apache/spark/util/collection/TimSort.java b/core/src/main/java/org/apache/spark/util/collection/TimSort.java index 409e1a41c5d49..a90cc0e761f62 100644 --- a/core/src/main/java/org/apache/spark/util/collection/TimSort.java +++ b/core/src/main/java/org/apache/spark/util/collection/TimSort.java @@ -425,15 +425,14 @@ private void pushRun(int runBase, int runLen) { private void mergeCollapse() { while (stackSize > 1) { int n = stackSize - 2; - if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) { + if ( (n >= 1 && runLen[n-1] <= runLen[n] + runLen[n+1]) + || (n >= 2 && runLen[n-2] <= runLen[n] + runLen[n-1])) { if (runLen[n - 1] < runLen[n + 1]) n--; - mergeAt(n); - } else if (runLen[n] <= runLen[n + 1]) { - mergeAt(n); - } else { + } else if (runLen[n] > runLen[n + 1]) { break; // Invariant is established } + mergeAt(n); } } diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index f23ba9dba167f..6c37cc8b98236 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -103,6 +103,12 @@ span.expand-details { float: right; } +span.rest-uri { + font-size: 10pt; + font-style: italic; + color: gray; +} + pre { font-size: 0.8em; } @@ -190,7 +196,7 @@ span.additional-metric-title { /* Hide all additional metrics by default. This is done here rather than using JavaScript to * avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */ -.scheduler_delay, .deserialization_time, .fetch_wait_time, .serialization_time, -.getting_result_time { +.scheduler_delay, .deserialization_time, .fetch_wait_time, .shuffle_read_remote, +.serialization_time, .getting_result_time { display: none; } diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index a0c0372b7f0ef..a96d754744a05 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -47,10 +47,15 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { val inputMetrics = blockResult.inputMetrics val existingMetrics = context.taskMetrics .getInputMetricsForReadMethod(inputMetrics.readMethod) - existingMetrics.addBytesRead(inputMetrics.bytesRead) - - new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]]) + existingMetrics.incBytesRead(inputMetrics.bytesRead) + val iter = blockResult.data.asInstanceOf[Iterator[T]] + new InterruptibleIterator[T](context, iter) { + override def next(): T = { + existingMetrics.incRecordsRead(1) + delegate.next() + } + } case None => // Acquire a lock for loading this partition // If another thread already holds the lock, wait for it to finish return its results diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index ede1e23f4fcc5..98e440102e30e 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -104,9 +104,19 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { cleaningThread.start() } - /** Stop the cleaner. */ + /** + * Stop the cleaning thread and wait until the thread has finished running its current task. + */ def stop() { stopped = true + // Interrupt the cleaning thread, but wait until the current task has finished before + // doing so. This guards against the race condition where a cleaning thread may + // potentially clean similarly named variables created by a different SparkContext, + // resulting in otherwise inexplicable block-not-found exceptions (SPARK-6132). + synchronized { + cleaningThread.interrupt() + } + cleaningThread.join() } /** Register a RDD for cleanup when it is garbage collected. */ @@ -135,19 +145,23 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { try { val reference = Option(referenceQueue.remove(ContextCleaner.REF_QUEUE_POLL_TIMEOUT)) .map(_.asInstanceOf[CleanupTaskWeakReference]) - reference.map(_.task).foreach { task => - logDebug("Got cleaning task " + task) - referenceBuffer -= reference.get - task match { - case CleanRDD(rddId) => - doCleanupRDD(rddId, blocking = blockOnCleanupTasks) - case CleanShuffle(shuffleId) => - doCleanupShuffle(shuffleId, blocking = blockOnShuffleCleanupTasks) - case CleanBroadcast(broadcastId) => - doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks) + // Synchronize here to avoid being interrupted on stop() + synchronized { + reference.map(_.task).foreach { task => + logDebug("Got cleaning task " + task) + referenceBuffer -= reference.get + task match { + case CleanRDD(rddId) => + doCleanupRDD(rddId, blocking = blockOnCleanupTasks) + case CleanShuffle(shuffleId) => + doCleanupShuffle(shuffleId, blocking = blockOnShuffleCleanupTasks) + case CleanBroadcast(broadcastId) => + doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks) + } } } } catch { + case ie: InterruptedException if stopped => // ignore case e: Exception => logError("Error in cleaning thread", e) } } diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala index a46a81eabd965..443830f8d03b6 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala @@ -19,24 +19,32 @@ package org.apache.spark /** * A client that communicates with the cluster manager to request or kill executors. + * This is currently supported only in YARN mode. */ private[spark] trait ExecutorAllocationClient { + /** + * Express a preference to the cluster manager for a given total number of executors. + * This can result in canceling pending requests or filing additional requests. + * @return whether the request is acknowledged by the cluster manager. + */ + private[spark] def requestTotalExecutors(numExecutors: Int): Boolean + /** * Request an additional number of executors from the cluster manager. - * Return whether the request is acknowledged by the cluster manager. + * @return whether the request is acknowledged by the cluster manager. */ def requestExecutors(numAdditionalExecutors: Int): Boolean /** * Request that the cluster manager kill the specified executors. - * Return whether the request is acknowledged by the cluster manager. + * @return whether the request is acknowledged by the cluster manager. */ def killExecutors(executorIds: Seq[String]): Boolean /** * Request that the cluster manager kill the specified executor. - * Return whether the request is acknowledged by the cluster manager. + * @return whether the request is acknowledged by the cluster manager. */ def killExecutor(executorId: String): Boolean = killExecutors(Seq(executorId)) } diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 5d5288bb6e60d..9385f557c4614 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -17,9 +17,12 @@ package org.apache.spark +import java.util.concurrent.{Executors, TimeUnit} + import scala.collection.mutable import org.apache.spark.scheduler._ +import org.apache.spark.util.{Clock, SystemClock, Utils} /** * An agent that dynamically allocates and removes executors based on the workload. @@ -76,15 +79,15 @@ private[spark] class ExecutorAllocationManager( private val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", Integer.MAX_VALUE) - // How long there must be backlogged tasks for before an addition is triggered + // How long there must be backlogged tasks for before an addition is triggered (seconds) private val schedulerBacklogTimeout = conf.getLong( - "spark.dynamicAllocation.schedulerBacklogTimeout", 60) + "spark.dynamicAllocation.schedulerBacklogTimeout", 5) // Same as above, but used only after `schedulerBacklogTimeout` is exceeded private val sustainedSchedulerBacklogTimeout = conf.getLong( "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", schedulerBacklogTimeout) - // How long an executor must be idle for before it is removed + // How long an executor must be idle for before it is removed (seconds) private val executorIdleTimeout = conf.getLong( "spark.dynamicAllocation.executorIdleTimeout", 600) @@ -123,11 +126,15 @@ private[spark] class ExecutorAllocationManager( private val intervalMillis: Long = 100 // Clock used to schedule when executors should be added and removed - private var clock: Clock = new RealClock + private var clock: Clock = new SystemClock() // Listener for Spark events that impact the allocation policy private val listener = new ExecutorAllocationListener + // Executor that handles the scheduling task. + private val executor = Executors.newSingleThreadScheduledExecutor( + Utils.namedThreadFactory("spark-dynamic-executor-allocation")) + /** * Verify that the settings specified through the config are valid. * If not, throw an appropriate exception. @@ -172,47 +179,55 @@ private[spark] class ExecutorAllocationManager( } /** - * Register for scheduler callbacks to decide when to add and remove executors. + * Register for scheduler callbacks to decide when to add and remove executors, and start + * the scheduling task. */ def start(): Unit = { listenerBus.addListener(listener) - startPolling() + + val scheduleTask = new Runnable() { + override def run(): Unit = Utils.logUncaughtExceptions(schedule()) + } + executor.scheduleAtFixedRate(scheduleTask, 0, intervalMillis, TimeUnit.MILLISECONDS) } /** - * Start the main polling thread that keeps track of when to add and remove executors. + * Stop the allocation manager. */ - private def startPolling(): Unit = { - val t = new Thread { - override def run(): Unit = { - while (true) { - try { - schedule() - } catch { - case e: Exception => logError("Exception in dynamic executor allocation thread!", e) - } - Thread.sleep(intervalMillis) - } - } - } - t.setName("spark-dynamic-executor-allocation") - t.setDaemon(true) - t.start() + def stop(): Unit = { + executor.shutdown() + executor.awaitTermination(10, TimeUnit.SECONDS) } /** - * If the add time has expired, request new executors and refresh the add time. - * If the remove time for an existing executor has expired, kill the executor. + * The number of executors we would have if the cluster manager were to fulfill all our existing + * requests. + */ + private def targetNumExecutors(): Int = + numExecutorsPending + executorIds.size - executorsPendingToRemove.size + + /** + * The maximum number of executors we would need under the current load to satisfy all running + * and pending tasks, rounded up. + */ + private def maxNumExecutorsNeeded(): Int = { + val numRunningOrPendingTasks = listener.totalPendingTasks + listener.totalRunningTasks + (numRunningOrPendingTasks + tasksPerExecutor - 1) / tasksPerExecutor + } + + /** + * This is called at a fixed interval to regulate the number of pending executor requests + * and number of executors running. + * + * First, adjust our requested executors based on the add time and our current needs. + * Then, if the remove time for an existing executor has expired, kill the executor. + * * This is factored out into its own method for testing. */ private def schedule(): Unit = synchronized { val now = clock.getTimeMillis - if (addTime != NOT_SET && now >= addTime) { - addExecutors() - logDebug(s"Starting timer to add more executors (to " + - s"expire in $sustainedSchedulerBacklogTimeout seconds)") - addTime += sustainedSchedulerBacklogTimeout * 1000 - } + + addOrCancelExecutorRequests(now) removeTimes.retain { case (executorId, expireTime) => val expired = now >= expireTime @@ -223,59 +238,89 @@ private[spark] class ExecutorAllocationManager( } } + /** + * Check to see whether our existing allocation and the requests we've made previously exceed our + * current needs. If so, let the cluster manager know so that it can cancel pending requests that + * are unneeded. + * + * If not, and the add time has expired, see if we can request new executors and refresh the add + * time. + * + * @return the delta in the target number of executors. + */ + private def addOrCancelExecutorRequests(now: Long): Int = synchronized { + val currentTarget = targetNumExecutors + val maxNeeded = maxNumExecutorsNeeded + + if (maxNeeded < currentTarget) { + // The target number exceeds the number we actually need, so stop adding new + // executors and inform the cluster manager to cancel the extra pending requests. + val newTotalExecutors = math.max(maxNeeded, minNumExecutors) + client.requestTotalExecutors(newTotalExecutors) + numExecutorsToAdd = 1 + updateNumExecutorsPending(newTotalExecutors) + } else if (addTime != NOT_SET && now >= addTime) { + val delta = addExecutors(maxNeeded) + logDebug(s"Starting timer to add more executors (to " + + s"expire in $sustainedSchedulerBacklogTimeout seconds)") + addTime += sustainedSchedulerBacklogTimeout * 1000 + delta + } else { + 0 + } + } + /** * Request a number of executors from the cluster manager. * If the cap on the number of executors is reached, give up and reset the * number of executors to add next round instead of continuing to double it. - * Return the number actually requested. + * + * @param maxNumExecutorsNeeded the maximum number of executors all currently running or pending + * tasks could fill + * @return the number of additional executors actually requested. */ - private def addExecutors(): Int = synchronized { - // Do not request more executors if we have already reached the upper bound - val numExistingExecutors = executorIds.size + numExecutorsPending - if (numExistingExecutors >= maxNumExecutors) { + private def addExecutors(maxNumExecutorsNeeded: Int): Int = { + // Do not request more executors if it would put our target over the upper bound + val currentTarget = targetNumExecutors + if (currentTarget >= maxNumExecutors) { logDebug(s"Not adding executors because there are already ${executorIds.size} " + s"registered and $numExecutorsPending pending executor(s) (limit $maxNumExecutors)") numExecutorsToAdd = 1 return 0 } - // The number of executors needed to satisfy all pending tasks is the number of tasks pending - // divided by the number of tasks each executor can fit, rounded up. - val maxNumExecutorsPending = - (listener.totalPendingTasks() + tasksPerExecutor - 1) / tasksPerExecutor - if (numExecutorsPending >= maxNumExecutorsPending) { - logDebug(s"Not adding executors because there are already $numExecutorsPending " + - s"pending and pending tasks could only fill $maxNumExecutorsPending") - numExecutorsToAdd = 1 - return 0 - } - - // It's never useful to request more executors than could satisfy all the pending tasks, so - // cap request at that amount. - // Also cap request with respect to the configured upper bound. - val maxNumExecutorsToAdd = math.min( - maxNumExecutorsPending - numExecutorsPending, - maxNumExecutors - numExistingExecutors) - assert(maxNumExecutorsToAdd > 0) - - val actualNumExecutorsToAdd = math.min(numExecutorsToAdd, maxNumExecutorsToAdd) - - val newTotalExecutors = numExistingExecutors + actualNumExecutorsToAdd - val addRequestAcknowledged = testing || client.requestExecutors(actualNumExecutorsToAdd) + val actualMaxNumExecutors = math.min(maxNumExecutors, maxNumExecutorsNeeded) + val newTotalExecutors = math.min(currentTarget + numExecutorsToAdd, actualMaxNumExecutors) + val addRequestAcknowledged = testing || client.requestTotalExecutors(newTotalExecutors) if (addRequestAcknowledged) { - logInfo(s"Requesting $actualNumExecutorsToAdd new executor(s) because " + - s"tasks are backlogged (new desired total will be $newTotalExecutors)") - numExecutorsToAdd = - if (actualNumExecutorsToAdd == numExecutorsToAdd) numExecutorsToAdd * 2 else 1 - numExecutorsPending += actualNumExecutorsToAdd - actualNumExecutorsToAdd + val delta = updateNumExecutorsPending(newTotalExecutors) + logInfo(s"Requesting $delta new executor(s) because tasks are backlogged" + + s" (new desired total will be $newTotalExecutors)") + numExecutorsToAdd = if (delta == numExecutorsToAdd) { + numExecutorsToAdd * 2 + } else { + 1 + } + delta } else { - logWarning(s"Unable to reach the cluster manager " + - s"to request $actualNumExecutorsToAdd executors!") + logWarning( + s"Unable to reach the cluster manager to request $newTotalExecutors total executors!") 0 } } + /** + * Given the new target number of executors, update the number of pending executor requests, + * and return the delta from the old number of pending requests. + */ + private def updateNumExecutorsPending(newTotalExecutors: Int): Int = { + val newNumExecutorsPending = + newTotalExecutors - executorIds.size + executorsPendingToRemove.size + val delta = newNumExecutorsPending - numExecutorsPending + numExecutorsPending = newNumExecutorsPending + delta + } + /** * Request the cluster manager to remove the given executor. * Return whether the request is received. @@ -415,6 +460,8 @@ private[spark] class ExecutorAllocationManager( private val stageIdToNumTasks = new mutable.HashMap[Int, Int] private val stageIdToTaskIndices = new mutable.HashMap[Int, mutable.HashSet[Int]] private val executorIdToTaskIds = new mutable.HashMap[String, mutable.HashSet[Long]] + // Number of tasks currently running on the cluster. Should be 0 when no stages are active. + private var numRunningTasks: Int = _ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { val stageId = stageSubmitted.stageInfo.stageId @@ -435,6 +482,10 @@ private[spark] class ExecutorAllocationManager( // This is needed in case the stage is aborted for any reason if (stageIdToNumTasks.isEmpty) { allocationManager.onSchedulerQueueEmpty() + if (numRunningTasks != 0) { + logWarning("No stages are running, but numRunningTasks != 0") + numRunningTasks = 0 + } } } } @@ -446,6 +497,7 @@ private[spark] class ExecutorAllocationManager( val executorId = taskStart.taskInfo.executorId allocationManager.synchronized { + numRunningTasks += 1 // This guards against the race condition in which the `SparkListenerTaskStart` // event is posted before the `SparkListenerBlockManagerAdded` event, which is // possible because these events are posted in different threads. (see SPARK-4951) @@ -475,7 +527,8 @@ private[spark] class ExecutorAllocationManager( val executorId = taskEnd.taskInfo.executorId val taskId = taskEnd.taskInfo.taskId allocationManager.synchronized { - // If the executor is no longer running scheduled any tasks, mark it as idle + numRunningTasks -= 1 + // If the executor is no longer running any scheduled tasks, mark it as idle if (executorIdToTaskIds.contains(executorId)) { executorIdToTaskIds(executorId) -= taskId if (executorIdToTaskIds(executorId).isEmpty) { @@ -486,8 +539,8 @@ private[spark] class ExecutorAllocationManager( } } - override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = { - val executorId = blockManagerAdded.blockManagerId.executorId + override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = { + val executorId = executorAdded.executorId if (executorId != SparkContext.DRIVER_IDENTIFIER) { // This guards against the race condition in which the `SparkListenerTaskStart` // event is posted before the `SparkListenerBlockManagerAdded` event, which is @@ -498,9 +551,8 @@ private[spark] class ExecutorAllocationManager( } } - override def onBlockManagerRemoved( - blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit = { - allocationManager.onExecutorRemoved(blockManagerRemoved.blockManagerId.executorId) + override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = { + allocationManager.onExecutorRemoved(executorRemoved.executorId) } /** @@ -515,6 +567,11 @@ private[spark] class ExecutorAllocationManager( }.sum } + /** + * The number of tasks currently running across all stages. + */ + def totalRunningTasks(): Int = numRunningTasks + /** * Return true if an executor is not currently running a task, and false otherwise. * @@ -530,28 +587,3 @@ private[spark] class ExecutorAllocationManager( private object ExecutorAllocationManager { val NOT_SET = Long.MaxValue } - -/** - * An abstract clock for measuring elapsed time. - */ -private trait Clock { - def getTimeMillis: Long -} - -/** - * A clock backed by a monotonically increasing time source. - * The time returned by this clock does not correspond to any notion of wall-clock time. - */ -private class RealClock extends Clock { - override def getTimeMillis: Long = System.nanoTime / (1000 * 1000) -} - -/** - * A clock that allows the caller to customize the time. - * This is used mainly for testing. - */ -private class TestClock(startTimeMillis: Long) extends Clock { - private var time: Long = startTimeMillis - override def getTimeMillis: Long = time - def tick(ms: Long): Unit = { time += ms } -} diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 88d35a4bacc6e..3653f724ba192 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.io.Text import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.network.sasl.SecretKeyHolder +import org.apache.spark.util.Utils /** * Spark class responsible for security. @@ -203,7 +204,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) // always add the current user and SPARK_USER to the viewAcls private val defaultAclUsers = Set[String](System.getProperty("user.name", ""), - Option(System.getenv("SPARK_USER")).getOrElse("")).filter(!_.isEmpty) + Utils.getCurrentUserName()) setViewAcls(defaultAclUsers, sparkConf.get("spark.ui.view.acls", "")) setModifyAcls(defaultAclUsers, sparkConf.get("spark.modify.acls", "")) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 13aa9960ac33a..f3de467883f85 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ import scala.collection.mutable.LinkedHashSet @@ -285,7 +286,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { // Validate memory fractions val memoryKeys = Seq( "spark.storage.memoryFraction", - "spark.shuffle.memoryFraction", + "spark.shuffle.memoryFraction", "spark.shuffle.safetyFraction", "spark.storage.unrollFraction", "spark.storage.safetyFraction") @@ -342,6 +343,13 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } } } + + // Warn against the use of deprecated configs + deprecatedConfigs.values.foreach { dc => + if (contains(dc.oldName)) { + dc.warn() + } + } } /** @@ -351,9 +359,20 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def toDebugString: String = { getAll.sorted.map{case (k, v) => k + "=" + v}.mkString("\n") } + } -private[spark] object SparkConf { +private[spark] object SparkConf extends Logging { + + private val deprecatedConfigs: Map[String, DeprecatedConfig] = { + val configs = Seq( + DeprecatedConfig("spark.files.userClassPathFirst", "spark.executor.userClassPathFirst", + "1.3"), + DeprecatedConfig("spark.yarn.user.classpath.first", null, "1.3", + "Use spark.{driver,executor}.userClassPathFirst instead.")) + configs.map { x => (x.oldName, x) }.toMap + } + /** * Return whether the given config is an akka config (e.g. akka.actor.provider). * Note that this does not include spark-specific akka configs (e.g. spark.akka.timeout). @@ -380,4 +399,63 @@ private[spark] object SparkConf { def isSparkPortConf(name: String): Boolean = { (name.startsWith("spark.") && name.endsWith(".port")) || name.startsWith("spark.port.") } + + /** + * Translate the configuration key if it is deprecated and has a replacement, otherwise just + * returns the provided key. + * + * @param userKey Configuration key from the user / caller. + * @param warn Whether to print a warning if the key is deprecated. Warnings will be printed + * only once for each key. + */ + def translateConfKey(userKey: String, warn: Boolean = false): String = { + deprecatedConfigs.get(userKey) + .map { deprecatedKey => + if (warn) { + deprecatedKey.warn() + } + deprecatedKey.newName.getOrElse(userKey) + }.getOrElse(userKey) + } + + /** + * Holds information about keys that have been deprecated or renamed. + * + * @param oldName Old configuration key. + * @param newName New configuration key, or `null` if key has no replacement, in which case the + * deprecated key will be used (but the warning message will still be printed). + * @param version Version of Spark where key was deprecated. + * @param deprecationMessage Message to include in the deprecation warning; mandatory when + * `newName` is not provided. + */ + private case class DeprecatedConfig( + oldName: String, + _newName: String, + version: String, + deprecationMessage: String = null) { + + private val warned = new AtomicBoolean(false) + val newName = Option(_newName) + + if (newName == null && (deprecationMessage == null || deprecationMessage.isEmpty())) { + throw new IllegalArgumentException("Need new config name or deprecation message.") + } + + def warn(): Unit = { + if (warned.compareAndSet(false, true)) { + if (newName != null) { + val message = Option(deprecationMessage).getOrElse( + s"Please use the alternative '$newName' instead.") + logWarning( + s"The configuration option '$oldName' has been replaced as of Spark $version and " + + s"may be removed in the future. $message") + } else { + logWarning( + s"The configuration option '$oldName' has been deprecated as of Spark $version and " + + s"may be removed in the future. $deprecationMessage") + } + } + } + + } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 6a16a31654630..495227b201e97 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -20,33 +20,43 @@ package org.apache.spark import scala.language.implicitConversions import java.io._ +import java.lang.reflect.Constructor import java.net.URI import java.util.{Arrays, Properties, UUID} import java.util.concurrent.atomic.AtomicInteger import java.util.UUID.randomUUID + import scala.collection.{Map, Set} import scala.collection.JavaConversions._ import scala.collection.generic.Growable import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} + +import akka.actor.Props + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable} -import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat, TextInputFormat} +import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, + FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable} +import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat, + TextInputFormat} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHadoopJob} import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} + import org.apache.mesos.MesosNativeLibrary -import akka.actor.Props import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} import org.apache.spark.executor.TriggerThreadDump -import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat} +import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, + FixedLengthBinaryInputFormat} +import org.apache.spark.io.CompressionCodec import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkDeploySchedulerBackend, SimrSchedulerBackend} +import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, + SparkDeploySchedulerBackend, SimrSchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage._ @@ -182,7 +192,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // log out Spark Version in Spark driver log logInfo(s"Running Spark version $SPARK_VERSION") - + private[spark] val conf = config.clone() conf.validateSettings() @@ -217,9 +227,19 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val appName = conf.get("spark.app.name") private[spark] val isEventLogEnabled = conf.getBoolean("spark.eventLog.enabled", false) - private[spark] val eventLogDir: Option[String] = { + private[spark] val eventLogDir: Option[URI] = { if (isEventLogEnabled) { - Some(conf.get("spark.eventLog.dir", EventLoggingListener.DEFAULT_LOG_DIR).stripSuffix("/")) + val unresolvedDir = conf.get("spark.eventLog.dir", EventLoggingListener.DEFAULT_LOG_DIR) + .stripSuffix("/") + Some(Utils.resolveURI(unresolvedDir)) + } else { + None + } + } + private[spark] val eventLogCodec: Option[String] = { + val compress = conf.getBoolean("spark.eventLog.compress", false) + if (compress && isEventLogEnabled) { + Some(CompressionCodec.getCodecName(conf)).map(CompressionCodec.getShortName) } else { None } @@ -240,7 +260,16 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER) // Create the Spark execution environment (cache, map output tracker, etc) - private[spark] val env = SparkEnv.createDriverEnv(conf, isLocal, listenerBus) + + // This function allows components created by SparkEnv to be mocked in unit tests: + private[spark] def createSparkEnv( + conf: SparkConf, + isLocal: Boolean, + listenerBus: LiveListenerBus): SparkEnv = { + SparkEnv.createDriverEnv(conf, isLocal, listenerBus) + } + + private[spark] val env = createSparkEnv(conf, isLocal, listenerBus) SparkEnv.set(env) // Used to store a URL for each static file/jar together with the file's local timestamp @@ -279,7 +308,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // the bound port to the cluster manager properly ui.foreach(_.bind()) - /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ + /** + * A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. + * + * '''Note:''' As it will be reused in all Hadoop RDDs, it's better not to modify it unless you + * plan to set some global configurations for all Hadoop RDDs. + */ val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(conf) // Add each JAR given through the constructor @@ -321,11 +355,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli executorEnvs ++= conf.getExecutorEnv // Set SPARK_USER for user who is running SparkContext. - val sparkUser = Option { - Option(System.getenv("SPARK_USER")).getOrElse(System.getProperty("user.name")) - }.getOrElse { - SparkContext.SPARK_UNKNOWN_USER - } + val sparkUser = Utils.getCurrentUserName() executorEnvs("SPARK_USER") = sparkUser // Create and start the scheduler @@ -387,9 +417,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } executorAllocationManager.foreach(_.start()) - // At this point, all relevant SparkListeners have been registered, so begin releasing events - listenerBus.start() - private[spark] val cleaner: Option[ContextCleaner] = { if (conf.getBoolean("spark.cleaner.referenceTracking", true)) { Some(new ContextCleaner(this)) @@ -399,6 +426,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } cleaner.foreach(_.start()) + setupAndStartListenerBus() postEnvironmentUpdate() postApplicationStart() @@ -407,6 +435,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Thread Local variable that can be used by users to pass information down the stack private val localProperties = new InheritableThreadLocal[Properties] { override protected def childValue(parent: Properties): Properties = new Properties(parent) + override protected def initialValue(): Properties = new Properties() } /** @@ -448,9 +477,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Spark fair scheduler pool. */ def setLocalProperty(key: String, value: String) { - if (localProperties.get() == null) { - localProperties.set(new Properties()) - } if (value == null) { localProperties.get.remove(key) } else { @@ -657,6 +683,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * * Load data from a flat binary file, assuming the length of each record is constant. * + * '''Note:''' We ensure that the byte array for each record in the resulting RDD + * has the provided record length. + * * @param path Directory to the input data files * @param recordLength The length at which to split the records * @return An RDD of data with values, represented as byte arrays @@ -671,7 +700,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli classOf[LongWritable], classOf[BytesWritable], conf=conf) - val data = br.map{ case (k, v) => v.getBytes} + val data = br.map { case (k, v) => + val bytes = v.getBytes + assert(bytes.length == recordLength, "Byte array does not have correct length") + bytes + } data } @@ -680,7 +713,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable), * using the older MapReduce API (`org.apache.hadoop.mapred`). * - * @param conf JobConf for setting up the dataset + * @param conf JobConf for setting up the dataset. Note: This will be put into a Broadcast. + * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make + * sure you won't modify the conf. A safe approach is always creating a new conf for + * a new RDD. * @param inputFormatClass Class of the InputFormat * @param keyClass Class of the keys * @param valueClass Class of the values @@ -804,7 +840,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli vClass: Class[V], conf: Configuration = hadoopConfiguration): RDD[(K, V)] = { assertNotStopped() - // The call to new NewHadoopJob automatically adds security credentials to conf, + // The call to new NewHadoopJob automatically adds security credentials to conf, // so we don't need to explicitly add them ourselves val job = new NewHadoopJob(conf) NewFileInputFormat.addInputPath(job, new Path(path)) @@ -816,6 +852,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat * and extra configuration options to pass to the input format. * + * @param conf Configuration for setting up the dataset. Note: This will be put into a Broadcast. + * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make + * sure you won't modify the conf. A safe approach is always creating a new conf for + * a new RDD. + * @param fClass Class of the InputFormat + * @param kClass Class of the keys + * @param vClass Class of the values + * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. @@ -926,11 +970,18 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } /** Build the union of a list of RDDs. */ - def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds) + def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = { + val partitioners = rdds.flatMap(_.partitioner).toSet + if (partitioners.size == 1) { + new PartitionerAwareUnionRDD(this, rdds) + } else { + new UnionRDD(this, rdds) + } + } /** Build the union of a list of RDDs passed as variable-length arguments. */ def union[T: ClassTag](first: RDD[T], rest: RDD[T]*): RDD[T] = - new UnionRDD(this, Seq(first) ++ rest) + union(Seq(first) ++ rest) /** Get an RDD that has no partitions or elements. */ def emptyRDD[T: ClassTag] = new EmptyRDD[T](this) @@ -1010,12 +1061,48 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, * use `SparkFiles.get(fileName)` to find its download location. */ - def addFile(path: String) { + def addFile(path: String): Unit = { + addFile(path, false) + } + + /** + * Add a file to be downloaded with this Spark job on every node. + * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported + * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, + * use `SparkFiles.get(fileName)` to find its download location. + * + * A directory can be given if the recursive option is set to true. Currently directories are only + * supported for Hadoop-supported filesystems. + */ + def addFile(path: String, recursive: Boolean): Unit = { val uri = new URI(path) - val key = uri.getScheme match { - case null | "file" => env.httpFileServer.addFile(new File(uri.getPath)) - case "local" => "file:" + uri.getPath - case _ => path + val schemeCorrectedPath = uri.getScheme match { + case null | "local" => new File(path).getCanonicalFile.toURI.toString + case _ => path + } + + val hadoopPath = new Path(schemeCorrectedPath) + val scheme = new URI(schemeCorrectedPath).getScheme + if (!Array("http", "https", "ftp").contains(scheme)) { + val fs = hadoopPath.getFileSystem(hadoopConfiguration) + if (!fs.exists(hadoopPath)) { + throw new FileNotFoundException(s"Added file $hadoopPath does not exist.") + } + val isDir = fs.isDirectory(hadoopPath) + if (!isLocal && scheme == "file" && isDir) { + throw new SparkException(s"addFile does not support local directories when not running " + + "local mode.") + } + if (!recursive && isDir) { + throw new SparkException(s"Added file $hadoopPath is a directory and recursive is not " + + "turned on.") + } + } + + val key = if (!isLocal && scheme == "file") { + env.httpFileServer.addFile(new File(uri.getPath)) + } else { + schemeCorrectedPath } val timestamp = System.currentTimeMillis addedFiles(key) = timestamp @@ -1037,10 +1124,27 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli listenerBus.addListener(listener) } + /** + * Express a preference to the cluster manager for a given total number of executors. + * This can result in canceling pending requests or filing additional requests. + * This is currently only supported in YARN mode. Return whether the request is received. + */ + private[spark] override def requestTotalExecutors(numExecutors: Int): Boolean = { + assert(master.contains("yarn") || dynamicAllocationTesting, + "Requesting executors is currently only supported in YARN mode") + schedulerBackend match { + case b: CoarseGrainedSchedulerBackend => + b.requestTotalExecutors(numExecutors) + case _ => + logWarning("Requesting executors is only supported in coarse-grained mode") + false + } + } + /** * :: DeveloperApi :: * Request an additional number of executors from the cluster manager. - * This is currently only supported in Yarn mode. Return whether the request is received. + * This is currently only supported in YARN mode. Return whether the request is received. */ @DeveloperApi override def requestExecutors(numAdditionalExecutors: Int): Boolean = { @@ -1058,7 +1162,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** * :: DeveloperApi :: * Request that the cluster manager kill the specified executors. - * This is currently only supported in Yarn mode. Return whether the request is received. + * This is currently only supported in YARN mode. Return whether the request is received. */ @DeveloperApi override def killExecutors(executorIds: Seq[String]): Boolean = { @@ -1224,7 +1328,19 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli null } } else { - env.httpFileServer.addJar(new File(uri.getPath)) + try { + env.httpFileServer.addJar(new File(uri.getPath)) + } catch { + case exc: FileNotFoundException => + logError(s"Jar not found at $path") + null + case e: Exception => + // For now just log an error but allow to go through so spark examples work. + // The spark examples don't really need the jar distributed since its also + // the app jar. + logError("Error adding jar (" + e + "), was the --addJars option used?") + null + } } // A JAR file which exists locally on every worker node case "local" => @@ -1253,22 +1369,24 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** Shut down the SparkContext. */ def stop() { SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { - postApplicationEnd() - ui.foreach(_.stop()) if (!stopped) { stopped = true + postApplicationEnd() + ui.foreach(_.stop()) env.metricsSystem.report() metadataCleaner.cancel() - env.actorSystem.stop(heartbeatReceiver) cleaner.foreach(_.stop()) + executorAllocationManager.foreach(_.stop()) dagScheduler.stop() dagScheduler = null + listenerBus.stop() + eventLogger.foreach(_.stop()) + env.actorSystem.stop(heartbeatReceiver) + progressBar.foreach(_.stop()) taskScheduler = null // TODO: Cache.stop()? env.stop() SparkEnv.set(null) - listenerBus.stop() - eventLogger.foreach(_.stop()) logInfo("Successfully stopped SparkContext") SparkContext.clearActiveContext() } else { @@ -1342,6 +1460,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val callSite = getCallSite val cleanedFunc = clean(func) logInfo("Starting job: " + callSite.shortForm) + if (conf.getBoolean("spark.logLineage", false)) { + logInfo("RDD's recursive dependencies:\n" + rdd.toDebugString) + } dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal, resultHandler, localProperties.get) progressBar.foreach(_.finishAll()) @@ -1528,8 +1649,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli @deprecated("use defaultMinPartitions", "1.0.0") def defaultMinSplits: Int = math.min(defaultParallelism, 2) - /** - * Default min number of partitions for Hadoop RDDs when not given by user + /** + * Default min number of partitions for Hadoop RDDs when not given by user * Notice that we use math.min so the "defaultMinPartitions" cannot be higher than 2. * The reasons for this are discussed in https://github.com/mesos/spark/pull/718 */ @@ -1544,6 +1665,58 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** Register a new RDD, returning its RDD ID */ private[spark] def newRddId(): Int = nextRddId.getAndIncrement() + /** + * Registers listeners specified in spark.extraListeners, then starts the listener bus. + * This should be called after all internal listeners have been registered with the listener bus + * (e.g. after the web UI and event logging listeners have been registered). + */ + private def setupAndStartListenerBus(): Unit = { + // Use reflection to instantiate listeners specified via `spark.extraListeners` + try { + val listenerClassNames: Seq[String] = + conf.get("spark.extraListeners", "").split(',').map(_.trim).filter(_ != "") + for (className <- listenerClassNames) { + // Use reflection to find the right constructor + val constructors = { + val listenerClass = Class.forName(className) + listenerClass.getConstructors.asInstanceOf[Array[Constructor[_ <: SparkListener]]] + } + val constructorTakingSparkConf = constructors.find { c => + c.getParameterTypes.sameElements(Array(classOf[SparkConf])) + } + lazy val zeroArgumentConstructor = constructors.find { c => + c.getParameterTypes.isEmpty + } + val listener: SparkListener = { + if (constructorTakingSparkConf.isDefined) { + constructorTakingSparkConf.get.newInstance(conf) + } else if (zeroArgumentConstructor.isDefined) { + zeroArgumentConstructor.get.newInstance() + } else { + throw new SparkException( + s"$className did not have a zero-argument constructor or a" + + " single-argument constructor that accepts SparkConf. Note: if the class is" + + " defined inside of another Scala class, then its constructors may accept an" + + " implicit parameter that references the enclosing class; in this case, you must" + + " define the listener as a top-level class in order to prevent this extra" + + " parameter from breaking Spark's ability to find a valid constructor.") + } + } + listenerBus.addListener(listener) + logInfo(s"Registered listener $className") + } + } catch { + case e: Exception => + try { + stop() + } finally { + throw new SparkException(s"Exception when registering SparkListener", e) + } + } + + listenerBus.start() + } + /** Post the application start event */ private def postApplicationStart() { // Note: this code assumes that the task scheduler has been initialized and has contacted @@ -1563,8 +1736,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val schedulingMode = getSchedulingMode.toString val addedJarPaths = addedJars.keys.toSeq val addedFilePaths = addedFiles.keys.toSeq - val environmentDetails = - SparkEnv.environmentDetails(conf, schedulingMode, addedJarPaths, addedFilePaths) + val environmentDetails = SparkEnv.environmentDetails(conf, schedulingMode, addedJarPaths, + addedFilePaths) val environmentUpdate = SparkListenerEnvironmentUpdate(environmentDetails) listenerBus.post(environmentUpdate) } @@ -1694,8 +1867,6 @@ object SparkContext extends Logging { private[spark] val SPARK_JOB_INTERRUPT_ON_CANCEL = "spark.job.interruptOnCancel" - private[spark] val SPARK_UNKNOWN_USER = "" - private[spark] val DRIVER_IDENTIFIER = "" // The following deprecated objects have already been copied to `object AccumulatorParam` to @@ -1749,8 +1920,14 @@ object SparkContext extends Logging { @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( - rdd: RDD[(K, V)]) = + rdd: RDD[(K, V)]) = { + val kf = implicitly[K => Writable] + val vf = implicitly[V => Writable] + // Set the Writable class to null and `SequenceFileRDDFunctions` will use Reflection to get it + implicit val keyWritableFactory = new WritableFactory[K](_ => null, kf) + implicit val valueWritableFactory = new WritableFactory[V](_ => null, vf) RDD.rddToSequenceFileRDDFunctions(rdd) + } @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") @@ -1767,20 +1944,35 @@ object SparkContext extends Logging { def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) = RDD.numericRDDToDoubleRDDFunctions(rdd) - // Implicit conversions to common Writable types, for saveAsSequenceFile + // The following deprecated functions have already been moved to `object WritableFactory` to + // make the compiler find them automatically. They are still kept here for backward compatibility. + @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") implicit def intToIntWritable(i: Int): IntWritable = new IntWritable(i) + @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") implicit def longToLongWritable(l: Long): LongWritable = new LongWritable(l) + @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") implicit def floatToFloatWritable(f: Float): FloatWritable = new FloatWritable(f) + @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") implicit def doubleToDoubleWritable(d: Double): DoubleWritable = new DoubleWritable(d) + @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") implicit def boolToBoolWritable (b: Boolean): BooleanWritable = new BooleanWritable(b) + @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") implicit def bytesToBytesWritable (aob: Array[Byte]): BytesWritable = new BytesWritable(aob) + @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") implicit def stringToText(s: String): Text = new Text(s) private implicit def arrayToArrayWritable[T <% Writable: ClassTag](arr: Traversable[T]) @@ -1959,7 +2151,7 @@ object SparkContext extends Logging { val scheduler = new TaskSchedulerImpl(sc) val localCluster = new LocalSparkCluster( - numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) + numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt, sc.conf) val masterUrls = localCluster.start() val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls) scheduler.initialize(backend) @@ -2070,7 +2262,7 @@ object WritableConverter { new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W])) } - // The following implicit functions were in SparkContext before 1.2 and users had to + // The following implicit functions were in SparkContext before 1.3 and users had to // `import SparkContext._` to enable them. Now we move them here to make the compiler find // them automatically. However, we still keep the old functions in SparkContext for backward // compatibility and forward to the following functions directly. @@ -2103,3 +2295,46 @@ object WritableConverter { implicit def writableWritableConverter[T <: Writable](): WritableConverter[T] = new WritableConverter[T](_.runtimeClass.asInstanceOf[Class[T]], _.asInstanceOf[T]) } + +/** + * A class encapsulating how to convert some type T to Writable. It stores both the Writable class + * corresponding to T (e.g. IntWritable for Int) and a function for doing the conversion. + * The Writable class will be used in `SequenceFileRDDFunctions`. + */ +private[spark] class WritableFactory[T]( + val writableClass: ClassTag[T] => Class[_ <: Writable], + val convert: T => Writable) extends Serializable + +object WritableFactory { + + private[spark] def simpleWritableFactory[T: ClassTag, W <: Writable : ClassTag](convert: T => W) + : WritableFactory[T] = { + val writableClass = implicitly[ClassTag[W]].runtimeClass.asInstanceOf[Class[W]] + new WritableFactory[T](_ => writableClass, convert) + } + + implicit def intWritableFactory: WritableFactory[Int] = + simpleWritableFactory(new IntWritable(_)) + + implicit def longWritableFactory: WritableFactory[Long] = + simpleWritableFactory(new LongWritable(_)) + + implicit def floatWritableFactory: WritableFactory[Float] = + simpleWritableFactory(new FloatWritable(_)) + + implicit def doubleWritableFactory: WritableFactory[Double] = + simpleWritableFactory(new DoubleWritable(_)) + + implicit def booleanWritableFactory: WritableFactory[Boolean] = + simpleWritableFactory(new BooleanWritable(_)) + + implicit def bytesWritableFactory: WritableFactory[Array[Byte]] = + simpleWritableFactory(new BytesWritable(_)) + + implicit def stringWritableFactory: WritableFactory[String] = + simpleWritableFactory(new Text(_)) + + implicit def writableWritableFactory[T <: Writable: ClassTag]: WritableFactory[T] = + simpleWritableFactory(w => w) + +} diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index f25db7f8de565..d95a176d18928 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -34,7 +34,8 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.nio.NioBlockTransferService -import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus} +import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorActor import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ @@ -67,6 +68,7 @@ class SparkEnv ( val sparkFilesDir: String, val metricsSystem: MetricsSystem, val shuffleMemoryManager: ShuffleMemoryManager, + val outputCommitCoordinator: OutputCommitCoordinator, val conf: SparkConf) extends Logging { private[spark] var isStopped = false @@ -86,6 +88,7 @@ class SparkEnv ( blockManager.stop() blockManager.master.stop() metricsSystem.stop() + outputCommitCoordinator.stop() actorSystem.shutdown() // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut // down, but let's call it anyway in case it gets fixed in a later release @@ -151,7 +154,8 @@ object SparkEnv extends Logging { private[spark] def createDriverEnv( conf: SparkConf, isLocal: Boolean, - listenerBus: LiveListenerBus): SparkEnv = { + listenerBus: LiveListenerBus, + mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { assert(conf.contains("spark.driver.host"), "spark.driver.host is not set on the driver!") assert(conf.contains("spark.driver.port"), "spark.driver.port is not set on the driver!") val hostname = conf.get("spark.driver.host") @@ -163,7 +167,8 @@ object SparkEnv extends Logging { port, isDriver = true, isLocal = isLocal, - listenerBus = listenerBus + listenerBus = listenerBus, + mockOutputCommitCoordinator = mockOutputCommitCoordinator ) } @@ -202,7 +207,8 @@ object SparkEnv extends Logging { isDriver: Boolean, isLocal: Boolean, listenerBus: LiveListenerBus = null, - numUsableCores: Int = 0): SparkEnv = { + numUsableCores: Int = 0, + mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { // Listener bus is only used on the driver if (isDriver) { @@ -350,6 +356,13 @@ object SparkEnv extends Logging { "levels using the RDD.persist() method instead.") } + val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse { + new OutputCommitCoordinator(conf) + } + val outputCommitCoordinatorActor = registerOrLookup("OutputCommitCoordinator", + new OutputCommitCoordinatorActor(outputCommitCoordinator)) + outputCommitCoordinator.coordinatorActor = Some(outputCommitCoordinatorActor) + new SparkEnv( executorId, actorSystem, @@ -366,6 +379,7 @@ object SparkEnv extends Logging { sparkFilesDir, metricsSystem, shuffleMemoryManager, + outputCommitCoordinator, conf) } diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 40237596570de..2ec42d3aea169 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -103,26 +103,11 @@ class SparkHadoopWriter(@transient jobConf: JobConf) } def commit() { - val taCtxt = getTaskContext() - val cmtr = getOutputCommitter() - if (cmtr.needsTaskCommit(taCtxt)) { - try { - cmtr.commitTask(taCtxt) - logInfo (taID + ": Committed") - } catch { - case e: IOException => { - logError("Error committing the output of task: " + taID.value, e) - cmtr.abortTask(taCtxt) - throw e - } - } - } else { - logInfo ("No need to commit output of task: " + taID.value) - } + SparkHadoopMapRedUtil.commitTask( + getOutputCommitter(), getTaskContext(), jobID, splitID, attemptID) } def commitJob() { - // always ? Or if cmtr.needsTaskCommit ? val cmtr = getOutputCommitter() cmtr.commitJob(getJobContext()) } diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala index edbdda8a0bcb6..34ee3a48f8e74 100644 --- a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala @@ -45,8 +45,7 @@ class SparkStatusTracker private[spark] (sc: SparkContext) { */ def getJobIdsForGroup(jobGroup: String): Array[Int] = { jobProgressListener.synchronized { - val jobData = jobProgressListener.jobIdToData.valuesIterator - jobData.filter(_.jobGroup.orNull == jobGroup).map(_.jobId).toArray + jobProgressListener.jobGroupToJobIds.getOrElse(jobGroup, Seq.empty).toArray } } diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index af5fd8e0ac00c..29a5cd5fdac76 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -146,6 +146,20 @@ case object TaskKilled extends TaskFailedReason { override def toErrorString: String = "TaskKilled (killed intentionally)" } +/** + * :: DeveloperApi :: + * Task requested the driver to commit, but was denied. + */ +@DeveloperApi +case class TaskCommitDenied( + jobID: Int, + partitionID: Int, + attemptID: Int) + extends TaskFailedReason { + override def toErrorString: String = s"TaskCommitDenied (Driver denied task commit)" + + s" for job: $jobID, partition: $partitionID, attempt: $attemptID" +} + /** * :: DeveloperApi :: * The task failed because the executor that it was running on was lost. This may happen because diff --git a/core/src/main/scala/org/apache/spark/TaskState.scala b/core/src/main/scala/org/apache/spark/TaskState.scala index 0bf1e4a5e2ccd..c415fe99b105e 100644 --- a/core/src/main/scala/org/apache/spark/TaskState.scala +++ b/core/src/main/scala/org/apache/spark/TaskState.scala @@ -27,6 +27,8 @@ private[spark] object TaskState extends Enumeration { type TaskState = Value + def isFailed(state: TaskState) = (LOST == state) || (FAILED == state) + def isFinished(state: TaskState) = FINISHED_STATES.contains(state) def toMesos(state: TaskState): MesosTaskState = state match { @@ -46,5 +48,6 @@ private[spark] object TaskState extends Enumeration { case MesosTaskState.TASK_FAILED => FAILED case MesosTaskState.TASK_KILLED => KILLED case MesosTaskState.TASK_LOST => LOST + case MesosTaskState.TASK_ERROR => LOST } } diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 34078142f5385..35b324ba6f573 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -17,12 +17,13 @@ package org.apache.spark -import java.io.{File, FileInputStream, FileOutputStream} +import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream} import java.net.{URI, URL} import java.util.jar.{JarEntry, JarOutputStream} import scala.collection.JavaConversions._ +import com.google.common.base.Charsets.UTF_8 import com.google.common.io.{ByteStreams, Files} import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} @@ -43,13 +44,38 @@ private[spark] object TestUtils { * Note: if this is used during class loader tests, class names should be unique * in order to avoid interference between tests. */ - def createJarWithClasses(classNames: Seq[String], value: String = ""): URL = { + def createJarWithClasses( + classNames: Seq[String], + toStringValue: String = "", + classNamesWithBase: Seq[(String, String)] = Seq(), + classpathUrls: Seq[URL] = Seq()): URL = { val tempDir = Utils.createTempDir() - val files = for (name <- classNames) yield createCompiledClass(name, tempDir, value) + val files1 = for (name <- classNames) yield { + createCompiledClass(name, tempDir, toStringValue, classpathUrls = classpathUrls) + } + val files2 = for ((childName, baseName) <- classNamesWithBase) yield { + createCompiledClass(childName, tempDir, toStringValue, baseName, classpathUrls) + } val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis())) - createJar(files, jarFile) + createJar(files1 ++ files2, jarFile) } + /** + * Create a jar file containing multiple files. The `files` map contains a mapping of + * file names in the jar file to their contents. + */ + def createJarWithFiles(files: Map[String, String], dir: File = null): URL = { + val tempDir = Option(dir).getOrElse(Utils.createTempDir()) + val jarFile = File.createTempFile("testJar", ".jar", tempDir) + val jarStream = new JarOutputStream(new FileOutputStream(jarFile)) + files.foreach { case (k, v) => + val entry = new JarEntry(k) + jarStream.putNextEntry(entry) + ByteStreams.copy(new ByteArrayInputStream(v.getBytes(UTF_8)), jarStream) + } + jarStream.close() + jarFile.toURI.toURL + } /** * Create a jar file that contains this set of files. All files will be located at the root @@ -85,15 +111,26 @@ private[spark] object TestUtils { } /** Creates a compiled class with the given name. Class file will be placed in destDir. */ - def createCompiledClass(className: String, destDir: File, value: String = ""): File = { + def createCompiledClass( + className: String, + destDir: File, + toStringValue: String = "", + baseClass: String = null, + classpathUrls: Seq[URL] = Seq()): File = { val compiler = ToolProvider.getSystemJavaCompiler + val extendsText = Option(baseClass).map { c => s" extends ${c}" }.getOrElse("") val sourceFile = new JavaSourceFromString(className, - "public class " + className + " implements java.io.Serializable {" + - " @Override public String toString() { return \"" + value + "\"; }}") + "public class " + className + extendsText + " implements java.io.Serializable {" + + " @Override public String toString() { return \"" + toStringValue + "\"; }}") // Calling this outputs a class file in pwd. It's easier to just rename the file than // build a custom FileManager that controls the output location. - compiler.getTask(null, null, null, null, null, Seq(sourceFile)).call() + val options = if (classpathUrls.nonEmpty) { + Seq("-classpath", classpathUrls.map { _.getFile }.mkString(File.pathSeparator)) + } else { + Seq() + } + compiler.getTask(null, null, null, options, null, Seq(sourceFile)).call() val fileName = className + ".class" val result = new File(fileName) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 8e8f7f6c4fda2..79e4ebf2db578 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -32,7 +32,8 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.StatCounter import org.apache.spark.util.Utils -class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, JavaDoubleRDD] { +class JavaDoubleRDD(val srdd: RDD[scala.Double]) + extends AbstractJavaRDDLike[JDouble, JavaDoubleRDD] { override val classTag: ClassTag[JDouble] = implicitly[ClassTag[JDouble]] diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 7af3538262fd6..4eadc9a85613e 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -44,7 +44,7 @@ import org.apache.spark.util.Utils class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) (implicit val kClassTag: ClassTag[K], implicit val vClassTag: ClassTag[V]) - extends JavaRDDLike[(K, V), JavaPairRDD[K, V]] { + extends AbstractJavaRDDLike[(K, V), JavaPairRDD[K, V]] { override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 86fb374bef1e3..645dc3bfb6b06 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -30,7 +30,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) - extends JavaRDDLike[T, JavaRDD[T]] { + extends AbstractJavaRDDLike[T, JavaRDD[T]] { override def wrapRDD(rdd: RDD[T]): JavaRDD[T] = JavaRDD.fromRDD(rdd) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 0f91c942ecd50..8da42934a7d96 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -38,6 +38,14 @@ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils +/** + * As a workaround for https://issues.scala-lang.org/browse/SI-8905, implementations + * of JavaRDDLike should extend this dummy abstract class instead of directly inheriting + * from the trait. See SPARK-3266 for additional details. + */ +private[spark] abstract class AbstractJavaRDDLike[T, This <: JavaRDDLike[T, This]] + extends JavaRDDLike[T, This] + /** * Defines operations common to several Java RDD implementations. * Note that this trait is not intended to be implemented by user code. diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 97f5c9f257e09..6d6ed693be752 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -373,6 +373,15 @@ class JavaSparkContext(val sc: SparkContext) * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable, * etc). * + * @param conf JobConf for setting up the dataset. Note: This will be put into a Broadcast. + * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make + * sure you won't modify the conf. A safe approach is always creating a new conf for + * a new RDD. + * @param inputFormatClass Class of the InputFormat + * @param keyClass Class of the keys + * @param valueClass Class of the values + * @param minPartitions Minimum number of Hadoop Splits to generate. + * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using @@ -395,6 +404,14 @@ class JavaSparkContext(val sc: SparkContext) * Get an RDD for a Hadoop-readable dataset from a Hadooop JobConf giving its InputFormat and any * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable, * + * @param conf JobConf for setting up the dataset. Note: This will be put into a Broadcast. + * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make + * sure you won't modify the conf. A safe approach is always creating a new conf for + * a new RDD. + * @param inputFormatClass Class of the InputFormat + * @param keyClass Class of the keys + * @param valueClass Class of the values + * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using @@ -476,6 +493,14 @@ class JavaSparkContext(val sc: SparkContext) * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat * and extra configuration options to pass to the input format. * + * @param conf Configuration for setting up the dataset. Note: This will be put into a Broadcast. + * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make + * sure you won't modify the conf. A safe approach is always creating a new conf for + * a new RDD. + * @param fClass Class of the InputFormat + * @param kClass Class of the keys + * @param vClass Class of the values + * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using @@ -675,6 +700,9 @@ class JavaSparkContext(val sc: SparkContext) /** * Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse. + * + * '''Note:''' As it will be reused in all Hadoop RDDs, it's better not to modify it unless you + * plan to set some global configurations for all Hadoop RDDs. */ def hadoopConfiguration(): Configuration = { sc.hadoopConfiguration diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala new file mode 100644 index 0000000000000..164e95081583f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import java.io.DataOutputStream +import java.net.Socket + +import py4j.GatewayServer + +import org.apache.spark.Logging +import org.apache.spark.util.Utils + +/** + * Process that starts a Py4J GatewayServer on an ephemeral port and communicates the bound port + * back to its caller via a callback port specified by the caller. + * + * This process is launched (via SparkSubmit) by the PySpark driver (see java_gateway.py). + */ +private[spark] object PythonGatewayServer extends Logging { + def main(args: Array[String]): Unit = Utils.tryOrExit { + // Start a GatewayServer on an ephemeral port + val gatewayServer: GatewayServer = new GatewayServer(null, 0) + gatewayServer.start() + val boundPort: Int = gatewayServer.getListeningPort + if (boundPort == -1) { + logError("GatewayServer failed to bind; exiting") + System.exit(1) + } else { + logDebug(s"Started PythonGatewayServer on port $boundPort") + } + + // Communicate the bound port back to the caller via the caller-specified callback port + val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST") + val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt + logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort") + val callbackSocket = new Socket(callbackHost, callbackPort) + val dos = new DataOutputStream(callbackSocket.getOutputStream) + dos.writeInt(boundPort) + dos.close() + callbackSocket.close() + + // Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies: + while (System.in.read() != -1) { + // Do nothing + } + logDebug("Exiting due to broken pipe from Python driver") + System.exit(0) + } +} diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index b89effc16d36d..eaff5e58122e2 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -19,26 +19,27 @@ package org.apache.spark.api.python import java.io._ import java.net._ -import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, UUID, Collections} - -import org.apache.spark.input.PortableDataStream +import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JMap} import scala.collection.JavaConversions._ import scala.collection.mutable import scala.language.existentials import com.google.common.base.Charsets.UTF_8 - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.compress.CompressionCodec -import org.apache.hadoop.mapred.{InputFormat, OutputFormat, JobConf} +import org.apache.hadoop.mapred.{InputFormat, JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat => NewOutputFormat} + import org.apache.spark._ -import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.broadcast.Broadcast +import org.apache.spark.input.PortableDataStream import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils +import scala.util.control.NonFatal + private[spark] class PythonRDD( @transient parent: RDD[_], command: Array[Byte], @@ -75,7 +76,6 @@ private[spark] class PythonRDD( context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() - writerThread.join() if (!reuse_worker || !released) { try { worker.close() @@ -219,14 +219,13 @@ private[spark] class PythonRDD( val oldBids = PythonRDD.getWorkerBroadcasts(worker) val newBids = broadcastVars.map(_.id).toSet // number of different broadcasts - val cnt = oldBids.diff(newBids).size + newBids.diff(oldBids).size + val toRemove = oldBids.diff(newBids) + val cnt = toRemove.size + newBids.diff(oldBids).size dataOut.writeInt(cnt) - for (bid <- oldBids) { - if (!newBids.contains(bid)) { - // remove the broadcast from worker - dataOut.writeLong(- bid - 1) // bid >= 0 - oldBids.remove(bid) - } + for (bid <- toRemove) { + // remove the broadcast from worker + dataOut.writeLong(- bid - 1) // bid >= 0 + oldBids.remove(bid) } for (broadcast <- broadcastVars) { if (!oldBids.contains(broadcast.id)) { @@ -248,13 +247,17 @@ private[spark] class PythonRDD( } catch { case e: Exception if context.isCompleted || context.isInterrupted => logDebug("Exception thrown after task completion (likely due to cleanup)", e) - worker.shutdownOutput() + if (!worker.isClosed) { + Utils.tryLog(worker.shutdownOutput()) + } case e: Exception => // We must avoid throwing exceptions here, because the thread uncaught exception handler // will kill the whole executor (see org.apache.spark.executor.Executor). _exception = e - worker.shutdownOutput() + if (!worker.isClosed) { + Utils.tryLog(worker.shutdownOutput()) + } } finally { // Release memory used by this thread for shuffles env.shuffleMemoryManager.releaseMemoryForThisThread() @@ -303,6 +306,7 @@ private class PythonException(msg: String, cause: Exception) extends RuntimeExce private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte])](prev) { override def getPartitions = prev.partitions + override val partitioner = prev.partitioner override def compute(split: Partition, context: TaskContext) = prev.iterator(split, context).grouped(2).map { case Seq(a, b) => (Utils.deserializeLongValue(a), b) @@ -329,24 +333,45 @@ private[spark] object PythonRDD extends Logging { } } + /** + * Return an RDD of values from an RDD of (Long, Array[Byte]), with preservePartitions=true + * + * This is useful for PySpark to have the partitioner after partitionBy() + */ + def valueOfPair(pair: JavaPairRDD[Long, Array[Byte]]): JavaRDD[Array[Byte]] = { + pair.rdd.mapPartitions(it => it.map(_._2), true) + } + /** * Adapter for calling SparkContext#runJob from Python. * - * This method will return an iterator of an array that contains all elements in the RDD + * This method will serve an iterator of an array that contains all elements in the RDD * (effectively a collect()), but allows you to run on a certain subset of partitions, * or to enable local execution. + * + * @return the port number of a local socket which serves the data collected from this job. */ def runJob( sc: SparkContext, rdd: JavaRDD[Array[Byte]], partitions: JArrayList[Int], - allowLocal: Boolean): Iterator[Array[Byte]] = { + allowLocal: Boolean): Int = { type ByteArray = Array[Byte] type UnrolledPartition = Array[ByteArray] val allPartitions: Array[UnrolledPartition] = sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal) val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*) - flattenedPartition.iterator + serveIterator(flattenedPartition.iterator, + s"serve RDD ${rdd.id} with partitions ${partitions.mkString(",")}") + } + + /** + * A helper function to collect an RDD as an iterator, then serve it via socket. + * + * @return the port number of a local socket which serves the data collected from this job. + */ + def collectAndServe[T](rdd: RDD[T]): Int = { + serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}") } def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): @@ -566,15 +591,43 @@ private[spark] object PythonRDD extends Logging { dataOut.write(bytes) } - def writeToFile[T](items: java.util.Iterator[T], filename: String) { - import scala.collection.JavaConverters._ - writeToFile(items.asScala, filename) - } + /** + * Create a socket server and a background thread to serve the data in `items`, + * + * The socket server can only accept one connection, or close if no connection + * in 3 seconds. + * + * Once a connection comes in, it tries to serialize all the data in `items` + * and send them into this connection. + * + * The thread will terminate after all the data are sent or any exceptions happen. + */ + private def serveIterator[T](items: Iterator[T], threadName: String): Int = { + val serverSocket = new ServerSocket(0, 1) + // Close the socket if no connection in 3 seconds + serverSocket.setSoTimeout(3000) + + new Thread(threadName) { + setDaemon(true) + override def run() { + try { + val sock = serverSocket.accept() + val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + try { + writeIteratorToStream(items, out) + } finally { + out.close() + } + } catch { + case NonFatal(e) => + logError(s"Error while sending iterator", e) + } finally { + serverSocket.close() + } + } + }.start() - def writeToFile[T](items: Iterator[T], filename: String) { - val file = new DataOutputStream(new FileOutputStream(filename)) - writeIteratorToStream(items, file) - file.close() + serverSocket.getLocalPort } private def getMergedConf(confAsMap: java.util.HashMap[String, String], diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index b7cfc8bd9c542..acbaba6791850 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -17,8 +17,10 @@ package org.apache.spark.api.python -import java.io.{File, InputStream, IOException, OutputStream} +import java.io.{File} +import java.util.{List => JList} +import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkContext @@ -44,4 +46,11 @@ private[spark] object PythonUtils { def generateRDDWithNull(sc: JavaSparkContext): JavaRDD[String] = { sc.parallelize(List("a", null, "b")) } + + /** + * Convert list of T into seq of T (for calling API with varargs) + */ + def toSeq[T](cols: JList[T]): Seq[T] = { + cols.toList.toSeq + } } diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala index c0cbd28a845be..cf289fb3ae39f 100644 --- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala +++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala @@ -107,7 +107,6 @@ private[python] class WritableToDoubleArrayConverter extends Converter[Any, Arra * given directory (probably a temp directory) */ object WriteInputFormatTestDataGenerator { - import SparkContext._ def main(args: Array[String]) { val path = args(0) diff --git a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala index ae55b4ff40b74..b7ae9c1fc0a23 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala @@ -17,13 +17,17 @@ package org.apache.spark.deploy +import java.net.URI + private[spark] class ApplicationDescription( val name: String, val maxCores: Option[Int], val memoryPerSlave: Int, val command: Command, var appUiUrl: String, - val eventLogDir: Option[String] = None) + val eventLogDir: Option[URI] = None, + // short name of compression codec used when writing event logs, if any (e.g. lzf) + val eventLogCodec: Option[String] = None) extends Serializable { val user = System.getProperty("user.name", "") @@ -34,8 +38,10 @@ private[spark] class ApplicationDescription( memoryPerSlave: Int = memoryPerSlave, command: Command = command, appUiUrl: String = appUiUrl, - eventLogDir: Option[String] = eventLogDir): ApplicationDescription = - new ApplicationDescription(name, maxCores, memoryPerSlave, command, appUiUrl, eventLogDir) + eventLogDir: Option[URI] = eventLogDir, + eventLogCodec: Option[String] = eventLogCodec): ApplicationDescription = + new ApplicationDescription( + name, maxCores, memoryPerSlave, command, appUiUrl, eventLogDir, eventLogCodec) override def toString: String = "ApplicationDescription(" + name + ")" } diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 38b3da0b13756..237d26fc6bd0e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -68,8 +68,9 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) .map(Utils.splitCommandString).getOrElse(Seq.empty) val sparkJavaOpts = Utils.sparkJavaOpts(conf) val javaOpts = sparkJavaOpts ++ extraJavaOpts - val command = new Command(mainClass, Seq("{{WORKER_URL}}", driverArgs.mainClass) ++ - driverArgs.driverOptions, sys.env, classPathEntries, libraryPathEntries, javaOpts) + val command = new Command(mainClass, + Seq("{{WORKER_URL}}", "{{USER_JAR}}", driverArgs.mainClass) ++ driverArgs.driverOptions, + sys.env, classPathEntries, libraryPathEntries, javaOpts) val driverDescription = new DriverDescription( driverArgs.jarUrl, diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index e5873ce724b9f..415bd50591692 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -29,8 +29,7 @@ import org.apache.spark.util.{IntParam, MemoryParam} * Command-line parser for the driver client. */ private[spark] class ClientArguments(args: Array[String]) { - val defaultCores = 1 - val defaultMemory = 512 + import ClientArguments._ var cmd: String = "" // 'launch' or 'kill' var logLevel = Level.WARN @@ -39,9 +38,9 @@ private[spark] class ClientArguments(args: Array[String]) { var master: String = "" var jarUrl: String = "" var mainClass: String = "" - var supervise: Boolean = false - var memory: Int = defaultMemory - var cores: Int = defaultCores + var supervise: Boolean = DEFAULT_SUPERVISE + var memory: Int = DEFAULT_MEMORY + var cores: Int = DEFAULT_CORES private var _driverOptions = ListBuffer[String]() def driverOptions = _driverOptions.toSeq @@ -50,7 +49,7 @@ private[spark] class ClientArguments(args: Array[String]) { parse(args.toList) - def parse(args: List[String]): Unit = args match { + private def parse(args: List[String]): Unit = args match { case ("--cores" | "-c") :: IntParam(value) :: tail => cores = value parse(tail) @@ -106,9 +105,10 @@ private[spark] class ClientArguments(args: Array[String]) { |Usage: DriverClient kill | |Options: - | -c CORES, --cores CORES Number of cores to request (default: $defaultCores) - | -m MEMORY, --memory MEMORY Megabytes of memory to request (default: $defaultMemory) + | -c CORES, --cores CORES Number of cores to request (default: $DEFAULT_CORES) + | -m MEMORY, --memory MEMORY Megabytes of memory to request (default: $DEFAULT_MEMORY) | -s, --supervise Whether to restart the driver on failure + | (default: $DEFAULT_SUPERVISE) | -v, --verbose Print more debugging output """.stripMargin System.err.println(usage) @@ -117,6 +117,10 @@ private[spark] class ClientArguments(args: Array[String]) { } object ClientArguments { + private[spark] val DEFAULT_CORES = 1 + private[spark] val DEFAULT_MEMORY = 512 // MB + private[spark] val DEFAULT_SUPERVISE = false + def isValidJarUrl(s: String): Boolean = { try { val uri = new URI(s) diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 243d8edb72ed3..7f600d89604a2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -148,15 +148,22 @@ private[deploy] object DeployMessages { // Master to MasterWebUI - case class MasterStateResponse(host: String, port: Int, workers: Array[WorkerInfo], - activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo], - activeDrivers: Array[DriverInfo], completedDrivers: Array[DriverInfo], - status: MasterState) { + case class MasterStateResponse( + host: String, + port: Int, + restPort: Option[Int], + workers: Array[WorkerInfo], + activeApps: Array[ApplicationInfo], + completedApps: Array[ApplicationInfo], + activeDrivers: Array[DriverInfo], + completedDrivers: Array[DriverInfo], + status: MasterState) { Utils.checkHost(host, "Required hostname") assert (port > 0) def uri = "spark://" + host + ":" + port + def restUri: Option[String] = restPort.map { p => "spark://" + host + ":" + p } } // WorkerWebUI to Worker diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 9a7a113c95715..3ab425aab84c8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -33,7 +33,11 @@ import org.apache.spark.util.Utils * fault recovery without spinning up a lot of processes. */ private[spark] -class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) +class LocalSparkCluster( + numWorkers: Int, + coresPerWorker: Int, + memoryPerWorker: Int, + conf: SparkConf) extends Logging { private val localHostname = Utils.localHostName() @@ -43,9 +47,11 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I def start(): Array[String] = { logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") + // Disable REST server on Master in this mode unless otherwise specified + val _conf = conf.clone().setIfMissing("spark.master.rest.enabled", "false") + /* Start the Master */ - val conf = new SparkConf(false) - val (masterSystem, masterPort, _) = Master.startSystemAndActor(localHostname, 0, 0, conf) + val (masterSystem, masterPort, _, _) = Master.startSystemAndActor(localHostname, 0, 0, _conf) masterActorSystems += masterSystem val masterUrl = "spark://" + localHostname + ":" + masterPort val masters = Array(masterUrl) @@ -53,7 +59,7 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I /* Start the Workers */ for (workerNum <- 1 to numWorkers) { val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker, - memoryPerWorker, masters, null, Some(workerNum)) + memoryPerWorker, masters, null, Some(workerNum), _conf) workerActorSystems += workerSystem } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index d68854214ef06..e0a32fb65cd51 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -21,7 +21,7 @@ import java.lang.reflect.Method import java.security.PrivilegedExceptionAction import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.fs.FileSystem.Statistics import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} @@ -52,18 +52,13 @@ class SparkHadoopUtil extends Logging { * do a FileSystem.closeAllForUGI in order to avoid leaking Filesystems */ def runAsSparkUser(func: () => Unit) { - val user = Option(System.getenv("SPARK_USER")).getOrElse(SparkContext.SPARK_UNKNOWN_USER) - if (user != SparkContext.SPARK_UNKNOWN_USER) { - logDebug("running as user: " + user) - val ugi = UserGroupInformation.createRemoteUser(user) - transferCredentials(UserGroupInformation.getCurrentUser(), ugi) - ugi.doAs(new PrivilegedExceptionAction[Unit] { - def run: Unit = func() - }) - } else { - logDebug("running as SPARK_UNKNOWN_USER") - func() - } + val user = Utils.getCurrentUserName() + logDebug("running as user: " + user) + val ugi = UserGroupInformation.createRemoteUser(user) + transferCredentials(UserGroupInformation.getCurrentUser(), ugi) + ugi.doAs(new PrivilegedExceptionAction[Unit] { + def run: Unit = func() + }) } def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) { @@ -191,6 +186,21 @@ class SparkHadoopUtil extends Logging { val method = context.getClass.getMethod("getConfiguration") method.invoke(context).asInstanceOf[Configuration] } + + /** + * Get [[FileStatus]] objects for all leaf children (files) under the given base path. If the + * given path points to a file, return a single-element collection containing [[FileStatus]] of + * that file. + */ + def listLeafStatuses(fs: FileSystem, basePath: Path): Seq[FileStatus] = { + def recurse(path: Path) = { + val (directories, leaves) = fs.listStatus(path).partition(_.isDir) + leaves ++ directories.flatMap(f => listLeafStatuses(fs, f.getPath)) + } + + val baseStatus = fs.getFileStatus(basePath) + if (baseStatus.isDir) recurse(basePath) else Array(baseStatus) + } } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 02021be9f93d4..4a74641f4e1fa 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -18,15 +18,37 @@ package org.apache.spark.deploy import java.io.{File, PrintStream} -import java.lang.reflect.{Modifier, InvocationTargetException} +import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException} import java.net.URL +import java.security.PrivilegedExceptionAction import scala.collection.mutable.{ArrayBuffer, HashMap, Map} import org.apache.hadoop.fs.Path +import org.apache.hadoop.security.UserGroupInformation +import org.apache.ivy.Ivy +import org.apache.ivy.core.LogOptions +import org.apache.ivy.core.module.descriptor._ +import org.apache.ivy.core.module.id.{ArtifactId, ModuleId, ModuleRevisionId} +import org.apache.ivy.core.report.ResolveReport +import org.apache.ivy.core.resolve.ResolveOptions +import org.apache.ivy.core.retrieve.RetrieveOptions +import org.apache.ivy.core.settings.IvySettings +import org.apache.ivy.plugins.matcher.GlobPatternMatcher +import org.apache.ivy.plugins.resolver.{ChainResolver, IBiblioResolver} + +import org.apache.spark.SPARK_VERSION +import org.apache.spark.deploy.rest._ +import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} -import org.apache.spark.executor.ExecutorURLClassLoader -import org.apache.spark.util.Utils +/** + * Whether to submit, kill, or request the status of an application. + * The latter two operations are currently supported only for standalone cluster mode. + */ +private[spark] object SparkSubmitAction extends Enumeration { + type SparkSubmitAction = Value + val SUBMIT, KILL, REQUEST_STATUS = Value +} /** * Main gateway of launching a Spark application. @@ -59,35 +81,127 @@ object SparkSubmit { private val CLASS_NOT_FOUND_EXIT_STATUS = 101 // Exposed for testing - private[spark] var exitFn: () => Unit = () => System.exit(-1) + private[spark] var exitFn: () => Unit = () => System.exit(1) private[spark] var printStream: PrintStream = System.err - private[spark] def printWarning(str: String) = printStream.println("Warning: " + str) - private[spark] def printErrorAndExit(str: String) = { + private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str) + private[spark] def printErrorAndExit(str: String): Unit = { printStream.println("Error: " + str) printStream.println("Run with --help for usage help or --verbose for debug output") exitFn() } + private[spark] def printVersionAndExit(): Unit = { + printStream.println("""Welcome to + ____ __ + / __/__ ___ _____/ /__ + _\ \/ _ \/ _ `/ __/ '_/ + /___/ .__/\_,_/_/ /_/\_\ version %s + /_/ + """.format(SPARK_VERSION)) + printStream.println("Type --help for more information.") + exitFn() + } - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val appArgs = new SparkSubmitArguments(args) if (appArgs.verbose) { printStream.println(appArgs) } - val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) - launch(childArgs, classpath, sysProps, mainClass, appArgs.verbose) + appArgs.action match { + case SparkSubmitAction.SUBMIT => submit(appArgs) + case SparkSubmitAction.KILL => kill(appArgs) + case SparkSubmitAction.REQUEST_STATUS => requestStatus(appArgs) + } + } + + /** Kill an existing submission using the REST protocol. Standalone cluster mode only. */ + private def kill(args: SparkSubmitArguments): Unit = { + new StandaloneRestClient() + .killSubmission(args.master, args.submissionToKill) } /** - * @return a tuple containing - * (1) the arguments for the child process, - * (2) a list of classpath entries for the child, - * (3) a list of system properties and env vars, and - * (4) the main class for the child + * Request the status of an existing submission using the REST protocol. + * Standalone cluster mode only. */ - private[spark] def createLaunchEnv(args: SparkSubmitArguments) - : (ArrayBuffer[String], ArrayBuffer[String], Map[String, String], String) = { + private def requestStatus(args: SparkSubmitArguments): Unit = { + new StandaloneRestClient() + .requestSubmissionStatus(args.master, args.submissionToRequestStatusFor) + } - // Values to return + /** + * Submit the application using the provided parameters. + * + * This runs in two steps. First, we prepare the launch environment by setting up + * the appropriate classpath, system properties, and application arguments for + * running the child main class based on the cluster manager and the deploy mode. + * Second, we use this launch environment to invoke the main method of the child + * main class. + */ + private[spark] def submit(args: SparkSubmitArguments): Unit = { + val (childArgs, childClasspath, sysProps, childMainClass) = prepareSubmitEnvironment(args) + + def doRunMain(): Unit = { + if (args.proxyUser != null) { + val proxyUser = UserGroupInformation.createProxyUser(args.proxyUser, + UserGroupInformation.getCurrentUser()) + try { + proxyUser.doAs(new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { + runMain(childArgs, childClasspath, sysProps, childMainClass, args.verbose) + } + }) + } catch { + case e: Exception => + // Hadoop's AuthorizationException suppresses the exception's stack trace, which + // makes the message printed to the output by the JVM not very helpful. Instead, + // detect exceptions with empty stack traces here, and treat them differently. + if (e.getStackTrace().length == 0) { + printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}") + exitFn() + } else { + throw e + } + } + } else { + runMain(childArgs, childClasspath, sysProps, childMainClass, args.verbose) + } + } + + // In standalone cluster mode, there are two submission gateways: + // (1) The traditional Akka gateway using o.a.s.deploy.Client as a wrapper + // (2) The new REST-based gateway introduced in Spark 1.3 + // The latter is the default behavior as of Spark 1.3, but Spark submit will fail over + // to use the legacy gateway if the master endpoint turns out to be not a REST server. + if (args.isStandaloneCluster && args.useRest) { + try { + printStream.println("Running Spark using the REST application submission protocol.") + doRunMain() + } catch { + // Fail over to use the legacy submission gateway + case e: SubmitRestConnectionException => + printWarning(s"Master endpoint ${args.master} was not a REST server. " + + "Falling back to legacy submission gateway instead.") + args.useRest = false + submit(args) + } + // In all other modes, just run the main class as prepared + } else { + doRunMain() + } + } + + /** + * Prepare the environment for submitting an application. + * This returns a 4-tuple: + * (1) the arguments for the child process, + * (2) a list of classpath entries for the child, + * (3) a map of system properties, and + * (4) the main class for the child + * Exposed for testing. + */ + private[spark] def prepareSubmitEnvironment(args: SparkSubmitArguments) + : (Seq[String], Seq[String], Map[String, String], String) = { + // Return values val childArgs = new ArrayBuffer[String]() val childClasspath = new ArrayBuffer[String]() val sysProps = new HashMap[String, String]() @@ -138,6 +252,26 @@ object SparkSubmit { val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER + // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files + // too for packages that include Python code + val resolvedMavenCoordinates = + SparkSubmitUtils.resolveMavenCoordinates( + args.packages, Option(args.repositories), Option(args.ivyRepoPath)) + if (!resolvedMavenCoordinates.trim.isEmpty) { + if (args.jars == null || args.jars.trim.isEmpty) { + args.jars = resolvedMavenCoordinates + } else { + args.jars += s",$resolvedMavenCoordinates" + } + if (args.isPython) { + if (args.pyFiles == null || args.pyFiles.trim.isEmpty) { + args.pyFiles = resolvedMavenCoordinates + } else { + args.pyFiles += s",$resolvedMavenCoordinates" + } + } + } + // Require all python files to be local, so we can add them to the PYTHONPATH // In YARN cluster mode, python files are distributed as regular files, which can be non-local if (args.isPython && !isYarnCluster) { @@ -169,8 +303,7 @@ object SparkSubmit { // If we're running a python app, set the main class to our specific python runner if (args.isPython && deployMode == CLIENT) { if (args.primaryResource == PYSPARK_SHELL) { - args.mainClass = "py4j.GatewayServer" - args.childArgs = ArrayBuffer("--die-on-broken-pipe", "0") + args.mainClass = "org.apache.spark.api.python.PythonGatewayServer" } else { // If a python file is provided, add it to the child arguments and list of files to deploy. // Usage: PythonAppRunner
[app arguments] @@ -202,6 +335,7 @@ object SparkSubmit { OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.master"), OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.app.name"), OptionAssigner(args.jars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars"), + OptionAssigner(args.ivyRepoPath, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars.ivy"), OptionAssigner(args.driverMemory, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.driver.memory"), OptionAssigner(args.driverExtraClassPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, @@ -212,9 +346,13 @@ object SparkSubmit { sysProp = "spark.driver.extraLibraryPath"), // Standalone cluster only + // Do not set CL arguments here because there are multiple possibilities for the main class OptionAssigner(args.jars, STANDALONE, CLUSTER, sysProp = "spark.jars"), - OptionAssigner(args.driverMemory, STANDALONE, CLUSTER, clOption = "--memory"), - OptionAssigner(args.driverCores, STANDALONE, CLUSTER, clOption = "--cores"), + OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, sysProp = "spark.jars.ivy"), + OptionAssigner(args.driverMemory, STANDALONE, CLUSTER, sysProp = "spark.driver.memory"), + OptionAssigner(args.driverCores, STANDALONE, CLUSTER, sysProp = "spark.driver.cores"), + OptionAssigner(args.supervise.toString, STANDALONE, CLUSTER, + sysProp = "spark.driver.supervise"), // Yarn client only OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"), @@ -255,7 +393,6 @@ object SparkSubmit { if (args.childArgs != null) { childArgs ++= args.childArgs } } - // Map all arguments to command-line options or system properties for our chosen mode for (opt <- options) { if (opt.value != null && @@ -277,14 +414,21 @@ object SparkSubmit { sysProps.put("spark.jars", jars.mkString(",")) } - // In standalone-cluster mode, use Client as a wrapper around the user class - if (clusterManager == STANDALONE && deployMode == CLUSTER) { - childMainClass = "org.apache.spark.deploy.Client" - if (args.supervise) { - childArgs += "--supervise" + // In standalone cluster mode, use the REST client to submit the application (Spark 1.3+). + // All Spark parameters are expected to be passed to the client through system properties. + if (args.isStandaloneCluster) { + if (args.useRest) { + childMainClass = "org.apache.spark.deploy.rest.StandaloneRestClient" + childArgs += (args.primaryResource, args.mainClass) + } else { + // In legacy standalone cluster mode, use Client as a wrapper around the user class + childMainClass = "org.apache.spark.deploy.Client" + if (args.supervise) { childArgs += "--supervise" } + Option(args.driverMemory).foreach { m => childArgs += ("--memory", m) } + Option(args.driverCores).foreach { c => childArgs += ("--cores", c) } + childArgs += "launch" + childArgs += (args.master, args.primaryResource, args.mainClass) } - childArgs += "launch" - childArgs += (args.master, args.primaryResource, args.mainClass) if (args.childArgs != null) { childArgs ++= args.childArgs } @@ -321,7 +465,7 @@ object SparkSubmit { // Ignore invalid spark.driver.host in cluster modes. if (deployMode == CLUSTER) { - sysProps -= ("spark.driver.host") + sysProps -= "spark.driver.host" } // Resolve paths in certain spark properties @@ -350,12 +494,18 @@ object SparkSubmit { (childArgs, childClasspath, sysProps, childMainClass) } - private def launch( - childArgs: ArrayBuffer[String], - childClasspath: ArrayBuffer[String], + /** + * Run the main method of the child class using the provided launch environment. + * + * Note that this main class will not be the one provided by the user if we're + * running cluster deploy mode or python applications. + */ + private def runMain( + childArgs: Seq[String], + childClasspath: Seq[String], sysProps: Map[String, String], childMainClass: String, - verbose: Boolean = false) { + verbose: Boolean): Unit = { if (verbose) { printStream.println(s"Main class:\n$childMainClass") printStream.println(s"Arguments:\n${childArgs.mkString("\n")}") @@ -364,8 +514,14 @@ object SparkSubmit { printStream.println("\n") } - val loader = new ExecutorURLClassLoader(new Array[URL](0), - Thread.currentThread.getContextClassLoader) + val loader = + if (sysProps.getOrElse("spark.driver.userClassPathFirst", "false").toBoolean) { + new ChildFirstURLClassLoader(new Array[URL](0), + Thread.currentThread.getContextClassLoader) + } else { + new MutableURLClassLoader(new Array[URL](0), + Thread.currentThread.getContextClassLoader) + } Thread.currentThread.setContextClassLoader(loader) for (jar <- childClasspath) { @@ -384,8 +540,8 @@ object SparkSubmit { case e: ClassNotFoundException => e.printStackTrace(printStream) if (childMainClass.contains("thriftserver")) { - println(s"Failed to load main class $childMainClass.") - println("You need to build Spark with -Phive and -Phive-thriftserver.") + printStream.println(s"Failed to load main class $childMainClass.") + printStream.println("You need to build Spark with -Phive and -Phive-thriftserver.") } System.exit(CLASS_NOT_FOUND_EXIT_STATUS) } @@ -399,17 +555,25 @@ object SparkSubmit { if (!Modifier.isStatic(mainMethod.getModifiers)) { throw new IllegalStateException("The main method in the given main class must be static") } + + def findCause(t: Throwable): Throwable = t match { + case e: UndeclaredThrowableException => + if (e.getCause() != null) findCause(e.getCause()) else e + case e: InvocationTargetException => + if (e.getCause() != null) findCause(e.getCause()) else e + case e: Throwable => + e + } + try { mainMethod.invoke(null, childArgs.toArray) } catch { - case e: InvocationTargetException => e.getCause match { - case cause: Throwable => throw cause - case null => throw e - } + case t: Throwable => + throw findCause(t) } } - private def addJarToClasspath(localJar: String, loader: ExecutorURLClassLoader) { + private def addJarToClasspath(localJar: String, loader: MutableURLClassLoader) { val uri = Utils.resolveURI(localJar) uri.getScheme match { case "file" | "local" => @@ -475,11 +639,234 @@ object SparkSubmit { } } +/** Provides utility functions to be used inside SparkSubmit. */ +private[spark] object SparkSubmitUtils { + + // Exposed for testing + private[spark] var printStream = SparkSubmit.printStream + + /** + * Represents a Maven Coordinate + * @param groupId the groupId of the coordinate + * @param artifactId the artifactId of the coordinate + * @param version the version of the coordinate + */ + private[spark] case class MavenCoordinate(groupId: String, artifactId: String, version: String) + +/** + * Extracts maven coordinates from a comma-delimited string. Coordinates should be provided + * in the format `groupId:artifactId:version` or `groupId/artifactId:version`. + * @param coordinates Comma-delimited string of maven coordinates + * @return Sequence of Maven coordinates + */ + private[spark] def extractMavenCoordinates(coordinates: String): Seq[MavenCoordinate] = { + coordinates.split(",").map { p => + val splits = p.replace("/", ":").split(":") + require(splits.length == 3, s"Provided Maven Coordinates must be in the form " + + s"'groupId:artifactId:version'. The coordinate provided is: $p") + require(splits(0) != null && splits(0).trim.nonEmpty, s"The groupId cannot be null or " + + s"be whitespace. The groupId provided is: ${splits(0)}") + require(splits(1) != null && splits(1).trim.nonEmpty, s"The artifactId cannot be null or " + + s"be whitespace. The artifactId provided is: ${splits(1)}") + require(splits(2) != null && splits(2).trim.nonEmpty, s"The version cannot be null or " + + s"be whitespace. The version provided is: ${splits(2)}") + new MavenCoordinate(splits(0), splits(1), splits(2)) + } + } + + /** + * Extracts maven coordinates from a comma-delimited string + * @param remoteRepos Comma-delimited string of remote repositories + * @return A ChainResolver used by Ivy to search for and resolve dependencies. + */ + private[spark] def createRepoResolvers(remoteRepos: Option[String]): ChainResolver = { + // We need a chain resolver if we want to check multiple repositories + val cr = new ChainResolver + cr.setName("list") + + // the biblio resolver resolves POM declared dependencies + val br: IBiblioResolver = new IBiblioResolver + br.setM2compatible(true) + br.setUsepoms(true) + br.setName("central") + cr.add(br) + + val sp: IBiblioResolver = new IBiblioResolver + sp.setM2compatible(true) + sp.setUsepoms(true) + sp.setRoot("http://dl.bintray.com/spark-packages/maven") + sp.setName("spark-packages") + cr.add(sp) + + val repositoryList = remoteRepos.getOrElse("") + // add any other remote repositories other than maven central + if (repositoryList.trim.nonEmpty) { + repositoryList.split(",").zipWithIndex.foreach { case (repo, i) => + val brr: IBiblioResolver = new IBiblioResolver + brr.setM2compatible(true) + brr.setUsepoms(true) + brr.setRoot(repo) + brr.setName(s"repo-${i + 1}") + cr.add(brr) + printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}") + } + } + cr + } + + /** + * Output a comma-delimited list of paths for the downloaded jars to be added to the classpath + * (will append to jars in SparkSubmit). The name of the jar is given + * after a '!' by Ivy. It also sometimes contains '(bundle)' after '.jar'. Remove that as well. + * @param artifacts Sequence of dependencies that were resolved and retrieved + * @param cacheDirectory directory where jars are cached + * @return a comma-delimited list of paths for the dependencies + */ + private[spark] def resolveDependencyPaths( + artifacts: Array[AnyRef], + cacheDirectory: File): String = { + artifacts.map { artifactInfo => + val artifactString = artifactInfo.toString + val jarName = artifactString.drop(artifactString.lastIndexOf("!") + 1) + cacheDirectory.getAbsolutePath + File.separator + + jarName.substring(0, jarName.lastIndexOf(".jar") + 4) + }.mkString(",") + } + + /** Adds the given maven coordinates to Ivy's module descriptor. */ + private[spark] def addDependenciesToIvy( + md: DefaultModuleDescriptor, + artifacts: Seq[MavenCoordinate], + ivyConfName: String): Unit = { + artifacts.foreach { mvn => + val ri = ModuleRevisionId.newInstance(mvn.groupId, mvn.artifactId, mvn.version) + val dd = new DefaultDependencyDescriptor(ri, false, false) + dd.addDependencyConfiguration(ivyConfName, ivyConfName) + printStream.println(s"${dd.getDependencyId} added as a dependency") + md.addDependency(dd) + } + } + + /** Add exclusion rules for dependencies already included in the spark-assembly */ + private[spark] def addExclusionRules( + ivySettings: IvySettings, + ivyConfName: String, + md: DefaultModuleDescriptor): Unit = { + // Add scala exclusion rule + val scalaArtifacts = new ArtifactId(new ModuleId("*", "scala-library"), "*", "*", "*") + val scalaDependencyExcludeRule = + new DefaultExcludeRule(scalaArtifacts, ivySettings.getMatcher("glob"), null) + scalaDependencyExcludeRule.addConfiguration(ivyConfName) + md.addExcludeRule(scalaDependencyExcludeRule) + + // We need to specify each component explicitly, otherwise we miss spark-streaming-kafka and + // other spark-streaming utility components. Underscore is there to differentiate between + // spark-streaming_2.1x and spark-streaming-kafka-assembly_2.1x + val components = Seq("bagel_", "catalyst_", "core_", "graphx_", "hive_", "mllib_", "repl_", + "sql_", "streaming_", "yarn_", "network-common_", "network-shuffle_", "network-yarn_") + + components.foreach { comp => + val sparkArtifacts = + new ArtifactId(new ModuleId("org.apache.spark", s"spark-$comp*"), "*", "*", "*") + val sparkDependencyExcludeRule = + new DefaultExcludeRule(sparkArtifacts, ivySettings.getMatcher("glob"), null) + sparkDependencyExcludeRule.addConfiguration(ivyConfName) + + md.addExcludeRule(sparkDependencyExcludeRule) + } + } + + /** A nice function to use in tests as well. Values are dummy strings. */ + private[spark] def getModuleDescriptor = DefaultModuleDescriptor.newDefaultInstance( + ModuleRevisionId.newInstance("org.apache.spark", "spark-submit-parent", "1.0")) + + /** + * Resolves any dependencies that were supplied through maven coordinates + * @param coordinates Comma-delimited string of maven coordinates + * @param remoteRepos Comma-delimited string of remote repositories other than maven central + * @param ivyPath The path to the local ivy repository + * @return The comma-delimited path to the jars of the given maven artifacts including their + * transitive dependencies + */ + private[spark] def resolveMavenCoordinates( + coordinates: String, + remoteRepos: Option[String], + ivyPath: Option[String], + isTest: Boolean = false): String = { + if (coordinates == null || coordinates.trim.isEmpty) { + "" + } else { + val sysOut = System.out + // To prevent ivy from logging to system out + System.setOut(printStream) + val artifacts = extractMavenCoordinates(coordinates) + // Default configuration name for ivy + val ivyConfName = "default" + // set ivy settings for location of cache + val ivySettings: IvySettings = new IvySettings + // Directories for caching downloads through ivy and storing the jars when maven coordinates + // are supplied to spark-submit + val alternateIvyCache = ivyPath.getOrElse("") + val packagesDirectory: File = + if (alternateIvyCache.trim.isEmpty) { + new File(ivySettings.getDefaultIvyUserDir, "jars") + } else { + ivySettings.setDefaultCache(new File(alternateIvyCache, "cache")) + new File(alternateIvyCache, "jars") + } + printStream.println( + s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}") + printStream.println(s"The jars for the packages stored in: $packagesDirectory") + // create a pattern matcher + ivySettings.addMatcher(new GlobPatternMatcher) + // create the dependency resolvers + val repoResolver = createRepoResolvers(remoteRepos) + ivySettings.addResolver(repoResolver) + ivySettings.setDefaultResolver(repoResolver.getName) + + val ivy = Ivy.newInstance(ivySettings) + // Set resolve options to download transitive dependencies as well + val resolveOptions = new ResolveOptions + resolveOptions.setTransitive(true) + val retrieveOptions = new RetrieveOptions + // Turn downloading and logging off for testing + if (isTest) { + resolveOptions.setDownload(false) + resolveOptions.setLog(LogOptions.LOG_QUIET) + retrieveOptions.setLog(LogOptions.LOG_QUIET) + } else { + resolveOptions.setDownload(true) + } + + // A Module descriptor must be specified. Entries are dummy strings + val md = getModuleDescriptor + md.setDefaultConf(ivyConfName) + + // Add exclusion rules for Spark and Scala Library + addExclusionRules(ivySettings, ivyConfName, md) + // add all supplied maven artifacts as dependencies + addDependenciesToIvy(md, artifacts, ivyConfName) + + // resolve dependencies + val rr: ResolveReport = ivy.resolve(md, resolveOptions) + if (rr.hasError) { + throw new RuntimeException(rr.getAllProblemMessages.toString) + } + // retrieve all resolved dependencies + ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId, + packagesDirectory.getAbsolutePath + File.separator + "[artifact](-[classifier]).[ext]", + retrieveOptions.setConfs(Array(ivyConfName))) + System.setOut(sysOut) + resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) + } + } +} + /** * Provides an indirection layer for passing arguments as system properties or flags to * the user's driver program or to downstream launcher tools. */ -private[spark] case class OptionAssigner( +private case class OptionAssigner( value: String, clusterManager: Int, deployMode: Int, diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 73e921fd83ef2..82e66a374249c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -22,6 +22,7 @@ import java.util.jar.JarFile import scala.collection.mutable.{ArrayBuffer, HashMap} +import org.apache.spark.deploy.SparkSubmitAction._ import org.apache.spark.util.Utils /** @@ -39,8 +40,6 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St var driverExtraClassPath: String = null var driverExtraLibraryPath: String = null var driverExtraJavaOptions: String = null - var driverCores: String = null - var supervise: Boolean = false var queue: String = null var numExecutors: String = null var files: String = null @@ -50,10 +49,22 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St var name: String = null var childArgs: ArrayBuffer[String] = new ArrayBuffer[String]() var jars: String = null + var packages: String = null + var repositories: String = null + var ivyRepoPath: String = null var verbose: Boolean = false var isPython: Boolean = false var pyFiles: String = null + var action: SparkSubmitAction = null val sparkProperties: HashMap[String, String] = new HashMap[String, String]() + var proxyUser: String = null + + // Standalone cluster mode only + var supervise: Boolean = false + var driverCores: String = null + var submissionToKill: String = null + var submissionToRequestStatusFor: String = null + var useRest: Boolean = true // used internally /** Default properties present in the currently defined defaults file. */ lazy val defaultSparkProperties: HashMap[String, String] = { @@ -79,7 +90,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St // Use `sparkProperties` map along with env vars to fill in any missing parameters loadEnvironmentArguments() - checkRequiredArguments() + validateArguments() /** * Merge values from the default properties file with those specified through --conf. @@ -104,6 +115,15 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St .orElse(sparkProperties.get("spark.master")) .orElse(env.get("MASTER")) .orNull + driverExtraClassPath = Option(driverExtraClassPath) + .orElse(sparkProperties.get("spark.driver.extraClassPath")) + .orNull + driverExtraJavaOptions = Option(driverExtraJavaOptions) + .orElse(sparkProperties.get("spark.driver.extraJavaOptions")) + .orNull + driverExtraLibraryPath = Option(driverExtraLibraryPath) + .orElse(sparkProperties.get("spark.driver.extraLibraryPath")) + .orNull driverMemory = Option(driverMemory) .orElse(sparkProperties.get("spark.driver.memory")) .orElse(env.get("SPARK_DRIVER_MEMORY")) @@ -123,6 +143,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St .orNull name = Option(name).orElse(sparkProperties.get("spark.app.name")).orNull jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull + ivyRepoPath = sparkProperties.get("spark.jars.ivy").orNull deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull numExecutors = Option(numExecutors) .getOrElse(sparkProperties.get("spark.executor.instances").orNull) @@ -162,10 +183,21 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St if (name == null && primaryResource != null) { name = Utils.stripDirectory(primaryResource) } + + // Action should be SUBMIT unless otherwise specified + action = Option(action).getOrElse(SUBMIT) } /** Ensure that required fields exists. Call this only once all defaults are loaded. */ - private def checkRequiredArguments(): Unit = { + private def validateArguments(): Unit = { + action match { + case SUBMIT => validateSubmitArguments() + case KILL => validateKillArguments() + case REQUEST_STATUS => validateStatusRequestArguments() + } + } + + private def validateSubmitArguments(): Unit = { if (args.length == 0) { printUsageAndExit(-1) } @@ -188,6 +220,29 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St } } + private def validateKillArguments(): Unit = { + if (!master.startsWith("spark://")) { + SparkSubmit.printErrorAndExit("Killing submissions is only supported in standalone mode!") + } + if (submissionToKill == null) { + SparkSubmit.printErrorAndExit("Please specify a submission to kill.") + } + } + + private def validateStatusRequestArguments(): Unit = { + if (!master.startsWith("spark://")) { + SparkSubmit.printErrorAndExit( + "Requesting submission statuses is only supported in standalone mode!") + } + if (submissionToRequestStatusFor == null) { + SparkSubmit.printErrorAndExit("Please specify a submission to request status for.") + } + } + + def isStandaloneCluster: Boolean = { + master.startsWith("spark://") && deployMode == "cluster" + } + override def toString = { s"""Parsed arguments: | master $master @@ -212,6 +267,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St | name $name | childArgs [${childArgs.mkString(" ")}] | jars $jars + | packages $packages + | repositories $repositories | verbose $verbose | |Spark properties used, including those specified through @@ -294,6 +351,22 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St propertiesFile = value parse(tail) + case ("--kill") :: value :: tail => + submissionToKill = value + if (action != null) { + SparkSubmit.printErrorAndExit(s"Action cannot be both $action and $KILL.") + } + action = KILL + parse(tail) + + case ("--status") :: value :: tail => + submissionToRequestStatusFor = value + if (action != null) { + SparkSubmit.printErrorAndExit(s"Action cannot be both $action and $REQUEST_STATUS.") + } + action = REQUEST_STATUS + parse(tail) + case ("--supervise") :: tail => supervise = true parse(tail) @@ -318,6 +391,14 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St jars = Utils.resolveURIs(value) parse(tail) + case ("--packages") :: value :: tail => + packages = value + parse(tail) + + case ("--repositories") :: value :: tail => + repositories = value + parse(tail) + case ("--conf" | "-c") :: value :: tail => value.split("=", 2).toSeq match { case Seq(k, v) => sparkProperties(k) = v @@ -325,6 +406,10 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St } parse(tail) + case ("--proxy-user") :: value :: tail => + proxyUser = value + parse(tail) + case ("--help" | "-h") :: tail => printUsageAndExit(0) @@ -332,6 +417,9 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St verbose = true parse(tail) + case ("--version") :: tail => + SparkSubmit.printVersionAndExit() + case EQ_SEPARATED_OPT(opt, value) :: tail => parse(opt :: value :: tail) @@ -358,7 +446,10 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St outStream.println("Unknown/unsupported param " + unknownParam) } outStream.println( - """Usage: spark-submit [options] [app options] + """Usage: spark-submit [options] [app arguments] + |Usage: spark-submit --kill [submission ID] --master [spark://...] + |Usage: spark-submit --status [submission ID] --master [spark://...] + | |Options: | --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local. | --deploy-mode DEPLOY_MODE Whether to launch the driver program locally ("client") or @@ -368,6 +459,13 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St | --name NAME A name of your application. | --jars JARS Comma-separated list of local jars to include on the driver | and executor classpaths. + | --packages Comma-separated list of maven coordinates of jars to include + | on the driver and executor classpaths. Will search the local + | maven repo, then maven central and any additional remote + | repositories given by --repositories. The format for the + | coordinates should be groupId:artifactId:version. + | --repositories Comma-separated list of additional remote repositories to + | search for the maven coordinates given with --packages. | --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to place | on the PYTHONPATH for Python apps. | --files FILES Comma-separated list of files to be placed in the working @@ -386,12 +484,17 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St | | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G). | + | --proxy-user NAME User to impersonate when submitting the application. + | | --help, -h Show this help message and exit | --verbose, -v Print additional debug output + | --version, Print the version of current Spark | | Spark standalone with cluster deploy mode only: | --driver-cores NUM Cores for driver (Default: 1). | --supervise If given, restarts the driver on failure. + | --kill SUBMISSION_ID If given, kills the driver specified. + | --status SUBMISSION_ID If given, requests the status of the driver specified. | | Spark standalone and Mesos only: | --total-executor-cores NUM Total cores for all executors. diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala index 2eab9981845e8..311048cdaa324 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala @@ -17,8 +17,6 @@ package org.apache.spark.deploy -import java.io.File - import scala.collection.JavaConversions._ import org.apache.spark.util.{RedirectThread, Utils} @@ -164,6 +162,8 @@ private[spark] object SparkSubmitDriverBootstrapper { } } val returnCode = process.waitFor() + stdoutThread.join() + stderrThread.join() sys.exit(returnCode) } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 0ae45f4ad9130..554ce133e14fe 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -68,8 +68,8 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis // Constants used to parse Spark 1.0.0 log directories. private[history] val LOG_PREFIX = "EVENT_LOG_" - private[history] val SPARK_VERSION_PREFIX = "SPARK_VERSION_" - private[history] val COMPRESSION_CODEC_PREFIX = "COMPRESSION_CODEC_" + private[history] val SPARK_VERSION_PREFIX = EventLoggingListener.SPARK_VERSION_KEY + "_" + private[history] val COMPRESSION_CODEC_PREFIX = EventLoggingListener.COMPRESSION_CODEC_KEY + "_" private[history] val APPLICATION_COMPLETE = "APPLICATION_COMPLETE" /** @@ -104,7 +104,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis if (!fs.exists(path)) { var msg = s"Log directory specified does not exist: $logDir." if (logDir == DEFAULT_LOG_DIR) { - msg += " Did you configure the correct one through spark.fs.history.logDirectory?" + msg += " Did you configure the correct one through spark.history.fs.logDirectory?" } throw new IllegalArgumentException(msg) } @@ -173,9 +173,10 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis val logInfos = statusList .filter { entry => try { - val modTime = getModificationTime(entry) - newLastModifiedTime = math.max(newLastModifiedTime, modTime) - modTime >= lastModifiedTime + getModificationTime(entry).map { time => + newLastModifiedTime = math.max(newLastModifiedTime, time) + time >= lastModifiedTime + }.getOrElse(false) } catch { case e: AccessControlException => // Do not use "logInfo" since these messages can get pretty noisy if printed on @@ -193,7 +194,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis None } } - .sortBy { info => (-info.endTime, -info.startTime) } + .sortWith(compareAppInfo) lastModifiedTime = newLastModifiedTime @@ -213,7 +214,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis val newIterator = logInfos.iterator.buffered val oldIterator = applications.values.iterator.buffered while (newIterator.hasNext && oldIterator.hasNext) { - if (newIterator.head.endTime > oldIterator.head.endTime) { + if (compareAppInfo(newIterator.head, oldIterator.head)) { addIfAbsent(newIterator.next) } else { addIfAbsent(oldIterator.next) @@ -229,13 +230,25 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis } } + /** + * Comparison function that defines the sort order for the application listing. + * + * @return Whether `i1` should precede `i2`. + */ + private def compareAppInfo( + i1: FsApplicationHistoryInfo, + i2: FsApplicationHistoryInfo): Boolean = { + if (i1.endTime != i2.endTime) i1.endTime >= i2.endTime else i1.startTime >= i2.startTime + } + /** * Replays the events in the specified log file and returns information about the associated * application. */ private def replay(eventLog: FileStatus, bus: ReplayListenerBus): FsApplicationHistoryInfo = { val logPath = eventLog.getPath() - val (logInput, sparkVersion) = + logInfo(s"Replaying log path: $logPath") + val logInput = if (isLegacyLogDirectory(eventLog)) { openLegacyEventLog(logPath) } else { @@ -244,14 +257,14 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis try { val appListener = new ApplicationEventListener bus.addListener(appListener) - bus.replay(logInput, sparkVersion) + bus.replay(logInput, logPath.toString) new FsApplicationHistoryInfo( logPath.getName(), appListener.appId.getOrElse(logPath.getName()), appListener.appName.getOrElse(NOT_STARTED), appListener.startTime.getOrElse(-1L), appListener.endTime.getOrElse(-1L), - getModificationTime(eventLog), + getModificationTime(eventLog).get, appListener.sparkUser.getOrElse(NOT_STARTED), isApplicationCompleted(eventLog)) } finally { @@ -264,30 +277,24 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis * log file (along with other metadata files), which is the case for directories generated by * the code in previous releases. * - * @return 2-tuple of (input stream of the events, version of Spark which wrote the log) + * @return input stream that holds one JSON record per line. */ - private[history] def openLegacyEventLog(dir: Path): (InputStream, String) = { + private[history] def openLegacyEventLog(dir: Path): InputStream = { val children = fs.listStatus(dir) var eventLogPath: Path = null var codecName: Option[String] = None - var sparkVersion: String = null children.foreach { child => child.getPath().getName() match { case name if name.startsWith(LOG_PREFIX) => eventLogPath = child.getPath() - case codec if codec.startsWith(COMPRESSION_CODEC_PREFIX) => codecName = Some(codec.substring(COMPRESSION_CODEC_PREFIX.length())) - - case version if version.startsWith(SPARK_VERSION_PREFIX) => - sparkVersion = version.substring(SPARK_VERSION_PREFIX.length()) - case _ => } } - if (eventLogPath == null || sparkVersion == null) { + if (eventLogPath == null) { throw new IllegalArgumentException(s"$dir is not a Spark application log directory.") } @@ -299,7 +306,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis } val in = new BufferedInputStream(fs.open(eventLogPath)) - (codec.map(_.compressedInputStream(in)).getOrElse(in), sparkVersion) + codec.map(_.compressedInputStream(in)).getOrElse(in) } /** @@ -310,11 +317,16 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis */ private def isLegacyLogDirectory(entry: FileStatus): Boolean = entry.isDir() - private def getModificationTime(fsEntry: FileStatus): Long = { - if (fsEntry.isDir) { - fs.listStatus(fsEntry.getPath).map(_.getModificationTime()).max + /** + * Returns the modification time of the given event log. If the status points at an empty + * directory, `None` is returned, indicating that there isn't an event log at that location. + */ + private def getModificationTime(fsEntry: FileStatus): Option[Long] = { + if (isLegacyLogDirectory(fsEntry)) { + val statusList = fs.listStatus(fsEntry.getPath) + if (!statusList.isEmpty) Some(statusList.map(_.getModificationTime()).max) else None } else { - fsEntry.getModificationTime() + Some(fsEntry.getModificationTime()) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index e4e7bc2216014..26ebc75971c66 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -61,9 +61,10 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { // page, `...` will be displayed. if (allApps.size > 0) { val leftSideIndices = - rangeIndices(actualPage - plusOrMinus until actualPage, 1 < _) + rangeIndices(actualPage - plusOrMinus until actualPage, 1 < _, requestedIncomplete) val rightSideIndices = - rangeIndices(actualPage + 1 to actualPage + plusOrMinus, _ < pageCount) + rangeIndices(actualPage + 1 to actualPage + plusOrMinus, _ < pageCount, + requestedIncomplete)

Showing {actualFirst + 1}-{last + 1} of {allApps.size} @@ -122,8 +123,10 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { "Spark User", "Last Updated") - private def rangeIndices(range: Seq[Int], condition: Int => Boolean): Seq[Node] = { - range.filter(condition).map(nextPage => {nextPage} ) + private def rangeIndices(range: Seq[Int], condition: Int => Boolean, showIncomplete: Boolean): + Seq[Node] = { + range.filter(condition).map(nextPage => + {nextPage} ) } private def appRow(info: ApplicationHistoryInfo): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 5eeb9fe526248..259888f3c33d0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -43,6 +43,7 @@ import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.MasterMessages._ import org.apache.spark.deploy.master.ui.MasterWebUI +import org.apache.spark.deploy.rest.StandaloneRestServer import org.apache.spark.metrics.MetricsSystem import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} import org.apache.spark.ui.SparkUI @@ -52,12 +53,12 @@ private[spark] class Master( host: String, port: Int, webUiPort: Int, - val securityMgr: SecurityManager) + val securityMgr: SecurityManager, + val conf: SparkConf) extends Actor with ActorLogReceive with Logging with LeaderElectable { import context.dispatcher // to use Akka's scheduler.schedule() - val conf = new SparkConf val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs @@ -95,7 +96,7 @@ private[spark] class Master( val webUi = new MasterWebUI(this, webUiPort) val masterPublicAddress = { - val envVar = System.getenv("SPARK_PUBLIC_DNS") + val envVar = conf.getenv("SPARK_PUBLIC_DNS") if (envVar != null) envVar else host } @@ -121,6 +122,17 @@ private[spark] class Master( throw new SparkException("spark.deploy.defaultCores must be positive") } + // Alternative application submission gateway that is stable across Spark versions + private val restServerEnabled = conf.getBoolean("spark.master.rest.enabled", true) + private val restServer = + if (restServerEnabled) { + val port = conf.getInt("spark.master.rest.port", 6066) + Some(new StandaloneRestServer(host, port, self, masterUrl, conf)) + } else { + None + } + private val restServerBoundPort = restServer.map(_.start()) + override def preStart() { logInfo("Starting Spark master at " + masterUrl) logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") @@ -174,6 +186,7 @@ private[spark] class Master( recoveryCompletionTask.cancel() } webUi.stop() + restServer.foreach(_.stop()) masterMetricsSystem.stop() applicationMetricsSystem.stop() persistenceEngine.close() @@ -421,7 +434,9 @@ private[spark] class Master( } case RequestMasterState => { - sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray, + sender ! MasterStateResponse( + host, port, restServerBoundPort, + workers.toArray, apps.toArray, completedApps.toArray, drivers.toArray, completedDrivers.toArray, state) } @@ -429,8 +444,8 @@ private[spark] class Master( timeOutDeadWorkers() } - case RequestWebUIPort => { - sender ! WebUIPortResponse(webUi.boundPort) + case BoundPortsRequest => { + sender ! BoundPortsResponse(port, webUi.boundPort, restServerBoundPort) } } @@ -656,7 +671,7 @@ private[spark] class Master( def registerApplication(app: ApplicationInfo): Unit = { val appAddress = app.driver.path.address - if (addressToWorker.contains(appAddress)) { + if (addressToApp.contains(appAddress)) { logInfo("Attempted to re-register application at same address: " + appAddress) return } @@ -722,13 +737,13 @@ private[spark] class Master( val notFoundBasePath = HistoryServer.UI_PATH_PREFIX + "/not-found" try { val eventLogFile = app.desc.eventLogDir - .map { dir => EventLoggingListener.getLogPath(dir, app.id) } + .map { dir => EventLoggingListener.getLogPath(dir, app.id, app.desc.eventLogCodec) } .getOrElse { // Event logging is not enabled for this application app.desc.appUiUrl = notFoundBasePath return false } - + val fs = Utils.getHadoopFileSystem(eventLogFile, hadoopConf) if (fs.exists(new Path(eventLogFile + EventLoggingListener.IN_PROGRESS))) { @@ -741,12 +756,12 @@ private[spark] class Master( return false } - val (logInput, sparkVersion) = EventLoggingListener.openEventLog(new Path(eventLogFile), fs) + val logInput = EventLoggingListener.openEventLog(new Path(eventLogFile), fs) val replayBus = new ReplayListenerBus() val ui = SparkUI.createHistoryUI(new SparkConf, replayBus, new SecurityManager(conf), appName + " (completed)", HistoryServer.UI_PATH_PREFIX + s"/${app.id}") try { - replayBus.replay(logInput, sparkVersion) + replayBus.replay(logInput, eventLogFile) } finally { logInput.close() } @@ -851,7 +866,7 @@ private[spark] object Master extends Logging { SignalLogger.register(log) val conf = new SparkConf val args = new MasterArguments(argStrings, conf) - val (actorSystem, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort, conf) + val (actorSystem, _, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort, conf) actorSystem.awaitTermination() } @@ -875,19 +890,26 @@ private[spark] object Master extends Logging { Address(protocol, systemName, host, port) } + /** + * Start the Master and return a four tuple of: + * (1) The Master actor system + * (2) The bound port + * (3) The web UI bound port + * (4) The REST server bound port, if any + */ def startSystemAndActor( host: String, port: Int, webUiPort: Int, - conf: SparkConf): (ActorSystem, Int, Int) = { + conf: SparkConf): (ActorSystem, Int, Int, Option[Int]) = { val securityMgr = new SecurityManager(conf) val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf, securityManager = securityMgr) - val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort, - securityMgr), actorName) + val actor = actorSystem.actorOf( + Props(classOf[Master], host, boundPort, webUiPort, securityMgr, conf), actorName) val timeout = AkkaUtils.askTimeout(conf) - val respFuture = actor.ask(RequestWebUIPort)(timeout) - val resp = Await.result(respFuture, timeout).asInstanceOf[WebUIPortResponse] - (actorSystem, boundPort, resp.webUIBoundPort) + val portsRequest = actor.ask(BoundPortsRequest)(timeout) + val portsResponse = Await.result(portsRequest, timeout).asInstanceOf[BoundPortsResponse] + (actorSystem, boundPort, portsResponse.webUIPort, portsResponse.restPort) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala index db72d8ae9bdaf..15c6296888f70 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala @@ -36,7 +36,7 @@ private[master] object MasterMessages { case object CompleteRecovery - case object RequestWebUIPort + case object BoundPortsRequest - case class WebUIPortResponse(webUIBoundPort: Int) + case class BoundPortsResponse(actorPort: Int, webUIPort: Int, restPort: Option[Int]) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 7ca3b08a28728..fd514f07664a9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -46,19 +46,19 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] val state = Await.result(stateFuture, timeout) - val workerHeaders = Seq("Id", "Address", "State", "Cores", "Memory") + val workerHeaders = Seq("Worker Id", "Address", "State", "Cores", "Memory") val workers = state.workers.sortBy(_.id) val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers) - val appHeaders = Seq("ID", "Name", "Cores", "Memory per Node", "Submitted Time", "User", - "State", "Duration") + val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Node", "Submitted Time", + "User", "State", "Duration") val activeApps = state.activeApps.sortBy(_.startTime).reverse val activeAppsTable = UIUtils.listingTable(appHeaders, appRow, activeApps) val completedApps = state.completedApps.sortBy(_.endTime).reverse val completedAppsTable = UIUtils.listingTable(appHeaders, appRow, completedApps) - val driverHeaders = Seq("ID", "Submitted Time", "Worker", "State", "Cores", "Memory", - "Main Class") + val driverHeaders = Seq("Submission ID", "Submitted Time", "Worker", "State", "Cores", + "Memory", "Main Class") val activeDrivers = state.activeDrivers.sortBy(_.startTime).reverse val activeDriversTable = UIUtils.listingTable(driverHeaders, driverRow, activeDrivers) val completedDrivers = state.completedDrivers.sortBy(_.startTime).reverse @@ -73,6 +73,14 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
  • URL: {state.uri}
  • + { + state.restUri.map { uri => +
  • + REST URL: {uri} + (cluster mode) +
  • + }.getOrElse { Seq.empty } + }
  • Workers: {state.workers.size}
  • Cores: {state.workers.map(_.cores).sum} Total, {state.workers.map(_.coresUsed).sum} Used
  • @@ -188,7 +196,7 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {Utils.megabytesToString(driver.desc.mem.toLong)} - {driver.desc.command.arguments(1)} + {driver.desc.command.arguments(2)} } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala new file mode 100644 index 0000000000000..c4be1f19e8e9f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala @@ -0,0 +1,331 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.io.{DataOutputStream, FileNotFoundException} +import java.net.{HttpURLConnection, SocketException, URL} +import javax.servlet.http.HttpServletResponse + +import scala.io.Source + +import com.fasterxml.jackson.core.JsonProcessingException +import com.google.common.base.Charsets + +import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} + +/** + * A client that submits applications to the standalone Master using a REST protocol. + * This client is intended to communicate with the [[StandaloneRestServer]] and is + * currently used for cluster mode only. + * + * In protocol version v1, the REST URL takes the form http://[host:port]/v1/submissions/[action], + * where [action] can be one of create, kill, or status. Each type of request is represented in + * an HTTP message sent to the following prefixes: + * (1) submit - POST to /submissions/create + * (2) kill - POST /submissions/kill/[submissionId] + * (3) status - GET /submissions/status/[submissionId] + * + * In the case of (1), parameters are posted in the HTTP body in the form of JSON fields. + * Otherwise, the URL fully specifies the intended action of the client. + * + * Since the protocol is expected to be stable across Spark versions, existing fields cannot be + * added or removed, though new optional fields can be added. In the rare event that forward or + * backward compatibility is broken, Spark must introduce a new protocol version (e.g. v2). + * + * The client and the server must communicate using the same version of the protocol. If there + * is a mismatch, the server will respond with the highest protocol version it supports. A future + * implementation of this client can use that information to retry using the version specified + * by the server. + */ +private[spark] class StandaloneRestClient extends Logging { + import StandaloneRestClient._ + + /** + * Submit an application specified by the parameters in the provided request. + * + * If the submission was successful, poll the status of the submission and report + * it to the user. Otherwise, report the error message provided by the server. + */ + def createSubmission( + master: String, + request: CreateSubmissionRequest): SubmitRestProtocolResponse = { + logInfo(s"Submitting a request to launch an application in $master.") + validateMaster(master) + val url = getSubmitUrl(master) + val response = postJson(url, request.toJson) + response match { + case s: CreateSubmissionResponse => + reportSubmissionStatus(master, s) + handleRestResponse(s) + case unexpected => + handleUnexpectedRestResponse(unexpected) + } + response + } + + /** Request that the server kill the specified submission. */ + def killSubmission(master: String, submissionId: String): SubmitRestProtocolResponse = { + logInfo(s"Submitting a request to kill submission $submissionId in $master.") + validateMaster(master) + val response = post(getKillUrl(master, submissionId)) + response match { + case k: KillSubmissionResponse => handleRestResponse(k) + case unexpected => handleUnexpectedRestResponse(unexpected) + } + response + } + + /** Request the status of a submission from the server. */ + def requestSubmissionStatus( + master: String, + submissionId: String, + quiet: Boolean = false): SubmitRestProtocolResponse = { + logInfo(s"Submitting a request for the status of submission $submissionId in $master.") + validateMaster(master) + val response = get(getStatusUrl(master, submissionId)) + response match { + case s: SubmissionStatusResponse => if (!quiet) { handleRestResponse(s) } + case unexpected => handleUnexpectedRestResponse(unexpected) + } + response + } + + /** Construct a message that captures the specified parameters for submitting an application. */ + def constructSubmitRequest( + appResource: String, + mainClass: String, + appArgs: Array[String], + sparkProperties: Map[String, String], + environmentVariables: Map[String, String]): CreateSubmissionRequest = { + val message = new CreateSubmissionRequest + message.clientSparkVersion = sparkVersion + message.appResource = appResource + message.mainClass = mainClass + message.appArgs = appArgs + message.sparkProperties = sparkProperties + message.environmentVariables = environmentVariables + message.validate() + message + } + + /** Send a GET request to the specified URL. */ + private def get(url: URL): SubmitRestProtocolResponse = { + logDebug(s"Sending GET request to server at $url.") + val conn = url.openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod("GET") + readResponse(conn) + } + + /** Send a POST request to the specified URL. */ + private def post(url: URL): SubmitRestProtocolResponse = { + logDebug(s"Sending POST request to server at $url.") + val conn = url.openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod("POST") + readResponse(conn) + } + + /** Send a POST request with the given JSON as the body to the specified URL. */ + private def postJson(url: URL, json: String): SubmitRestProtocolResponse = { + logDebug(s"Sending POST request to server at $url:\n$json") + val conn = url.openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod("POST") + conn.setRequestProperty("Content-Type", "application/json") + conn.setRequestProperty("charset", "utf-8") + conn.setDoOutput(true) + val out = new DataOutputStream(conn.getOutputStream) + out.write(json.getBytes(Charsets.UTF_8)) + out.close() + readResponse(conn) + } + + /** + * Read the response from the server and return it as a validated [[SubmitRestProtocolResponse]]. + * If the response represents an error, report the embedded message to the user. + * Exposed for testing. + */ + private[rest] def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = { + try { + val dataStream = + if (connection.getResponseCode == HttpServletResponse.SC_OK) { + connection.getInputStream + } else { + connection.getErrorStream + } + // If the server threw an exception while writing a response, it will not have a body + if (dataStream == null) { + throw new SubmitRestProtocolException("Server returned empty body") + } + val responseJson = Source.fromInputStream(dataStream).mkString + logDebug(s"Response from the server:\n$responseJson") + val response = SubmitRestProtocolMessage.fromJson(responseJson) + response.validate() + response match { + // If the response is an error, log the message + case error: ErrorResponse => + logError(s"Server responded with error:\n${error.message}") + error + // Otherwise, simply return the response + case response: SubmitRestProtocolResponse => response + case unexpected => + throw new SubmitRestProtocolException( + s"Message received from server was not a response:\n${unexpected.toJson}") + } + } catch { + case unreachable @ (_: FileNotFoundException | _: SocketException) => + throw new SubmitRestConnectionException( + s"Unable to connect to server ${connection.getURL}", unreachable) + case malformed @ (_: JsonProcessingException | _: SubmitRestProtocolException) => + throw new SubmitRestProtocolException( + "Malformed response received from server", malformed) + } + } + + /** Return the REST URL for creating a new submission. */ + private def getSubmitUrl(master: String): URL = { + val baseUrl = getBaseUrl(master) + new URL(s"$baseUrl/create") + } + + /** Return the REST URL for killing an existing submission. */ + private def getKillUrl(master: String, submissionId: String): URL = { + val baseUrl = getBaseUrl(master) + new URL(s"$baseUrl/kill/$submissionId") + } + + /** Return the REST URL for requesting the status of an existing submission. */ + private def getStatusUrl(master: String, submissionId: String): URL = { + val baseUrl = getBaseUrl(master) + new URL(s"$baseUrl/status/$submissionId") + } + + /** Return the base URL for communicating with the server, including the protocol version. */ + private def getBaseUrl(master: String): String = { + val masterUrl = master.stripPrefix("spark://").stripSuffix("/") + s"http://$masterUrl/$PROTOCOL_VERSION/submissions" + } + + /** Throw an exception if this is not standalone mode. */ + private def validateMaster(master: String): Unit = { + if (!master.startsWith("spark://")) { + throw new IllegalArgumentException("This REST client is only supported in standalone mode.") + } + } + + /** Report the status of a newly created submission. */ + private def reportSubmissionStatus( + master: String, + submitResponse: CreateSubmissionResponse): Unit = { + if (submitResponse.success) { + val submissionId = submitResponse.submissionId + if (submissionId != null) { + logInfo(s"Submission successfully created as $submissionId. Polling submission state...") + pollSubmissionStatus(master, submissionId) + } else { + // should never happen + logError("Application successfully submitted, but submission ID was not provided!") + } + } else { + val failMessage = Option(submitResponse.message).map { ": " + _ }.getOrElse("") + logError("Application submission failed" + failMessage) + } + } + + /** + * Poll the status of the specified submission and log it. + * This retries up to a fixed number of times before giving up. + */ + private def pollSubmissionStatus(master: String, submissionId: String): Unit = { + (1 to REPORT_DRIVER_STATUS_MAX_TRIES).foreach { _ => + val response = requestSubmissionStatus(master, submissionId, quiet = true) + val statusResponse = response match { + case s: SubmissionStatusResponse => s + case _ => return // unexpected type, let upstream caller handle it + } + if (statusResponse.success) { + val driverState = Option(statusResponse.driverState) + val workerId = Option(statusResponse.workerId) + val workerHostPort = Option(statusResponse.workerHostPort) + val exception = Option(statusResponse.message) + // Log driver state, if present + driverState match { + case Some(state) => logInfo(s"State of driver $submissionId is now $state.") + case _ => logError(s"State of driver $submissionId was not found!") + } + // Log worker node, if present + (workerId, workerHostPort) match { + case (Some(id), Some(hp)) => logInfo(s"Driver is running on worker $id at $hp.") + case _ => + } + // Log exception stack trace, if present + exception.foreach { e => logError(e) } + return + } + Thread.sleep(REPORT_DRIVER_STATUS_INTERVAL) + } + logError(s"Error: Master did not recognize driver $submissionId.") + } + + /** Log the response sent by the server in the REST application submission protocol. */ + private def handleRestResponse(response: SubmitRestProtocolResponse): Unit = { + logInfo(s"Server responded with ${response.messageType}:\n${response.toJson}") + } + + /** Log an appropriate error if the response sent by the server is not of the expected type. */ + private def handleUnexpectedRestResponse(unexpected: SubmitRestProtocolResponse): Unit = { + logError(s"Error: Server responded with message of unexpected type ${unexpected.messageType}.") + } +} + +private[spark] object StandaloneRestClient { + val REPORT_DRIVER_STATUS_INTERVAL = 1000 + val REPORT_DRIVER_STATUS_MAX_TRIES = 10 + val PROTOCOL_VERSION = "v1" + + /** + * Submit an application, assuming Spark parameters are specified through the given config. + * This is abstracted to its own method for testing purposes. + */ + private[rest] def run( + appResource: String, + mainClass: String, + appArgs: Array[String], + conf: SparkConf, + env: Map[String, String] = sys.env): SubmitRestProtocolResponse = { + val master = conf.getOption("spark.master").getOrElse { + throw new IllegalArgumentException("'spark.master' must be set.") + } + val sparkProperties = conf.getAll.toMap + val environmentVariables = env.filter { case (k, _) => k.startsWith("SPARK_") } + val client = new StandaloneRestClient + val submitRequest = client.constructSubmitRequest( + appResource, mainClass, appArgs, sparkProperties, environmentVariables) + client.createSubmission(master, submitRequest) + } + + def main(args: Array[String]): Unit = { + if (args.size < 2) { + sys.error("Usage: StandaloneRestClient [app resource] [main class] [app args*]") + sys.exit(1) + } + val appResource = args(0) + val mainClass = args(1) + val appArgs = args.slice(2, args.size) + val conf = new SparkConf + run(appResource, mainClass, appArgs, conf) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala new file mode 100644 index 0000000000000..f9e0478e4f874 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -0,0 +1,438 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.io.File +import java.net.InetSocketAddress +import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} + +import scala.io.Source + +import akka.actor.ActorRef +import com.fasterxml.jackson.core.JsonProcessingException +import org.eclipse.jetty.server.Server +import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler} +import org.eclipse.jetty.util.thread.QueuedThreadPool +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} +import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} +import org.apache.spark.deploy.ClientArguments._ + +/** + * A server that responds to requests submitted by the [[StandaloneRestClient]]. + * This is intended to be embedded in the standalone Master and used in cluster mode only. + * + * This server responds with different HTTP codes depending on the situation: + * 200 OK - Request was processed successfully + * 400 BAD REQUEST - Request was malformed, not successfully validated, or of unexpected type + * 468 UNKNOWN PROTOCOL VERSION - Request specified a protocol this server does not understand + * 500 INTERNAL SERVER ERROR - Server throws an exception internally while processing the request + * + * The server always includes a JSON representation of the relevant [[SubmitRestProtocolResponse]] + * in the HTTP body. If an error occurs, however, the server will include an [[ErrorResponse]] + * instead of the one expected by the client. If the construction of this error response itself + * fails, the response will consist of an empty body with a response code that indicates internal + * server error. + * + * @param host the address this server should bind to + * @param requestedPort the port this server will attempt to bind to + * @param masterActor reference to the Master actor to which requests can be sent + * @param masterUrl the URL of the Master new drivers will attempt to connect to + * @param masterConf the conf used by the Master + */ +private[spark] class StandaloneRestServer( + host: String, + requestedPort: Int, + masterActor: ActorRef, + masterUrl: String, + masterConf: SparkConf) + extends Logging { + + import StandaloneRestServer._ + + private var _server: Option[Server] = None + + // A mapping from URL prefixes to servlets that serve them. Exposed for testing. + protected val baseContext = s"/$PROTOCOL_VERSION/submissions" + protected val contextToServlet = Map[String, StandaloneRestServlet]( + s"$baseContext/create/*" -> new SubmitRequestServlet(masterActor, masterUrl, masterConf), + s"$baseContext/kill/*" -> new KillRequestServlet(masterActor, masterConf), + s"$baseContext/status/*" -> new StatusRequestServlet(masterActor, masterConf), + "/*" -> new ErrorServlet // default handler + ) + + /** Start the server and return the bound port. */ + def start(): Int = { + val (server, boundPort) = Utils.startServiceOnPort[Server](requestedPort, doStart, masterConf) + _server = Some(server) + logInfo(s"Started REST server for submitting applications on port $boundPort") + boundPort + } + + /** + * Map the servlets to their corresponding contexts and attach them to a server. + * Return a 2-tuple of the started server and the bound port. + */ + private def doStart(startPort: Int): (Server, Int) = { + val server = new Server(new InetSocketAddress(host, startPort)) + val threadPool = new QueuedThreadPool + threadPool.setDaemon(true) + server.setThreadPool(threadPool) + val mainHandler = new ServletContextHandler + mainHandler.setContextPath("/") + contextToServlet.foreach { case (prefix, servlet) => + mainHandler.addServlet(new ServletHolder(servlet), prefix) + } + server.setHandler(mainHandler) + server.start() + val boundPort = server.getConnectors()(0).getLocalPort + (server, boundPort) + } + + def stop(): Unit = { + _server.foreach(_.stop()) + } +} + +private[rest] object StandaloneRestServer { + val PROTOCOL_VERSION = StandaloneRestClient.PROTOCOL_VERSION + val SC_UNKNOWN_PROTOCOL_VERSION = 468 +} + +/** + * An abstract servlet for handling requests passed to the [[StandaloneRestServer]]. + */ +private[rest] abstract class StandaloneRestServlet extends HttpServlet with Logging { + + /** + * Serialize the given response message to JSON and send it through the response servlet. + * This validates the response before sending it to ensure it is properly constructed. + */ + protected def sendResponse( + responseMessage: SubmitRestProtocolResponse, + responseServlet: HttpServletResponse): Unit = { + val message = validateResponse(responseMessage, responseServlet) + responseServlet.setContentType("application/json") + responseServlet.setCharacterEncoding("utf-8") + responseServlet.getWriter.write(message.toJson) + } + + /** + * Return any fields in the client request message that the server does not know about. + * + * The mechanism for this is to reconstruct the JSON on the server side and compare the + * diff between this JSON and the one generated on the client side. Any fields that are + * only in the client JSON are treated as unexpected. + */ + protected def findUnknownFields( + requestJson: String, + requestMessage: SubmitRestProtocolMessage): Array[String] = { + val clientSideJson = parse(requestJson) + val serverSideJson = parse(requestMessage.toJson) + val Diff(_, _, unknown) = clientSideJson.diff(serverSideJson) + unknown match { + case j: JObject => j.obj.map { case (k, _) => k }.toArray + case _ => Array.empty[String] // No difference + } + } + + /** Return a human readable String representation of the exception. */ + protected def formatException(e: Throwable): String = { + val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n") + s"$e\n$stackTraceString" + } + + /** Construct an error message to signal the fact that an exception has been thrown. */ + protected def handleError(message: String): ErrorResponse = { + val e = new ErrorResponse + e.serverSparkVersion = sparkVersion + e.message = message + e + } + + /** + * Parse a submission ID from the relative path, assuming it is the first part of the path. + * For instance, we expect the path to take the form /[submission ID]/maybe/something/else. + * The returned submission ID cannot be empty. If the path is unexpected, return None. + */ + protected def parseSubmissionId(path: String): Option[String] = { + if (path == null || path.isEmpty) { + None + } else { + path.stripPrefix("/").split("/").headOption.filter(_.nonEmpty) + } + } + + /** + * Validate the response to ensure that it is correctly constructed. + * + * If it is, simply return the message as is. Otherwise, return an error response instead + * to propagate the exception back to the client and set the appropriate error code. + */ + private def validateResponse( + responseMessage: SubmitRestProtocolResponse, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { + try { + responseMessage.validate() + responseMessage + } catch { + case e: Exception => + responseServlet.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR) + handleError("Internal server error: " + formatException(e)) + } + } +} + +/** + * A servlet for handling kill requests passed to the [[StandaloneRestServer]]. + */ +private[rest] class KillRequestServlet(masterActor: ActorRef, conf: SparkConf) + extends StandaloneRestServlet { + + /** + * If a submission ID is specified in the URL, have the Master kill the corresponding + * driver and return an appropriate response to the client. Otherwise, return error. + */ + protected override def doPost( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val submissionId = parseSubmissionId(request.getPathInfo) + val responseMessage = submissionId.map(handleKill).getOrElse { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Submission ID is missing in kill request.") + } + sendResponse(responseMessage, response) + } + + protected def handleKill(submissionId: String): KillSubmissionResponse = { + val askTimeout = AkkaUtils.askTimeout(conf) + val response = AkkaUtils.askWithReply[DeployMessages.KillDriverResponse]( + DeployMessages.RequestKillDriver(submissionId), masterActor, askTimeout) + val k = new KillSubmissionResponse + k.serverSparkVersion = sparkVersion + k.message = response.message + k.submissionId = submissionId + k.success = response.success + k + } +} + +/** + * A servlet for handling status requests passed to the [[StandaloneRestServer]]. + */ +private[rest] class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf) + extends StandaloneRestServlet { + + /** + * If a submission ID is specified in the URL, request the status of the corresponding + * driver from the Master and include it in the response. Otherwise, return error. + */ + protected override def doGet( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val submissionId = parseSubmissionId(request.getPathInfo) + val responseMessage = submissionId.map(handleStatus).getOrElse { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Submission ID is missing in status request.") + } + sendResponse(responseMessage, response) + } + + protected def handleStatus(submissionId: String): SubmissionStatusResponse = { + val askTimeout = AkkaUtils.askTimeout(conf) + val response = AkkaUtils.askWithReply[DeployMessages.DriverStatusResponse]( + DeployMessages.RequestDriverStatus(submissionId), masterActor, askTimeout) + val message = response.exception.map { s"Exception from the cluster:\n" + formatException(_) } + val d = new SubmissionStatusResponse + d.serverSparkVersion = sparkVersion + d.submissionId = submissionId + d.success = response.found + d.driverState = response.state.map(_.toString).orNull + d.workerId = response.workerId.orNull + d.workerHostPort = response.workerHostPort.orNull + d.message = message.orNull + d + } +} + +/** + * A servlet for handling submit requests passed to the [[StandaloneRestServer]]. + */ +private[rest] class SubmitRequestServlet( + masterActor: ActorRef, + masterUrl: String, + conf: SparkConf) + extends StandaloneRestServlet { + + /** + * Submit an application to the Master with parameters specified in the request. + * + * The request is assumed to be a [[SubmitRestProtocolRequest]] in the form of JSON. + * If the request is successfully processed, return an appropriate response to the + * client indicating so. Otherwise, return error instead. + */ + protected override def doPost( + requestServlet: HttpServletRequest, + responseServlet: HttpServletResponse): Unit = { + val responseMessage = + try { + val requestMessageJson = Source.fromInputStream(requestServlet.getInputStream).mkString + val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson) + // The response should have already been validated on the client. + // In case this is not true, validate it ourselves to avoid potential NPEs. + requestMessage.validate() + handleSubmit(requestMessageJson, requestMessage, responseServlet) + } catch { + // The client failed to provide a valid JSON, so this is not our fault + case e @ (_: JsonProcessingException | _: SubmitRestProtocolException) => + responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Malformed request: " + formatException(e)) + } + sendResponse(responseMessage, responseServlet) + } + + /** + * Handle the submit request and construct an appropriate response to return to the client. + * + * This assumes that the request message is already successfully validated. + * If the request message is not of the expected type, return error to the client. + */ + private def handleSubmit( + requestMessageJson: String, + requestMessage: SubmitRestProtocolMessage, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { + requestMessage match { + case submitRequest: CreateSubmissionRequest => + val askTimeout = AkkaUtils.askTimeout(conf) + val driverDescription = buildDriverDescription(submitRequest) + val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse]( + DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout) + val submitResponse = new CreateSubmissionResponse + submitResponse.serverSparkVersion = sparkVersion + submitResponse.message = response.message + submitResponse.success = response.success + submitResponse.submissionId = response.driverId.orNull + val unknownFields = findUnknownFields(requestMessageJson, requestMessage) + if (unknownFields.nonEmpty) { + // If there are fields that the server does not know about, warn the client + submitResponse.unknownFields = unknownFields + } + submitResponse + case unexpected => + responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError(s"Received message of unexpected type ${unexpected.messageType}.") + } + } + + /** + * Build a driver description from the fields specified in the submit request. + * + * This involves constructing a command that takes into account memory, java options, + * classpath and other settings to launch the driver. This does not currently consider + * fields used by python applications since python is not supported in standalone + * cluster mode yet. + */ + private def buildDriverDescription(request: CreateSubmissionRequest): DriverDescription = { + // Required fields, including the main class because python is not yet supported + val appResource = Option(request.appResource).getOrElse { + throw new SubmitRestMissingFieldException("Application jar is missing.") + } + val mainClass = Option(request.mainClass).getOrElse { + throw new SubmitRestMissingFieldException("Main class is missing.") + } + + // Optional fields + val sparkProperties = request.sparkProperties + val driverMemory = sparkProperties.get("spark.driver.memory") + val driverCores = sparkProperties.get("spark.driver.cores") + val driverExtraJavaOptions = sparkProperties.get("spark.driver.extraJavaOptions") + val driverExtraClassPath = sparkProperties.get("spark.driver.extraClassPath") + val driverExtraLibraryPath = sparkProperties.get("spark.driver.extraLibraryPath") + val superviseDriver = sparkProperties.get("spark.driver.supervise") + val appArgs = request.appArgs + val environmentVariables = request.environmentVariables + + // Construct driver description + val conf = new SparkConf(false) + .setAll(sparkProperties) + .set("spark.master", masterUrl) + val extraClassPath = driverExtraClassPath.toSeq.flatMap(_.split(File.pathSeparator)) + val extraLibraryPath = driverExtraLibraryPath.toSeq.flatMap(_.split(File.pathSeparator)) + val extraJavaOpts = driverExtraJavaOptions.map(Utils.splitCommandString).getOrElse(Seq.empty) + val sparkJavaOpts = Utils.sparkJavaOpts(conf) + val javaOpts = sparkJavaOpts ++ extraJavaOpts + val command = new Command( + "org.apache.spark.deploy.worker.DriverWrapper", + Seq("{{WORKER_URL}}", "{{USER_JAR}}", mainClass) ++ appArgs, // args to the DriverWrapper + environmentVariables, extraClassPath, extraLibraryPath, javaOpts) + val actualDriverMemory = driverMemory.map(Utils.memoryStringToMb).getOrElse(DEFAULT_MEMORY) + val actualDriverCores = driverCores.map(_.toInt).getOrElse(DEFAULT_CORES) + val actualSuperviseDriver = superviseDriver.map(_.toBoolean).getOrElse(DEFAULT_SUPERVISE) + new DriverDescription( + appResource, actualDriverMemory, actualDriverCores, actualSuperviseDriver, command) + } +} + +/** + * A default servlet that handles error cases that are not captured by other servlets. + */ +private class ErrorServlet extends StandaloneRestServlet { + private val serverVersion = StandaloneRestServer.PROTOCOL_VERSION + + /** Service a faulty request by returning an appropriate error message to the client. */ + protected override def service( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val path = request.getPathInfo + val parts = path.stripPrefix("/").split("/").filter(_.nonEmpty).toList + var versionMismatch = false + var msg = + parts match { + case Nil => + // http://host:port/ + "Missing protocol version." + case `serverVersion` :: Nil => + // http://host:port/correct-version + "Missing the /submissions prefix." + case `serverVersion` :: "submissions" :: tail => + // http://host:port/correct-version/submissions/* + "Missing an action: please specify one of /create, /kill, or /status." + case unknownVersion :: tail => + // http://host:port/unknown-version/* + versionMismatch = true + s"Unknown protocol version '$unknownVersion'." + case _ => + // never reached + s"Malformed path $path." + } + msg += s" Please submit requests through http://[host]:[port]/$serverVersion/submissions/..." + val error = handleError(msg) + // If there is a version mismatch, include the highest protocol version that + // this server supports in case the client wants to retry with our version + if (versionMismatch) { + error.highestProtocolVersion = serverVersion + response.setStatus(StandaloneRestServer.SC_UNKNOWN_PROTOCOL_VERSION) + } else { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + } + sendResponse(error, response) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolException.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolException.scala new file mode 100644 index 0000000000000..d7a0bdbe10778 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolException.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +/** + * An exception thrown in the REST application submission protocol. + */ +private[spark] class SubmitRestProtocolException(message: String, cause: Throwable = null) + extends Exception(message, cause) + +/** + * An exception thrown if a field is missing from a [[SubmitRestProtocolMessage]]. + */ +private[spark] class SubmitRestMissingFieldException(message: String) + extends SubmitRestProtocolException(message) + +/** + * An exception thrown if the REST client cannot reach the REST server. + */ +private[spark] class SubmitRestConnectionException(message: String, cause: Throwable) + extends SubmitRestProtocolException(message, cause) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala new file mode 100644 index 0000000000000..8f36635674a28 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import com.fasterxml.jackson.annotation._ +import com.fasterxml.jackson.annotation.JsonAutoDetect.Visibility +import com.fasterxml.jackson.annotation.JsonInclude.Include +import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper, SerializationFeature} +import com.fasterxml.jackson.module.scala.DefaultScalaModule +import org.json4s.JsonAST._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.util.Utils + +/** + * An abstract message exchanged in the REST application submission protocol. + * + * This message is intended to be serialized to and deserialized from JSON in the exchange. + * Each message can either be a request or a response and consists of three common fields: + * (1) the action, which fully specifies the type of the message + * (2) the Spark version of the client / server + * (3) an optional message + */ +@JsonInclude(Include.NON_NULL) +@JsonAutoDetect(getterVisibility = Visibility.ANY, setterVisibility = Visibility.ANY) +@JsonPropertyOrder(alphabetic = true) +private[spark] abstract class SubmitRestProtocolMessage { + @JsonIgnore + val messageType = Utils.getFormattedClassName(this) + + val action: String = messageType + var message: String = null + + // For JSON deserialization + private def setAction(a: String): Unit = { } + + /** + * Serialize the message to JSON. + * This also ensures that the message is valid and its fields are in the expected format. + */ + def toJson: String = { + validate() + SubmitRestProtocolMessage.mapper.writeValueAsString(this) + } + + /** + * Assert the validity of the message. + * If the validation fails, throw a [[SubmitRestProtocolException]]. + */ + final def validate(): Unit = { + try { + doValidate() + } catch { + case e: Exception => + throw new SubmitRestProtocolException(s"Validation of message $messageType failed!", e) + } + } + + /** Assert the validity of the message */ + protected def doValidate(): Unit = { + if (action == null) { + throw new SubmitRestMissingFieldException(s"The action field is missing in $messageType") + } + } + + /** Assert that the specified field is set in this message. */ + protected def assertFieldIsSet[T](value: T, name: String): Unit = { + if (value == null) { + throw new SubmitRestMissingFieldException(s"'$name' is missing in message $messageType.") + } + } + + /** + * Assert a condition when validating this message. + * If the assertion fails, throw a [[SubmitRestProtocolException]]. + */ + protected def assert(condition: Boolean, failMessage: String): Unit = { + if (!condition) { throw new SubmitRestProtocolException(failMessage) } + } +} + +/** + * Helper methods to process serialized [[SubmitRestProtocolMessage]]s. + */ +private[spark] object SubmitRestProtocolMessage { + private val packagePrefix = this.getClass.getPackage.getName + private val mapper = new ObjectMapper() + .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) + .enable(SerializationFeature.INDENT_OUTPUT) + .registerModule(DefaultScalaModule) + + /** + * Parse the value of the action field from the given JSON. + * If the action field is not found, throw a [[SubmitRestMissingFieldException]]. + */ + def parseAction(json: String): String = { + val value: Option[String] = parse(json) match { + case JObject(fields) => + fields.collectFirst { case ("action", v) => v }.collect { case JString(s) => s } + case _ => None + } + value.getOrElse { + throw new SubmitRestMissingFieldException(s"Action field not found in JSON:\n$json") + } + } + + /** + * Construct a [[SubmitRestProtocolMessage]] from its JSON representation. + * + * This method first parses the action from the JSON and uses it to infer the message type. + * Note that the action must represent one of the [[SubmitRestProtocolMessage]]s defined in + * this package. Otherwise, a [[ClassNotFoundException]] will be thrown. + */ + def fromJson(json: String): SubmitRestProtocolMessage = { + val className = parseAction(json) + val clazz = Class.forName(packagePrefix + "." + className) + .asSubclass[SubmitRestProtocolMessage](classOf[SubmitRestProtocolMessage]) + fromJson(json, clazz) + } + + /** + * Construct a [[SubmitRestProtocolMessage]] from its JSON representation. + * + * This method determines the type of the message from the class provided instead of + * inferring it from the action field. This is useful for deserializing JSON that + * represents custom user-defined messages. + */ + def fromJson[T <: SubmitRestProtocolMessage](json: String, clazz: Class[T]): T = { + mapper.readValue(json, clazz) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala new file mode 100644 index 0000000000000..9e1fd8c40cabd --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import scala.util.Try + +import org.apache.spark.util.Utils + +/** + * An abstract request sent from the client in the REST application submission protocol. + */ +private[spark] abstract class SubmitRestProtocolRequest extends SubmitRestProtocolMessage { + var clientSparkVersion: String = null + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(clientSparkVersion, "clientSparkVersion") + } +} + +/** + * A request to launch a new application in the REST application submission protocol. + */ +private[spark] class CreateSubmissionRequest extends SubmitRestProtocolRequest { + var appResource: String = null + var mainClass: String = null + var appArgs: Array[String] = null + var sparkProperties: Map[String, String] = null + var environmentVariables: Map[String, String] = null + + protected override def doValidate(): Unit = { + super.doValidate() + assert(sparkProperties != null, "No Spark properties set!") + assertFieldIsSet(appResource, "appResource") + assertPropertyIsSet("spark.app.name") + assertPropertyIsBoolean("spark.driver.supervise") + assertPropertyIsNumeric("spark.driver.cores") + assertPropertyIsNumeric("spark.cores.max") + assertPropertyIsMemory("spark.driver.memory") + assertPropertyIsMemory("spark.executor.memory") + } + + private def assertPropertyIsSet(key: String): Unit = + assertFieldIsSet(sparkProperties.getOrElse(key, null), key) + + private def assertPropertyIsBoolean(key: String): Unit = + assertProperty[Boolean](key, "boolean", _.toBoolean) + + private def assertPropertyIsNumeric(key: String): Unit = + assertProperty[Int](key, "numeric", _.toInt) + + private def assertPropertyIsMemory(key: String): Unit = + assertProperty[Int](key, "memory", Utils.memoryStringToMb) + + /** Assert that a Spark property can be converted to a certain type. */ + private def assertProperty[T](key: String, valueType: String, convert: (String => T)): Unit = { + sparkProperties.get(key).foreach { value => + Try(convert(value)).getOrElse { + throw new SubmitRestProtocolException( + s"Property '$key' expected $valueType value: actual was '$value'.") + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala new file mode 100644 index 0000000000000..16dfe041d4bea --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.lang.Boolean + +/** + * An abstract response sent from the server in the REST application submission protocol. + */ +private[spark] abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage { + var serverSparkVersion: String = null + var success: Boolean = null + var unknownFields: Array[String] = null + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(serverSparkVersion, "serverSparkVersion") + } +} + +/** + * A response to a [[CreateSubmissionRequest]] in the REST application submission protocol. + */ +private[spark] class CreateSubmissionResponse extends SubmitRestProtocolResponse { + var submissionId: String = null + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(success, "success") + } +} + +/** + * A response to a kill request in the REST application submission protocol. + */ +private[spark] class KillSubmissionResponse extends SubmitRestProtocolResponse { + var submissionId: String = null + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(submissionId, "submissionId") + assertFieldIsSet(success, "success") + } +} + +/** + * A response to a status request in the REST application submission protocol. + */ +private[spark] class SubmissionStatusResponse extends SubmitRestProtocolResponse { + var submissionId: String = null + var driverState: String = null + var workerId: String = null + var workerHostPort: String = null + + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(submissionId, "submissionId") + assertFieldIsSet(success, "success") + } +} + +/** + * An error response message used in the REST application submission protocol. + */ +private[spark] class ErrorResponse extends SubmitRestProtocolResponse { + // The highest protocol version that the server knows about + // This is set when the client specifies an unknown version + var highestProtocolVersion: String = null + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(message, "message") + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 28cab36c7b9e2..e16bccb24d2c4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -20,19 +20,18 @@ package org.apache.spark.deploy.worker import java.io._ import scala.collection.JavaConversions._ -import scala.collection.Map import akka.actor.ActorRef import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileUtil, Path} import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.deploy.{Command, DriverDescription, SparkHadoopUtil} +import org.apache.spark.deploy.{DriverDescription, SparkHadoopUtil} import org.apache.spark.deploy.DeployMessages.DriverStateChanged import org.apache.spark.deploy.master.DriverState import org.apache.spark.deploy.master.DriverState.DriverState +import org.apache.spark.util.{Clock, SystemClock} /** * Manages the execution of one driver, including automatically restarting the driver on failure. @@ -59,9 +58,7 @@ private[spark] class DriverRunner( // Decoupled for testing private[deploy] def setClock(_clock: Clock) = clock = _clock private[deploy] def setSleeper(_sleeper: Sleeper) = sleeper = _sleeper - private var clock = new Clock { - def currentTimeMillis(): Long = System.currentTimeMillis() - } + private var clock: Clock = new SystemClock() private var sleeper = new Sleeper { def sleep(seconds: Int): Unit = (0 until seconds).takeWhile(f => {Thread.sleep(1000); !killed}) } @@ -74,10 +71,15 @@ private[spark] class DriverRunner( val driverDir = createWorkingDirectory() val localJarFilename = downloadUserJar(driverDir) - // Make sure user application jar is on the classpath + def substituteVariables(argument: String): String = argument match { + case "{{WORKER_URL}}" => workerUrl + case "{{USER_JAR}}" => localJarFilename + case other => other + } + // TODO: If we add ability to submit multiple jars they should also be added here val builder = CommandUtils.buildProcessBuilder(driverDesc.command, driverDesc.mem, - sparkHome.getAbsolutePath, substituteVariables, Seq(localJarFilename)) + sparkHome.getAbsolutePath, substituteVariables) launchDriver(builder, driverDir, driverDesc.supervise) } catch { @@ -111,12 +113,6 @@ private[spark] class DriverRunner( } } - /** Replace variables in a command argument passed to us */ - private def substituteVariables(argument: String): String = argument match { - case "{{WORKER_URL}}" => workerUrl - case other => other - } - /** * Creates the working directory for this driver. * Will throw an exception if there are errors preparing the directory. @@ -191,9 +187,9 @@ private[spark] class DriverRunner( initialize(process.get) } - val processStart = clock.currentTimeMillis() + val processStart = clock.getTimeMillis() val exitCode = process.get.waitFor() - if (clock.currentTimeMillis() - processStart > successfulRunDuration * 1000) { + if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000) { waitSeconds = 1 } @@ -209,10 +205,6 @@ private[spark] class DriverRunner( } } -private[deploy] trait Clock { - def currentTimeMillis(): Long -} - private[deploy] trait Sleeper { def sleep(seconds: Int) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index 05e242e6df702..deef6ef9043c6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -17,32 +17,51 @@ package org.apache.spark.deploy.worker +import java.io.File + import akka.actor._ import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{AkkaUtils, ChildFirstURLClassLoader, MutableURLClassLoader, Utils} /** * Utility object for launching driver programs such that they share fate with the Worker process. + * This is used in standalone cluster mode only. */ object DriverWrapper { def main(args: Array[String]) { args.toList match { - case workerUrl :: mainClass :: extraArgs => + /* + * IMPORTANT: Spark 1.3 provides a stable application submission gateway that is both + * backward and forward compatible across future Spark versions. Because this gateway + * uses this class to launch the driver, the ordering and semantics of the arguments + * here must also remain consistent across versions. + */ + case workerUrl :: userJar :: mainClass :: extraArgs => val conf = new SparkConf() val (actorSystem, _) = AkkaUtils.createActorSystem("Driver", Utils.localHostName(), 0, conf, new SecurityManager(conf)) actorSystem.actorOf(Props(classOf[WorkerWatcher], workerUrl), name = "workerWatcher") + val currentLoader = Thread.currentThread.getContextClassLoader + val userJarUrl = new File(userJar).toURI().toURL() + val loader = + if (sys.props.getOrElse("spark.driver.userClassPathFirst", "false").toBoolean) { + new ChildFirstURLClassLoader(Array(userJarUrl), currentLoader) + } else { + new MutableURLClassLoader(Array(userJarUrl), currentLoader) + } + Thread.currentThread.setContextClassLoader(loader) + // Delegate to supplied main class - val clazz = Class.forName(args(1)) + val clazz = Class.forName(mainClass, true, loader) val mainMethod = clazz.getMethod("main", classOf[Array[String]]) mainMethod.invoke(null, extraArgs.toArray[String]) actorSystem.shutdown() case _ => - System.err.println("Usage: DriverWrapper [options]") + System.err.println("Usage: DriverWrapper [options]") System.exit(-1) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index bc9f78b9e5c77..a3ec803461400 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -43,6 +43,8 @@ private[spark] class ExecutorRunner( val worker: ActorRef, val workerId: String, val host: String, + val webUiPort: Int, + val publicAddress: String, val sparkHome: File, val executorDir: File, val workerUrl: String, @@ -84,14 +86,13 @@ private[spark] class ExecutorRunner( var exitCode: Option[Int] = None if (process != null) { logInfo("Killing process!") - process.destroy() - process.waitFor() if (stdoutAppender != null) { stdoutAppender.stop() } if (stderrAppender != null) { stderrAppender.stop() } + process.destroy() exitCode = Some(process.waitFor()) } worker ! ExecutorStateChanged(appId, execId, state, message, exitCode) @@ -104,7 +105,11 @@ private[spark] class ExecutorRunner( workerThread.interrupt() workerThread = null state = ExecutorState.KILLED - Runtime.getRuntime.removeShutdownHook(shutdownHook) + try { + Runtime.getRuntime.removeShutdownHook(shutdownHook) + } catch { + case e: IllegalStateException => None + } } } @@ -134,6 +139,13 @@ private[spark] class ExecutorRunner( // In case we are running this from within the Spark Shell, avoid creating a "scala" // parent process for the executor command builder.environment.put("SPARK_LAUNCH_WITH_SCALA", "0") + + // Add webUI log urls + val baseUrl = + s"http://$publicAddress:$webUiPort/logPage/?appId=$appId&executorId=$execId&logType=" + builder.environment.put("SPARK_LOG_URL_STDERR", s"${baseUrl}stderr") + builder.environment.put("SPARK_LOG_URL_STDOUT", s"${baseUrl}stdout") + process = builder.start() val header = "Spark Executor Command: %s\n%s\n\n".format( command.mkString("\"", "\" \"", "\""), "=" * 40) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index b20f5c0c82895..3036cd6e3c8a2 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -121,7 +121,7 @@ private[spark] class Worker( val shuffleService = new StandaloneWorkerShuffleService(conf, securityMgr) val publicAddress = { - val envVar = System.getenv("SPARK_PUBLIC_DNS") + val envVar = conf.getenv("SPARK_PUBLIC_DNS") if (envVar != null) envVar else host } var webUi: WorkerWebUI = null @@ -362,6 +362,8 @@ private[spark] class Worker( self, workerId, host, + webUi.boundPort, + publicAddress, sparkHome, executorDir, akkaUrl, @@ -537,10 +539,10 @@ private[spark] object Worker extends Logging { memory: Int, masterUrls: Array[String], workDir: String, - workerNumber: Option[Int] = None): (ActorSystem, Int) = { + workerNumber: Option[Int] = None, + conf: SparkConf = new SparkConf): (ActorSystem, Int) = { // The LocalSparkCluster runs multiple local sparkWorkerX actor systems - val conf = new SparkConf val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") val actorName = "Worker" val securityMgr = new SecurityManager(conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index 327b905032800..720f13bfa829b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -134,7 +134,7 @@ private[spark] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { def driverRow(driver: DriverRunner): Seq[Node] = { {driver.driverId} - {driver.driverDesc.command.arguments(1)} + {driver.driverDesc.command.arguments(2)} {driver.finalState.getOrElse(DriverState.RUNNING)} {driver.driverDesc.cores.toString} diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index bc72c8970319c..dd19e4947db1e 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -17,8 +17,10 @@ package org.apache.spark.executor +import java.net.URL import java.nio.ByteBuffer +import scala.collection.mutable import scala.concurrent.Await import akka.actor.{Actor, ActorSelection, Props} @@ -38,6 +40,7 @@ private[spark] class CoarseGrainedExecutorBackend( executorId: String, hostPort: String, cores: Int, + userClassPath: Seq[URL], env: SparkEnv) extends Actor with ActorLogReceive with ExecutorBackend with Logging { @@ -49,15 +52,21 @@ private[spark] class CoarseGrainedExecutorBackend( override def preStart() { logInfo("Connecting to driver: " + driverUrl) driver = context.actorSelection(driverUrl) - driver ! RegisterExecutor(executorId, hostPort, cores) + driver ! RegisterExecutor(executorId, hostPort, cores, extractLogUrls) context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) } + def extractLogUrls: Map[String, String] = { + val prefix = "SPARK_LOG_URL_" + sys.env.filterKeys(_.startsWith(prefix)) + .map(e => (e._1.substring(prefix.length).toLowerCase, e._2)) + } + override def receiveWithLogging = { case RegisteredExecutor => logInfo("Successfully registered with driver") val (hostname, _) = Utils.parseHostPort(hostPort) - executor = new Executor(executorId, hostname, env, isLocal = false) + executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false) case RegisterExecutorFailed(message) => logError("Slave registration failed: " + message) @@ -111,7 +120,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { hostname: String, cores: Int, appId: String, - workerUrl: Option[String]) { + workerUrl: Option[String], + userClassPath: Seq[URL]) { SignalLogger.register(log) @@ -156,7 +166,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { val sparkHostPort = hostname + ":" + boundPort env.actorSystem.actorOf( Props(classOf[CoarseGrainedExecutorBackend], - driverUrl, executorId, sparkHostPort, cores, env), + driverUrl, executorId, sparkHostPort, cores, userClassPath, env), name = "Executor") workerUrl.foreach { url => env.actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher") @@ -166,20 +176,69 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { } def main(args: Array[String]) { - args.length match { - case x if x < 5 => - System.err.println( + var driverUrl: String = null + var executorId: String = null + var hostname: String = null + var cores: Int = 0 + var appId: String = null + var workerUrl: Option[String] = None + val userClassPath = new mutable.ListBuffer[URL]() + + var argv = args.toList + while (!argv.isEmpty) { + argv match { + case ("--driver-url") :: value :: tail => + driverUrl = value + argv = tail + case ("--executor-id") :: value :: tail => + executorId = value + argv = tail + case ("--hostname") :: value :: tail => + hostname = value + argv = tail + case ("--cores") :: value :: tail => + cores = value.toInt + argv = tail + case ("--app-id") :: value :: tail => + appId = value + argv = tail + case ("--worker-url") :: value :: tail => // Worker url is used in spark standalone mode to enforce fate-sharing with worker - "Usage: CoarseGrainedExecutorBackend " + - " [] ") - System.exit(1) + workerUrl = Some(value) + argv = tail + case ("--user-class-path") :: value :: tail => + userClassPath += new URL(value) + argv = tail + case Nil => + case tail => + System.err.println(s"Unrecognized options: ${tail.mkString(" ")}") + printUsageAndExit() + } + } - // NB: These arguments are provided by SparkDeploySchedulerBackend (for standalone mode) - // and CoarseMesosSchedulerBackend (for mesos mode). - case 5 => - run(args(0), args(1), args(2), args(3).toInt, args(4), None) - case x if x > 5 => - run(args(0), args(1), args(2), args(3).toInt, args(4), Some(args(5))) + if (driverUrl == null || executorId == null || hostname == null || cores <= 0 || + appId == null) { + printUsageAndExit() } + + run(driverUrl, executorId, hostname, cores, appId, workerUrl, userClassPath) + } + + private def printUsageAndExit() = { + System.err.println( + """ + |"Usage: CoarseGrainedExecutorBackend [options] + | + | Options are: + | --driver-url + | --executor-id + | --hostname + | --cores + | --app-id + | --worker-url + | --user-class-path + |""".stripMargin) + System.exit(1) } + } diff --git a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala new file mode 100644 index 0000000000000..f7604a321f007 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import org.apache.spark.{TaskCommitDenied, TaskEndReason} + +/** + * Exception thrown when a task attempts to commit output to HDFS but is denied by the driver. + */ +class CommitDeniedException( + msg: String, + jobID: Int, + splitID: Int, + attemptID: Int) + extends Exception(msg) { + + def toTaskEndReason: TaskEndReason = new TaskCommitDenied(jobID, splitID, attemptID) + +} + diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 312bb3a1daaa3..c6ff38d527d8d 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -19,6 +19,7 @@ package org.apache.spark.executor import java.io.File import java.lang.management.ManagementFactory +import java.net.URL import java.nio.ByteBuffer import java.util.concurrent._ @@ -33,7 +34,8 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler._ import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} -import org.apache.spark.util.{SparkUncaughtExceptionHandler, AkkaUtils, Utils} +import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, + SparkUncaughtExceptionHandler, AkkaUtils, Utils} /** * Spark executor used with Mesos, YARN, and the standalone scheduler. @@ -43,6 +45,7 @@ private[spark] class Executor( executorId: String, executorHostname: String, env: SparkEnv, + userClassPath: Seq[URL] = Nil, isLocal: Boolean = false) extends Logging { @@ -75,6 +78,9 @@ private[spark] class Executor( Thread.setDefaultUncaughtExceptionHandler(SparkUncaughtExceptionHandler) } + // Start worker thread pool + val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker") + val executorSource = new ExecutorSource(this, executorId) if (!isLocal) { @@ -86,13 +92,19 @@ private[spark] class Executor( private val executorActor = env.actorSystem.actorOf( Props(new ExecutorActor(executorId)), "ExecutorActor") + // Whether to load classes in user jars before those in Spark jars + private val userClassPathFirst: Boolean = { + conf.getBoolean("spark.executor.userClassPathFirst", + conf.getBoolean("spark.files.userClassPathFirst", false)) + } + // Create our ClassLoader // do this after SparkEnv creation so can access the SecurityManager private val urlClassLoader = createClassLoader() private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) // Set the classloader for serializer - env.serializer.setDefaultClassLoader(urlClassLoader) + env.serializer.setDefaultClassLoader(replClassLoader) // Akka's message frame size. If task result is bigger than this, we use the block manager // to send the result back. @@ -101,9 +113,6 @@ private[spark] class Executor( // Limit of bytes for total size of results (default is 1GB) private val maxResultSize = Utils.getMaxResultSize(conf) - // Start worker thread pool - val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker") - // Maintains the list of running tasks. private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] @@ -250,6 +259,11 @@ private[spark] class Executor( execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) } + case cDE: CommitDeniedException => { + val reason = cDE.toTaskEndReason + execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) + } + case t: Throwable => { // Attempt to exit cleanly by informing the driver of our failure. // If anything goes wrong (or this was a fatal exception), we will delegate to @@ -288,17 +302,23 @@ private[spark] class Executor( * created by the interpreter to the search path */ private def createClassLoader(): MutableURLClassLoader = { + // Bootstrap the list of jars with the user class path. + val now = System.currentTimeMillis() + userClassPath.foreach { url => + currentJars(url.getPath().split("/").last) = now + } + val currentLoader = Utils.getContextOrSparkClassLoader // For each of the jars in the jarSet, add them to the class loader. // We assume each of the files has already been fetched. - val urls = currentJars.keySet.map { uri => + val urls = userClassPath.toArray ++ currentJars.keySet.map { uri => new File(uri.split("/").last).toURI.toURL - }.toArray - val userClassPathFirst = conf.getBoolean("spark.files.userClassPathFirst", false) - userClassPathFirst match { - case true => new ChildExecutorURLClassLoader(urls, currentLoader) - case false => new ExecutorURLClassLoader(urls, currentLoader) + } + if (userClassPathFirst) { + new ChildFirstURLClassLoader(urls, currentLoader) + } else { + new MutableURLClassLoader(urls, currentLoader) } } @@ -310,14 +330,13 @@ private[spark] class Executor( val classUri = conf.get("spark.repl.class.uri", null) if (classUri != null) { logInfo("Using REPL class URI: " + classUri) - val userClassPathFirst: java.lang.Boolean = - conf.getBoolean("spark.files.userClassPathFirst", false) try { + val _userClassPathFirst: java.lang.Boolean = userClassPathFirst val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader") .asInstanceOf[Class[_ <: ClassLoader]] val constructor = klass.getConstructor(classOf[SparkConf], classOf[String], classOf[ClassLoader], classOf[Boolean]) - constructor.newInstance(conf, classUri, parent, userClassPathFirst) + constructor.newInstance(conf, classUri, parent, _userClassPathFirst) } catch { case _: ClassNotFoundException => logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!") @@ -344,18 +363,23 @@ private[spark] class Executor( env.securityManager, hadoopConf, timestamp, useCache = !isLocal) currentFiles(name) = timestamp } - for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - // Fetch file with useCache mode, close cache for local mode. - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, - env.securityManager, hadoopConf, timestamp, useCache = !isLocal) - currentJars(name) = timestamp - // Add it to our class loader + for ((name, timestamp) <- newJars) { val localName = name.split("/").last - val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL - if (!urlClassLoader.getURLs.contains(url)) { - logInfo("Adding " + url + " to class loader") - urlClassLoader.addURL(url) + val currentTimeStamp = currentJars.get(name) + .orElse(currentJars.get(localName)) + .getOrElse(-1L) + if (currentTimeStamp < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + // Fetch file with useCache mode, close cache for local mode. + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, + env.securityManager, hadoopConf, timestamp, useCache = !isLocal) + currentJars(name) = timestamp + // Add it to our class loader + val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL + if (!urlClassLoader.getURLs.contains(url)) { + logInfo("Adding " + url + " to class loader") + urlClassLoader.addURL(url) + } } } } diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala deleted file mode 100644 index 218ed7b5d2d39..0000000000000 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.executor - -import java.net.{URLClassLoader, URL} - -import org.apache.spark.util.ParentClassLoader - -/** - * The addURL method in URLClassLoader is protected. We subclass it to make this accessible. - * We also make changes so user classes can come before the default classes. - */ - -private[spark] trait MutableURLClassLoader extends ClassLoader { - def addURL(url: URL) - def getURLs: Array[URL] -} - -private[spark] class ChildExecutorURLClassLoader(urls: Array[URL], parent: ClassLoader) - extends MutableURLClassLoader { - - private object userClassLoader extends URLClassLoader(urls, null){ - override def addURL(url: URL) { - super.addURL(url) - } - override def findClass(name: String): Class[_] = { - super.findClass(name) - } - } - - private val parentClassLoader = new ParentClassLoader(parent) - - override def findClass(name: String): Class[_] = { - try { - userClassLoader.findClass(name) - } catch { - case e: ClassNotFoundException => { - parentClassLoader.loadClass(name) - } - } - } - - def addURL(url: URL) { - userClassLoader.addURL(url) - } - - def getURLs() = { - userClassLoader.getURLs() - } -} - -private[spark] class ExecutorURLClassLoader(urls: Array[URL], parent: ClassLoader) - extends URLClassLoader(urls, parent) with MutableURLClassLoader { - - override def addURL(url: URL) { - super.addURL(url) - } -} - diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 97912c68c5982..07b152651dedf 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -177,8 +177,8 @@ class TaskMetrics extends Serializable { * Once https://issues.apache.org/jira/browse/SPARK-5225 is addressed, * we can store all the different inputMetrics (one per readMethod). */ - private[spark] def getInputMetricsForReadMethod(readMethod: DataReadMethod): - InputMetrics =synchronized { + private[spark] def getInputMetricsForReadMethod( + readMethod: DataReadMethod): InputMetrics = synchronized { _inputMetrics match { case None => val metrics = new InputMetrics(readMethod) @@ -194,18 +194,22 @@ class TaskMetrics extends Serializable { /** * Aggregates shuffle read metrics for all registered dependencies into shuffleReadMetrics. */ - private[spark] def updateShuffleReadMetrics() = synchronized { - val merged = new ShuffleReadMetrics() - for (depMetrics <- depsShuffleReadMetrics) { - merged.incFetchWaitTime(depMetrics.fetchWaitTime) - merged.incLocalBlocksFetched(depMetrics.localBlocksFetched) - merged.incRemoteBlocksFetched(depMetrics.remoteBlocksFetched) - merged.incRemoteBytesRead(depMetrics.remoteBytesRead) + private[spark] def updateShuffleReadMetrics(): Unit = synchronized { + if (!depsShuffleReadMetrics.isEmpty) { + val merged = new ShuffleReadMetrics() + for (depMetrics <- depsShuffleReadMetrics) { + merged.incFetchWaitTime(depMetrics.fetchWaitTime) + merged.incLocalBlocksFetched(depMetrics.localBlocksFetched) + merged.incRemoteBlocksFetched(depMetrics.remoteBlocksFetched) + merged.incRemoteBytesRead(depMetrics.remoteBytesRead) + merged.incLocalBytesRead(depMetrics.localBytesRead) + merged.incRecordsRead(depMetrics.recordsRead) + } + _shuffleReadMetrics = Some(merged) } - _shuffleReadMetrics = Some(merged) } - private[spark] def updateInputMetrics() = synchronized { + private[spark] def updateInputMetrics(): Unit = synchronized { inputMetrics.foreach(_.updateBytesRead()) } } @@ -242,27 +246,31 @@ object DataWriteMethod extends Enumeration with Serializable { @DeveloperApi case class InputMetrics(readMethod: DataReadMethod.Value) { - private val _bytesRead: AtomicLong = new AtomicLong() + /** + * This is volatile so that it is visible to the updater thread. + */ + @volatile @transient var bytesReadCallback: Option[() => Long] = None /** * Total bytes read. */ - def bytesRead: Long = _bytesRead.get() - @volatile @transient var bytesReadCallback: Option[() => Long] = None + private var _bytesRead: Long = _ + def bytesRead: Long = _bytesRead + def incBytesRead(bytes: Long) = _bytesRead += bytes /** - * Adds additional bytes read for this read method. + * Total records read. */ - def addBytesRead(bytes: Long) = { - _bytesRead.addAndGet(bytes) - } + private var _recordsRead: Long = _ + def recordsRead: Long = _recordsRead + def incRecordsRead(records: Long) = _recordsRead += records /** * Invoke the bytesReadCallback and mutate bytesRead. */ def updateBytesRead() { bytesReadCallback.foreach { c => - _bytesRead.set(c()) + _bytesRead = c() } } @@ -287,6 +295,13 @@ case class OutputMetrics(writeMethod: DataWriteMethod.Value) { private var _bytesWritten: Long = _ def bytesWritten = _bytesWritten private[spark] def setBytesWritten(value : Long) = _bytesWritten = value + + /** + * Total records written + */ + private var _recordsWritten: Long = 0L + def recordsWritten = _recordsWritten + private[spark] def setRecordsWritten(value: Long) = _recordsWritten = value } /** @@ -301,7 +316,7 @@ class ShuffleReadMetrics extends Serializable { private var _remoteBlocksFetched: Int = _ def remoteBlocksFetched = _remoteBlocksFetched private[spark] def incRemoteBlocksFetched(value: Int) = _remoteBlocksFetched += value - private[spark] def defRemoteBlocksFetched(value: Int) = _remoteBlocksFetched -= value + private[spark] def decRemoteBlocksFetched(value: Int) = _remoteBlocksFetched -= value /** * Number of local blocks fetched in this shuffle by this task @@ -309,8 +324,7 @@ class ShuffleReadMetrics extends Serializable { private var _localBlocksFetched: Int = _ def localBlocksFetched = _localBlocksFetched private[spark] def incLocalBlocksFetched(value: Int) = _localBlocksFetched += value - private[spark] def defLocalBlocksFetched(value: Int) = _localBlocksFetched -= value - + private[spark] def decLocalBlocksFetched(value: Int) = _localBlocksFetched -= value /** * Time the task spent waiting for remote shuffle blocks. This only includes the time @@ -330,10 +344,30 @@ class ShuffleReadMetrics extends Serializable { private[spark] def incRemoteBytesRead(value: Long) = _remoteBytesRead += value private[spark] def decRemoteBytesRead(value: Long) = _remoteBytesRead -= value + /** + * Shuffle data that was read from the local disk (as opposed to from a remote executor). + */ + private var _localBytesRead: Long = _ + def localBytesRead = _localBytesRead + private[spark] def incLocalBytesRead(value: Long) = _localBytesRead += value + + /** + * Total bytes fetched in the shuffle by this task (both remote and local). + */ + def totalBytesRead = _remoteBytesRead + _localBytesRead + /** * Number of blocks fetched in this shuffle by this task (remote or local) */ def totalBlocksFetched = _remoteBlocksFetched + _localBlocksFetched + + /** + * Total number of records read from the shuffle by this task + */ + private var _recordsRead: Long = _ + def recordsRead = _recordsRead + private[spark] def incRecordsRead(value: Long) = _recordsRead += value + private[spark] def decRecordsRead(value: Long) = _recordsRead -= value } /** @@ -358,5 +392,12 @@ class ShuffleWriteMetrics extends Serializable { private[spark] def incShuffleWriteTime(value: Long) = _shuffleWriteTime += value private[spark] def decShuffleWriteTime(value: Long) = _shuffleWriteTime -= value - + /** + * Total number of records written to the shuffle by this task + */ + @volatile private var _shuffleRecordsWritten: Long = _ + def shuffleRecordsWritten = _shuffleRecordsWritten + private[spark] def incShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten += value + private[spark] def decShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten -= value + private[spark] def setShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten = value } diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index f856890d279f4..0709b6d689e86 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -26,7 +26,6 @@ import org.xerial.snappy.{Snappy, SnappyInputStream, SnappyOutputStream} import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.Utils -import org.apache.spark.Logging /** * :: DeveloperApi :: @@ -53,8 +52,12 @@ private[spark] object CompressionCodec { "lzf" -> classOf[LZFCompressionCodec].getName, "snappy" -> classOf[SnappyCompressionCodec].getName) + def getCodecName(conf: SparkConf): String = { + conf.get(configKey, DEFAULT_COMPRESSION_CODEC) + } + def createCodec(conf: SparkConf): CompressionCodec = { - createCodec(conf, conf.get(configKey, DEFAULT_COMPRESSION_CODEC)) + createCodec(conf, getCodecName(conf)) } def createCodec(conf: SparkConf, codecName: String): CompressionCodec = { @@ -71,6 +74,20 @@ private[spark] object CompressionCodec { s"Consider setting $configKey=$FALLBACK_COMPRESSION_CODEC")) } + /** + * Return the short version of the given codec name. + * If it is already a short name, just return it. + */ + def getShortName(codecName: String): String = { + if (shortCompressionCodecNames.contains(codecName)) { + codecName + } else { + shortCompressionCodecNames + .collectFirst { case (k, v) if v == codecName => k } + .getOrElse { throw new IllegalArgumentException(s"No short name for codec $codecName.") } + } + } + val FALLBACK_COMPRESSION_CODEC = "lzf" val DEFAULT_COMPRESSION_CODEC = "snappy" val ALL_COMPRESSION_CODECS = shortCompressionCodecNames.values.toSeq diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 21b782edd2a9e..639ee78d6c599 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -17,9 +17,15 @@ package org.apache.spark.mapred +import java.io.IOException import java.lang.reflect.Modifier -import org.apache.hadoop.mapred.{TaskAttemptID, JobID, JobConf, JobContext, TaskAttemptContext} +import org.apache.hadoop.mapred._ +import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext} +import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter} + +import org.apache.spark.executor.CommitDeniedException +import org.apache.spark.{Logging, SparkEnv, TaskContext} private[spark] trait SparkHadoopMapRedUtil { @@ -65,3 +71,86 @@ trait SparkHadoopMapRedUtil { } } } + +object SparkHadoopMapRedUtil extends Logging { + /** + * Commits a task output. Before committing the task output, we need to know whether some other + * task attempt might be racing to commit the same output partition. Therefore, coordinate with + * the driver in order to determine whether this attempt can commit (please see SPARK-4879 for + * details). + * + * Output commit coordinator is only contacted when the following two configurations are both set + * to `true`: + * + * - `spark.speculation` + * - `spark.hadoop.outputCommitCoordination.enabled` + */ + def commitTask( + committer: MapReduceOutputCommitter, + mrTaskContext: MapReduceTaskAttemptContext, + jobId: Int, + splitId: Int, + attemptId: Int): Unit = { + + val mrTaskAttemptID = mrTaskContext.getTaskAttemptID + + // Called after we have decided to commit + def performCommit(): Unit = { + try { + committer.commitTask(mrTaskContext) + logInfo(s"$mrTaskAttemptID: Committed") + } catch { + case cause: IOException => + logError(s"Error committing the output of task: $mrTaskAttemptID", cause) + committer.abortTask(mrTaskContext) + throw cause + } + } + + // First, check whether the task's output has already been committed by some other attempt + if (committer.needsTaskCommit(mrTaskContext)) { + val shouldCoordinateWithDriver: Boolean = { + val sparkConf = SparkEnv.get.conf + // We only need to coordinate with the driver if there are multiple concurrent task + // attempts, which should only occur if speculation is enabled + val speculationEnabled = sparkConf.getBoolean("spark.speculation", defaultValue = false) + // This (undocumented) setting is an escape-hatch in case the commit code introduces bugs + sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", speculationEnabled) + } + + if (shouldCoordinateWithDriver) { + val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator + val canCommit = outputCommitCoordinator.canCommit(jobId, splitId, attemptId) + + if (canCommit) { + performCommit() + } else { + val message = + s"$mrTaskAttemptID: Not committed because the driver did not authorize commit" + logInfo(message) + // We need to abort the task so that the driver can reschedule new attempts, if necessary + committer.abortTask(mrTaskContext) + throw new CommitDeniedException(message, jobId, splitId, attemptId) + } + } else { + // Speculation is disabled or a user has chosen to manually bypass the commit coordination + performCommit() + } + } else { + // Some other attempt committed the output, so we do nothing and signal success + logInfo(s"No need to commit output of task because needsTaskCommit=false: $mrTaskAttemptID") + } + } + + def commitTask( + committer: MapReduceOutputCommitter, + mrTaskContext: MapReduceTaskAttemptContext, + sparkTaskContext: TaskContext): Unit = { + commitTask( + committer, + mrTaskContext, + sparkTaskContext.stageId(), + sparkTaskContext.partitionId(), + sparkTaskContext.attemptNumber()) + } +} diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala index 1b7a5d1f1980a..8edf493780687 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala @@ -28,12 +28,12 @@ import org.apache.spark.util.Utils private[spark] class MetricsConfig(val configFile: Option[String]) extends Logging { - val DEFAULT_PREFIX = "*" - val INSTANCE_REGEX = "^(\\*|[a-zA-Z]+)\\.(.+)".r - val METRICS_CONF = "metrics.properties" + private val DEFAULT_PREFIX = "*" + private val INSTANCE_REGEX = "^(\\*|[a-zA-Z]+)\\.(.+)".r + private val DEFAULT_METRICS_CONF_FILENAME = "metrics.properties" - val properties = new Properties() - var propertyCategories: mutable.HashMap[String, Properties] = null + private[metrics] val properties = new Properties() + private[metrics] var propertyCategories: mutable.HashMap[String, Properties] = null private def setDefaultProperties(prop: Properties) { prop.setProperty("*.sink.servlet.class", "org.apache.spark.metrics.sink.MetricsServlet") @@ -47,20 +47,22 @@ private[spark] class MetricsConfig(val configFile: Option[String]) extends Loggi setDefaultProperties(properties) // If spark.metrics.conf is not set, try to get file in class path - var is: InputStream = null - try { - is = configFile match { - case Some(f) => new FileInputStream(f) - case None => Utils.getSparkClassLoader.getResourceAsStream(METRICS_CONF) + val isOpt: Option[InputStream] = configFile.map(new FileInputStream(_)).orElse { + try { + Option(Utils.getSparkClassLoader.getResourceAsStream(DEFAULT_METRICS_CONF_FILENAME)) + } catch { + case e: Exception => + logError("Error loading default configuration file", e) + None } + } - if (is != null) { + isOpt.foreach { is => + try { properties.load(is) + } finally { + is.close() } - } catch { - case e: Exception => logError("Error loading configure file", e) - } finally { - if (is != null) is.close() } propertyCategories = subProperties(properties, INSTANCE_REGEX) diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 83e8eb71260eb..345db36630fd5 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -191,7 +191,10 @@ private[spark] class MetricsSystem private ( sinks += sink.asInstanceOf[Sink] } } catch { - case e: Exception => logError("Sink class " + classPath + " cannot be instantialized", e) + case e: Exception => { + logError("Sink class " + classPath + " cannot be instantialized") + throw e + } } } } diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index 5ad73c3d27f47..7780372cbabed 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -27,8 +27,7 @@ package org.apache * contains operations available only on RDDs of Doubles; and * [[org.apache.spark.rdd.SequenceFileRDDFunctions]] contains operations available on RDDs that can * be saved as SequenceFiles. These operations are automatically available on any RDD of the right - * type (e.g. RDD[(Int, Int)] through implicit conversions except `saveAsSequenceFile`. You need to - * `import org.apache.spark.SparkContext._` to make `saveAsSequenceFile` work. + * type (e.g. RDD[(Int, Int)] through implicit conversions. * * Java programmers should reference the [[org.apache.spark.api.java]] package * for Spark programming APIs in Java. @@ -44,5 +43,5 @@ package org.apache package object spark { // For package docs only - val SPARK_VERSION = "1.3.0-SNAPSHOT" + val SPARK_VERSION = "1.3.1" } diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index e66f83bb34e30..a3a03ef0495c7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -191,25 +191,23 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { } } // Determine the bucket function in constant time. Requires that buckets are evenly spaced - def fastBucketFunction(min: Double, increment: Double, count: Int)(e: Double): Option[Int] = { + def fastBucketFunction(min: Double, max: Double, count: Int)(e: Double): Option[Int] = { // If our input is not a number unless the increment is also NaN then we fail fast - if (e.isNaN()) { - return None - } - val bucketNumber = (e - min)/(increment) - // We do this rather than buckets.lengthCompare(bucketNumber) - // because Array[Double] fails to override it (for now). - if (bucketNumber > count || bucketNumber < 0) { + if (e.isNaN || e < min || e > max) { None } else { - Some(bucketNumber.toInt.min(count - 1)) + // Compute ratio of e's distance along range to total range first, for better precision + val bucketNumber = (((e - min) / (max - min)) * count).toInt + // should be less than count, but will equal count if e == max, in which case + // it's part of the last end-range-inclusive bucket, so return count-1 + Some(math.min(bucketNumber, count - 1)) } } // Decide which bucket function to pass to histogramPartition. We decide here - // rather than having a general function so that the decission need only be made + // rather than having a general function so that the decision need only be made // once rather than once per shard val bucketFunction = if (evenBuckets) { - fastBucketFunction(buckets(0), buckets(1)-buckets(0), buckets.length-1) _ + fastBucketFunction(buckets.head, buckets.last, buckets.length - 1) _ } else { basicBucketFunction _ } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 89adddcf0ac36..486e86ce1bb19 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -247,7 +247,9 @@ class HadoopRDD[K, V]( case eof: EOFException => finished = true } - + if (!finished) { + inputMetrics.incRecordsRead(1) + } (key, value) } @@ -261,7 +263,7 @@ class HadoopRDD[K, V]( // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { - inputMetrics.addBytesRead(split.inputSplit.value.getLength) + inputMetrics.incBytesRead(split.inputSplit.value.getLength) } catch { case e: java.io.IOException => logWarning("Unable to get input size to set InputMetrics for task", e) diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index 642a12c1edf6c..4fe7622bda00f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -99,21 +99,21 @@ class JdbcRDD[T: ClassTag]( override def close() { try { - if (null != rs && ! rs.isClosed()) { + if (null != rs) { rs.close() } } catch { case e: Exception => logWarning("Exception closing resultset", e) } try { - if (null != stmt && ! stmt.isClosed()) { + if (null != stmt) { stmt.close() } } catch { case e: Exception => logWarning("Exception closing statement", e) } try { - if (null != conn && ! conn.isClosed()) { + if (null != conn) { conn.close() } logInfo("closed connection") diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 44b9ffd2a53fd..7fb94840df99c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -151,7 +151,9 @@ class NewHadoopRDD[K, V]( throw new java.util.NoSuchElementException("End of stream") } havePair = false - + if (!finished) { + inputMetrics.incRecordsRead(1) + } (reader.getCurrentKey, reader.getCurrentValue) } @@ -165,7 +167,7 @@ class NewHadoopRDD[K, V]( // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { - inputMetrics.addBytesRead(split.serializableHadoopSplit.value.getLength) + inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) } catch { case e: java.io.IOException => logWarning("Unable to get input size to set InputMetrics for task", e) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 49b88a90ab5af..955b42c3baaa1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -34,7 +34,7 @@ import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat, -RecordWriter => NewRecordWriter} + RecordWriter => NewRecordWriter} import org.apache.spark._ import org.apache.spark.Partitioner.defaultPartitioner @@ -993,8 +993,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context) val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] + var recordsWritten = 0L try { - var recordsWritten = 0L while (iter.hasNext) { val pair = iter.next() writer.write(pair._1, pair._2) @@ -1008,6 +1008,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } committer.commitTask(hadoopContext) bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) } + outputMetrics.setRecordsWritten(recordsWritten) 1 } : Int @@ -1065,8 +1066,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.setup(context.stageId, context.partitionId, taskAttemptId) writer.open() + var recordsWritten = 0L try { - var recordsWritten = 0L while (iter.hasNext) { val record = iter.next() writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) @@ -1080,6 +1081,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } writer.commit() bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) } + outputMetrics.setRecordsWritten(recordsWritten) } self.context.runJob(self, writeToFile) @@ -1097,9 +1099,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) private def maybeUpdateOutputMetrics(bytesWrittenCallback: Option[() => Long], outputMetrics: OutputMetrics, recordsWritten: Long): Unit = { - if (recordsWritten % PairRDDFunctions.RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0 - && bytesWrittenCallback.isDefined) { + if (recordsWritten % PairRDDFunctions.RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0) { bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) } + outputMetrics.setRecordsWritten(recordsWritten) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 97aee58bddbf1..3ab9e54f0ec56 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -25,11 +25,8 @@ import scala.language.implicitConversions import scala.reflect.{classTag, ClassTag} import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus -import org.apache.hadoop.io.BytesWritable +import org.apache.hadoop.io.{Writable, BytesWritable, NullWritable, Text} import org.apache.hadoop.io.compress.CompressionCodec -import org.apache.hadoop.io.NullWritable -import org.apache.hadoop.io.Text -import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.TextOutputFormat import org.apache.spark._ @@ -57,8 +54,7 @@ import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, Bernoulli * [[org.apache.spark.rdd.SequenceFileRDDFunctions]] contains operations available on RDDs that * can be saved as SequenceFiles. * All operations are automatically available on any RDD of the right type (e.g. RDD[(Int, Int)] - * through implicit conversions except `saveAsSequenceFile`. You need to - * `import org.apache.spark.SparkContext._` to make `saveAsSequenceFile` work. + * through implicit. * * Internally, each RDD is characterized by five main properties: * @@ -466,7 +462,13 @@ abstract class RDD[T: ClassTag]( * Return the union of this RDD and another one. Any identical elements will appear multiple * times (use `.distinct()` to eliminate them). */ - def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other)) + def union(other: RDD[T]): RDD[T] = { + if (partitioner.isDefined && other.partitioner == partitioner) { + new PartitionerAwareUnionRDD(sc, Array(this, other)) + } else { + new UnionRDD(sc, Array(this, other)) + } + } /** * Return the union of this RDD and another one. Any identical elements will appear multiple @@ -1527,7 +1529,7 @@ abstract class RDD[T: ClassTag]( */ object RDD { - // The following implicit functions were in SparkContext before 1.2 and users had to + // The following implicit functions were in SparkContext before 1.3 and users had to // `import SparkContext._` to enable them. Now we move them here to make the compiler find // them automatically. However, we still keep the old functions in SparkContext for backward // compatibility and forward to the following functions directly. @@ -1541,9 +1543,15 @@ object RDD { new AsyncRDDActions(rdd) } - implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( - rdd: RDD[(K, V)]): SequenceFileRDDFunctions[K, V] = { - new SequenceFileRDDFunctions(rdd) + implicit def rddToSequenceFileRDDFunctions[K, V](rdd: RDD[(K, V)]) + (implicit kt: ClassTag[K], vt: ClassTag[V], + keyWritableFactory: WritableFactory[K], + valueWritableFactory: WritableFactory[V]) + : SequenceFileRDDFunctions[K, V] = { + implicit val keyConverter = keyWritableFactory.convert + implicit val valueConverter = valueWritableFactory.convert + new SequenceFileRDDFunctions(rdd, + keyWritableFactory.writableClass(kt), valueWritableFactory.writableClass(vt)) } implicit def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag](rdd: RDD[(K, V)]) diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala index 2b48916951430..059f8963691f0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala @@ -30,13 +30,35 @@ import org.apache.spark.Logging * through an implicit conversion. Note that this can't be part of PairRDDFunctions because * we need more implicit parameters to convert our keys and values to Writable. * - * Import `org.apache.spark.SparkContext._` at the top of their program to use these functions. */ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag]( - self: RDD[(K, V)]) + self: RDD[(K, V)], + _keyWritableClass: Class[_ <: Writable], + _valueWritableClass: Class[_ <: Writable]) extends Logging with Serializable { + @deprecated("It's used to provide backward compatibility for pre 1.3.0.", "1.3.0") + def this(self: RDD[(K, V)]) { + this(self, null, null) + } + + private val keyWritableClass = + if (_keyWritableClass == null) { + // pre 1.3.0, we need to use Reflection to get the Writable class + getWritableClass[K]() + } else { + _keyWritableClass + } + + private val valueWritableClass = + if (_valueWritableClass == null) { + // pre 1.3.0, we need to use Reflection to get the Writable class + getWritableClass[V]() + } else { + _valueWritableClass + } + private def getWritableClass[T <% Writable: ClassTag](): Class[_ <: Writable] = { val c = { if (classOf[Writable].isAssignableFrom(classTag[T].runtimeClass)) { @@ -55,6 +77,7 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag c.asInstanceOf[Class[_ <: Writable]] } + /** * Output the RDD as a Hadoop SequenceFile using the Writable types we infer from the RDD's key * and value types. If the key or value are Writable, then we use their classes directly; @@ -65,26 +88,28 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag def saveAsSequenceFile(path: String, codec: Option[Class[_ <: CompressionCodec]] = None) { def anyToWritable[U <% Writable](u: U): Writable = u - val keyClass = getWritableClass[K] - val valueClass = getWritableClass[V] - val convertKey = !classOf[Writable].isAssignableFrom(self.keyClass) - val convertValue = !classOf[Writable].isAssignableFrom(self.valueClass) + // TODO We cannot force the return type of `anyToWritable` be same as keyWritableClass and + // valueWritableClass at the compile time. To implement that, we need to add type parameters to + // SequenceFileRDDFunctions. however, SequenceFileRDDFunctions is a public class so it will be a + // breaking change. + val convertKey = self.keyClass != keyWritableClass + val convertValue = self.valueClass != valueWritableClass - logInfo("Saving as sequence file of type (" + keyClass.getSimpleName + "," + - valueClass.getSimpleName + ")" ) + logInfo("Saving as sequence file of type (" + keyWritableClass.getSimpleName + "," + + valueWritableClass.getSimpleName + ")" ) val format = classOf[SequenceFileOutputFormat[Writable, Writable]] val jobConf = new JobConf(self.context.hadoopConfiguration) if (!convertKey && !convertValue) { - self.saveAsHadoopFile(path, keyClass, valueClass, format, jobConf, codec) + self.saveAsHadoopFile(path, keyWritableClass, valueWritableClass, format, jobConf, codec) } else if (!convertKey && convertValue) { self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile( - path, keyClass, valueClass, format, jobConf, codec) + path, keyWritableClass, valueWritableClass, format, jobConf, codec) } else if (convertKey && !convertValue) { self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile( - path, keyClass, valueClass, format, jobConf, codec) + path, keyWritableClass, valueWritableClass, format, jobConf, codec) } else if (convertKey && convertValue) { self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile( - path, keyClass, valueClass, format, jobConf, codec) + path, keyWritableClass, valueWritableClass, format, jobConf, codec) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 1cfe98673773a..4a79ebabe6360 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -38,7 +38,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.storage._ -import org.apache.spark.util.{CallSite, EventLoop, SystemClock, Clock, Utils} +import org.apache.spark.util._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat /** @@ -54,6 +54,10 @@ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat * not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task * a small number of times before cancelling the whole stage. * + * Here's a checklist to use when making or reviewing changes to this class: + * + * - When adding a new data structure, update `DAGSchedulerSuite.assertDataStructuresEmpty` to + * include the new structure. This will help to catch memory leaks. */ private[spark] class DAGScheduler( @@ -63,7 +67,7 @@ class DAGScheduler( mapOutputTracker: MapOutputTrackerMaster, blockManagerMaster: BlockManagerMaster, env: SparkEnv, - clock: Clock = SystemClock) + clock: Clock = new SystemClock()) extends Logging { def this(sc: SparkContext, taskScheduler: TaskScheduler) = { @@ -98,7 +102,13 @@ class DAGScheduler( private[scheduler] val activeJobs = new HashSet[ActiveJob] - // Contains the locations that each RDD's partitions are cached on + /** + * Contains the locations that each RDD's partitions are cached on. This map's keys are RDD ids + * and its values are arrays indexed by partition numbers. Each array value is the set of + * locations where that RDD partition is cached. + * + * All accesses to this map should be guarded by synchronizing on it (see SPARK-4454). + */ private val cacheLocs = new HashMap[Int, Array[Seq[TaskLocation]]] // For tracking failed nodes, we use the MapOutputTracker's epoch number, which is sent with @@ -109,6 +119,8 @@ class DAGScheduler( // stray messages to detect. private val failedEpoch = new HashMap[String, Long] + private [scheduler] val outputCommitCoordinator = env.outputCommitCoordinator + // A closure serializer that we reuse. // This is only safe because DAGScheduler runs in a single thread. private val closureSerializer = SparkEnv.get.closureSerializer.newInstance() @@ -181,7 +193,8 @@ class DAGScheduler( eventProcessLoop.post(TaskSetFailed(taskSet, reason)) } - private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = { + private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = cacheLocs.synchronized { + // Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times if (!cacheLocs.contains(rdd.id)) { val blockIds = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId] val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster) @@ -192,7 +205,7 @@ class DAGScheduler( cacheLocs(rdd.id) } - private def clearCacheLocs() { + private def clearCacheLocs(): Unit = cacheLocs.synchronized { cacheLocs.clear() } @@ -465,8 +478,7 @@ class DAGScheduler( callSite: CallSite, allowLocal: Boolean, resultHandler: (Int, U) => Unit, - properties: Properties = null): JobWaiter[U] = - { + properties: Properties): JobWaiter[U] = { // Check to make sure we are not launching a task on a partition that does not exist. val maxPartitions = rdd.partitions.length partitions.find(p => p >= maxPartitions || p < 0).foreach { p => @@ -495,8 +507,7 @@ class DAGScheduler( callSite: CallSite, allowLocal: Boolean, resultHandler: (Int, U) => Unit, - properties: Properties = null) - { + properties: Properties): Unit = { val start = System.nanoTime val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties) waiter.awaitResult() match { @@ -517,9 +528,7 @@ class DAGScheduler( evaluator: ApproximateEvaluator[U, R], callSite: CallSite, timeout: Long, - properties: Properties = null) - : PartialResult[R] = - { + properties: Properties): PartialResult[R] = { val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val partitions = (0 until rdd.partitions.size).toArray @@ -648,7 +657,7 @@ class DAGScheduler( // completion events or stage abort stageIdToStage -= s.id jobIdToStageIds -= job.jobId - listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTime(), jobResult)) + listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), jobResult)) } } @@ -666,7 +675,7 @@ class DAGScheduler( // Cancel all jobs belonging to this job group. // First finds all active jobs with this group id, and then kill stages for them. val activeInGroup = activeJobs.filter(activeJob => - groupId == activeJob.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) + Option(activeJob.properties).exists(_.get(SparkContext.SPARK_JOB_GROUP_ID) == groupId)) val jobIds = activeInGroup.map(_.jobId) jobIds.foreach(handleJobCancellation(_, "part of cancelled job group %s".format(groupId))) submitWaitingStages() @@ -693,11 +702,12 @@ class DAGScheduler( // cancelling the stages because if the DAG scheduler is stopped, the entire application // is in the process of getting stopped. val stageFailedMessage = "Stage cancelled because SparkContext was shut down" - runningStages.foreach { stage => - stage.latestInfo.stageFailed(stageFailedMessage) - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + // The `toArray` here is necessary so that we don't iterate over `runningStages` while + // mutating it. + runningStages.toArray.foreach { stage => + markStageAsFinished(stage, Some(stageFailedMessage)) } - listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTime(), JobFailed(error))) + listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobFailed(error))) } } @@ -713,8 +723,7 @@ class DAGScheduler( allowLocal: Boolean, callSite: CallSite, listener: JobListener, - properties: Properties = null) - { + properties: Properties) { var finalStage: Stage = null try { // New stage creation may throw an exception if, for example, jobs are run on a @@ -736,7 +745,7 @@ class DAGScheduler( logInfo("Missing parents: " + getMissingParentStages(finalStage)) val shouldRunLocally = localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1 - val jobSubmissionTime = clock.getTime() + val jobSubmissionTime = clock.getTimeMillis() if (shouldRunLocally) { // Compute very short actions like first() or take() with no parent stages locally. listenerBus.post( @@ -808,6 +817,7 @@ class DAGScheduler( // will be posted, which should always come after a corresponding SparkListenerStageSubmitted // event. stage.latestInfo = StageInfo.fromStage(stage, Some(partitionsToCompute.size)) + outputCommitCoordinator.stageStart(stage.id) listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times. @@ -861,14 +871,13 @@ class DAGScheduler( logDebug("New pending tasks: " + stage.pendingTasks) taskScheduler.submitTasks( new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) - stage.latestInfo.submissionTime = Some(clock.getTime()) + stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) } else { - // Because we posted SparkListenerStageSubmitted earlier, we should post - // SparkListenerStageCompleted here in case there are no tasks to run. - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + // Because we posted SparkListenerStageSubmitted earlier, we should mark + // the stage as completed here in case there are no tasks to run + markStageAsFinished(stage, None) logDebug("Stage " + stage + " is actually done; %b %d %d".format( stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) - runningStages -= stage } } @@ -909,6 +918,9 @@ class DAGScheduler( val stageId = task.stageId val taskType = Utils.getFormattedClassName(task) + outputCommitCoordinator.taskCompleted(stageId, task.partitionId, + event.taskInfo.attempt, event.reason) + // The success case is dealt with separately below, since we need to compute accumulator // updates before posting. if (event.reason != Success) { @@ -921,23 +933,8 @@ class DAGScheduler( // Skip all the actions if the stage has been cancelled. return } - val stage = stageIdToStage(task.stageId) - def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None) = { - val serviceTime = stage.latestInfo.submissionTime match { - case Some(t) => "%.03f".format((clock.getTime() - t) / 1000.0) - case _ => "Unknown" - } - if (errorMessage.isEmpty) { - logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) - stage.latestInfo.completionTime = Some(clock.getTime()) - } else { - stage.latestInfo.stageFailed(errorMessage.get) - logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime)) - } - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) - runningStages -= stage - } + val stage = stageIdToStage(task.stageId) event.reason match { case Success => listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType, @@ -956,7 +953,7 @@ class DAGScheduler( markStageAsFinished(stage) cleanupStateForJobAndIndependentStages(job) listenerBus.post( - SparkListenerJobEnd(job.jobId, clock.getTime(), JobSucceeded)) + SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded)) } // taskSucceeded runs some user code that might throw an exception. Make sure @@ -1045,7 +1042,6 @@ class DAGScheduler( logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + s"due to a fetch failure from $mapStage (${mapStage.name})") markStageAsFinished(failedStage, Some(failureMessage)) - runningStages -= failedStage } if (disallowStageRetryForTest) { @@ -1073,6 +1069,9 @@ class DAGScheduler( handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch)) } + case commitDenied: TaskCommitDenied => + // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits + case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) => // Do nothing here, left up to the TaskScheduler to decide how to handle user failures @@ -1158,6 +1157,26 @@ class DAGScheduler( submitWaitingStages() } + /** + * Marks a stage as finished and removes it from the list of running stages. + */ + private def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = { + val serviceTime = stage.latestInfo.submissionTime match { + case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0) + case _ => "Unknown" + } + if (errorMessage.isEmpty) { + logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) + stage.latestInfo.completionTime = Some(clock.getTimeMillis()) + } else { + stage.latestInfo.stageFailed(errorMessage.get) + logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime)) + } + outputCommitCoordinator.stageEnd(stage.id) + listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + runningStages -= stage + } + /** * Aborts all jobs depending on a particular Stage. This is called in response to a task set * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. @@ -1169,7 +1188,7 @@ class DAGScheduler( } val dependentJobs: Seq[ActiveJob] = activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq - failedStage.latestInfo.completionTime = Some(clock.getTime()) + failedStage.latestInfo.completionTime = Some(clock.getTimeMillis()) for (job <- dependentJobs) { failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason") } @@ -1209,8 +1228,7 @@ class DAGScheduler( if (runningStages.contains(stage)) { try { // cancelTasks will fail if a SchedulerBackend does not implement killTask taskScheduler.cancelTasks(stageId, shouldInterruptThread) - stage.latestInfo.stageFailed(failureReason) - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + markStageAsFinished(stage, Some(failureReason)) } catch { case e: UnsupportedOperationException => logInfo(s"Could not cancel tasks for stage $stageId", e) @@ -1224,7 +1242,7 @@ class DAGScheduler( if (ableToCancelStages) { job.listener.jobFailed(error) cleanupStateForJobAndIndependentStages(job) - listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTime(), JobFailed(error))) + listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobFailed(error))) } } @@ -1265,17 +1283,26 @@ class DAGScheduler( } /** - * Synchronized method that might be called from other threads. + * Gets the locality information associated with a partition of a particular RDD. + * + * This method is thread-safe and is called from both DAGScheduler and SparkContext. + * * @param rdd whose partitions are to be looked at * @param partition to lookup locality information for * @return list of machines that are preferred by the partition */ private[spark] - def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = synchronized { + def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = { getPreferredLocsInternal(rdd, partition, new HashSet) } - /** Recursive implementation for getPreferredLocs. */ + /** + * Recursive implementation for getPreferredLocs. + * + * This method is thread-safe because it only accesses DAGScheduler state through thread-safe + * methods (getCacheLocs()); please be careful when modifying this method, because any new + * DAGScheduler state accessed by it may require additional synchronization. + */ private def getPreferredLocsInternal( rdd: RDD[_], partition: Int, diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 30075c172bdb1..d99db9e26c1f5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -47,21 +47,30 @@ import org.apache.spark.util.{JsonProtocol, Utils} */ private[spark] class EventLoggingListener( appId: String, - logBaseDir: String, + logBaseDir: URI, sparkConf: SparkConf, hadoopConf: Configuration) extends SparkListener with Logging { import EventLoggingListener._ - def this(appId: String, logBaseDir: String, sparkConf: SparkConf) = + def this(appId: String, logBaseDir: URI, sparkConf: SparkConf) = this(appId, logBaseDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf)) private val shouldCompress = sparkConf.getBoolean("spark.eventLog.compress", false) private val shouldOverwrite = sparkConf.getBoolean("spark.eventLog.overwrite", false) private val testing = sparkConf.getBoolean("spark.eventLog.testing", false) private val outputBufferSize = sparkConf.getInt("spark.eventLog.buffer.kb", 100) * 1024 - private val fileSystem = Utils.getHadoopFileSystem(new URI(logBaseDir), hadoopConf) + private val fileSystem = Utils.getHadoopFileSystem(logBaseDir, hadoopConf) + private val compressionCodec = + if (shouldCompress) { + Some(CompressionCodec.createCodec(sparkConf)) + } else { + None + } + private val compressionCodecName = compressionCodec.map { c => + CompressionCodec.getShortName(c.getClass.getName) + } // Only defined if the file system scheme is not local private var hadoopDataStream: Option[FSDataOutputStream] = None @@ -80,7 +89,7 @@ private[spark] class EventLoggingListener( private[scheduler] val loggedEvents = new ArrayBuffer[JValue] // Visible for tests only. - private[scheduler] val logPath = getLogPath(logBaseDir, appId) + private[scheduler] val logPath = getLogPath(logBaseDir, appId, compressionCodecName) /** * Creates the log file in the configured log directory. @@ -111,19 +120,19 @@ private[spark] class EventLoggingListener( hadoopDataStream.get } - val compressionCodec = - if (shouldCompress) { - Some(CompressionCodec.createCodec(sparkConf)) - } else { - None - } - - fileSystem.setPermission(path, LOG_FILE_PERMISSIONS) - val logStream = initEventLog(new BufferedOutputStream(dstream, outputBufferSize), - compressionCodec) - writer = Some(new PrintWriter(logStream)) + try { + val cstream = compressionCodec.map(_.compressedOutputStream(dstream)).getOrElse(dstream) + val bstream = new BufferedOutputStream(cstream, outputBufferSize) - logInfo("Logging events to %s".format(logPath)) + EventLoggingListener.initEventLog(bstream) + fileSystem.setPermission(path, LOG_FILE_PERMISSIONS) + writer = Some(new PrintWriter(bstream)) + logInfo("Logging events to %s".format(logPath)) + } catch { + case e: Exception => + dstream.close() + throw e + } } /** Log the event as JSON. */ @@ -201,77 +210,57 @@ private[spark] object EventLoggingListener extends Logging { // Suffix applied to the names of files still being written by applications. val IN_PROGRESS = ".inprogress" val DEFAULT_LOG_DIR = "/tmp/spark-events" + val SPARK_VERSION_KEY = "SPARK_VERSION" + val COMPRESSION_CODEC_KEY = "COMPRESSION_CODEC" private val LOG_FILE_PERMISSIONS = new FsPermission(Integer.parseInt("770", 8).toShort) - // Marker for the end of header data in a log file. After this marker, log data, potentially - // compressed, will be found. - private val HEADER_END_MARKER = "=== LOG_HEADER_END ===" - - // To avoid corrupted files causing the heap to fill up. Value is arbitrary. - private val MAX_HEADER_LINE_LENGTH = 4096 - // A cache for compression codecs to avoid creating the same codec many times private val codecMap = new mutable.HashMap[String, CompressionCodec] /** - * Write metadata about the event log to the given stream. - * - * The header is a serialized version of a map, except it does not use Java serialization to - * avoid incompatibilities between different JDKs. It writes one map entry per line, in - * "key=value" format. + * Write metadata about an event log to the given stream. + * The metadata is encoded in the first line of the event log as JSON. * - * The very last entry in the header is the `HEADER_END_MARKER` marker, so that the parsing code - * can know when to stop. - * - * The format needs to be kept in sync with the openEventLog() method below. Also, it cannot - * change in new Spark versions without some other way of detecting the change (like some - * metadata encoded in the file name). - * - * @param logStream Raw output stream to the even log file. - * @param compressionCodec Optional compression codec to use. - * @return A stream where to write event log data. This may be a wrapper around the original - * stream (for example, when compression is enabled). + * @param logStream Raw output stream to the event log file. */ - def initEventLog( - logStream: OutputStream, - compressionCodec: Option[CompressionCodec]): OutputStream = { - val meta = mutable.HashMap(("version" -> SPARK_VERSION)) - compressionCodec.foreach { codec => - meta += ("compressionCodec" -> codec.getClass().getName()) - } - - def write(entry: String) = { - val bytes = entry.getBytes(Charsets.UTF_8) - if (bytes.length > MAX_HEADER_LINE_LENGTH) { - throw new IOException(s"Header entry too long: ${entry}") - } - logStream.write(bytes, 0, bytes.length) - } - - meta.foreach { case (k, v) => write(s"$k=$v\n") } - write(s"$HEADER_END_MARKER\n") - compressionCodec.map(_.compressedOutputStream(logStream)).getOrElse(logStream) + def initEventLog(logStream: OutputStream): Unit = { + val metadata = SparkListenerLogStart(SPARK_VERSION) + val metadataJson = compact(JsonProtocol.logStartToJson(metadata)) + "\n" + logStream.write(metadataJson.getBytes(Charsets.UTF_8)) } /** * Return a file-system-safe path to the log file for the given application. * + * Note that because we currently only create a single log file for each application, + * we must encode all the information needed to parse this event log in the file name + * instead of within the file itself. Otherwise, if the file is compressed, for instance, + * we won't know which codec to use to decompress the metadata needed to open the file in + * the first place. + * * @param logBaseDir Directory where the log file will be written. * @param appId A unique app ID. + * @param compressionCodecName Name to identify the codec used to compress the contents + * of the log, or None if compression is not enabled. * @return A path which consists of file-system-safe characters. */ - def getLogPath(logBaseDir: String, appId: String): String = { - val name = appId.replaceAll("[ :/]", "-").replaceAll("[${}'\"]", "_").toLowerCase - Utils.resolveURI(logBaseDir) + "/" + name.stripSuffix("/") + def getLogPath( + logBaseDir: URI, + appId: String, + compressionCodecName: Option[String] = None): String = { + val sanitizedAppId = appId.replaceAll("[ :/]", "-").replaceAll("[.${}'\"]", "_").toLowerCase + // e.g. app_123, app_123.lzf + val logName = sanitizedAppId + compressionCodecName.map { "." + _ }.getOrElse("") + logBaseDir.toString.stripSuffix("/") + "/" + logName } /** - * Opens an event log file and returns an input stream to the event data. + * Opens an event log file and returns an input stream that contains the event data. * - * @return 2-tuple (event input stream, Spark version of event data) + * @return input stream that holds one JSON record per line. */ - def openEventLog(log: Path, fs: FileSystem): (InputStream, String) = { + def openEventLog(log: Path, fs: FileSystem): InputStream = { // It's not clear whether FileSystem.open() throws FileNotFoundException or just plain // IOException when a file does not exist, so try our best to throw a proper exception. if (!fs.exists(log)) { @@ -279,52 +268,17 @@ private[spark] object EventLoggingListener extends Logging { } val in = new BufferedInputStream(fs.open(log)) - // Read a single line from the input stream without buffering. - // We cannot use BufferedReader because we must avoid reading - // beyond the end of the header, after which the content of the - // file may be compressed. - def readLine(): String = { - val bytes = new ByteArrayOutputStream() - var next = in.read() - var count = 0 - while (next != '\n') { - if (next == -1) { - throw new IOException("Unexpected end of file.") - } - bytes.write(next) - count = count + 1 - if (count > MAX_HEADER_LINE_LENGTH) { - throw new IOException("Maximum header line length exceeded.") - } - next = in.read() - } - new String(bytes.toByteArray(), Charsets.UTF_8) + + // Compression codec is encoded as an extension, e.g. app_123.lzf + // Since we sanitize the app ID to not include periods, it is safe to split on it + val logName = log.getName.stripSuffix(IN_PROGRESS) + val codecName: Option[String] = logName.split("\\.").tail.lastOption + val codec = codecName.map { c => + codecMap.getOrElseUpdate(c, CompressionCodec.createCodec(new SparkConf, c)) } - // Parse the header metadata in the form of k=v pairs - // This assumes that every line before the header end marker follows this format try { - val meta = new mutable.HashMap[String, String]() - var foundEndMarker = false - while (!foundEndMarker) { - readLine() match { - case HEADER_END_MARKER => - foundEndMarker = true - case entry => - val prop = entry.split("=", 2) - if (prop.length != 2) { - throw new IllegalArgumentException("Invalid metadata in log file.") - } - meta += (prop(0) -> prop(1)) - } - } - - val sparkVersion = meta.get("version").getOrElse( - throw new IllegalArgumentException("Missing Spark version in log metadata.")) - val codec = meta.get("compressionCodec").map { codecName => - codecMap.getOrElseUpdate(codecName, CompressionCodec.createCodec(new SparkConf, codecName)) - } - (codec.map(_.compressedInputStream(in)).getOrElse(in), sparkVersion) + codec.map(_.compressedInputStream(in)).getOrElse(in) } catch { case e: Exception => in.close() diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 3bb54855bae44..8aa528ac573d0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -169,7 +169,8 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched + " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched + " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime + - " REMOTE_BYTES_READ=" + metrics.remoteBytesRead + " REMOTE_BYTES_READ=" + metrics.remoteBytesRead + + " LOCAL_BYTES_READ=" + metrics.localBytesRead case None => "" } val writeMetrics = taskMetrics.shuffleWriteMetrics match { diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala new file mode 100644 index 0000000000000..d89721caa9c13 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import scala.collection.mutable + +import akka.actor.{ActorRef, Actor} + +import org.apache.spark._ +import org.apache.spark.util.{AkkaUtils, ActorLogReceive} + +private sealed trait OutputCommitCoordinationMessage extends Serializable + +private case object StopCoordinator extends OutputCommitCoordinationMessage +private case class AskPermissionToCommitOutput(stage: Int, task: Long, taskAttempt: Long) + +/** + * Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins" + * policy. + * + * OutputCommitCoordinator is instantiated in both the drivers and executors. On executors, it is + * configured with a reference to the driver's OutputCommitCoordinatorActor, so requests to commit + * output will be forwarded to the driver's OutputCommitCoordinator. + * + * This class was introduced in SPARK-4879; see that JIRA issue (and the associated pull requests) + * for an extensive design discussion. + */ +private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { + + // Initialized by SparkEnv + var coordinatorActor: Option[ActorRef] = None + private val timeout = AkkaUtils.askTimeout(conf) + private val maxAttempts = AkkaUtils.numRetries(conf) + private val retryInterval = AkkaUtils.retryWaitMs(conf) + + private type StageId = Int + private type PartitionId = Long + private type TaskAttemptId = Long + + /** + * Map from active stages's id => partition id => task attempt with exclusive lock on committing + * output for that partition. + * + * Entries are added to the top-level map when stages start and are removed they finish + * (either successfully or unsuccessfully). + * + * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance. + */ + private val authorizedCommittersByStage: CommittersByStageMap = mutable.Map() + private type CommittersByStageMap = mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptId]] + + /** + * Returns whether the OutputCommitCoordinator's internal data structures are all empty. + */ + def isEmpty: Boolean = { + authorizedCommittersByStage.isEmpty + } + + /** + * Called by tasks to ask whether they can commit their output to HDFS. + * + * If a task attempt has been authorized to commit, then all other attempts to commit the same + * task will be denied. If the authorized task attempt fails (e.g. due to its executor being + * lost), then a subsequent task attempt may be authorized to commit its output. + * + * @param stage the stage number + * @param partition the partition number + * @param attempt a unique identifier for this task attempt + * @return true if this task is authorized to commit, false otherwise + */ + def canCommit( + stage: StageId, + partition: PartitionId, + attempt: TaskAttemptId): Boolean = { + val msg = AskPermissionToCommitOutput(stage, partition, attempt) + coordinatorActor match { + case Some(actor) => + AkkaUtils.askWithReply[Boolean](msg, actor, maxAttempts, retryInterval, timeout) + case None => + logError( + "canCommit called after coordinator was stopped (is SparkEnv shutdown in progress)?") + false + } + } + + // Called by DAGScheduler + private[scheduler] def stageStart(stage: StageId): Unit = synchronized { + authorizedCommittersByStage(stage) = mutable.HashMap[PartitionId, TaskAttemptId]() + } + + // Called by DAGScheduler + private[scheduler] def stageEnd(stage: StageId): Unit = synchronized { + authorizedCommittersByStage.remove(stage) + } + + // Called by DAGScheduler + private[scheduler] def taskCompleted( + stage: StageId, + partition: PartitionId, + attempt: TaskAttemptId, + reason: TaskEndReason): Unit = synchronized { + val authorizedCommitters = authorizedCommittersByStage.getOrElse(stage, { + logDebug(s"Ignoring task completion for completed stage") + return + }) + reason match { + case Success => + // The task output has been committed successfully + case denied: TaskCommitDenied => + logInfo( + s"Task was denied committing, stage: $stage, partition: $partition, attempt: $attempt") + case otherReason => + if (authorizedCommitters.get(partition).exists(_ == attempt)) { + logDebug(s"Authorized committer $attempt (stage=$stage, partition=$partition) failed;" + + s" clearing lock") + authorizedCommitters.remove(partition) + } + } + } + + def stop(): Unit = synchronized { + coordinatorActor.foreach(_ ! StopCoordinator) + coordinatorActor = None + authorizedCommittersByStage.clear() + } + + // Marked private[scheduler] instead of private so this can be mocked in tests + private[scheduler] def handleAskPermissionToCommit( + stage: StageId, + partition: PartitionId, + attempt: TaskAttemptId): Boolean = synchronized { + authorizedCommittersByStage.get(stage) match { + case Some(authorizedCommitters) => + authorizedCommitters.get(partition) match { + case Some(existingCommitter) => + logDebug(s"Denying $attempt to commit for stage=$stage, partition=$partition; " + + s"existingCommitter = $existingCommitter") + false + case None => + logDebug(s"Authorizing $attempt to commit for stage=$stage, partition=$partition") + authorizedCommitters(partition) = attempt + true + } + case None => + logDebug(s"Stage $stage has completed, so not allowing task attempt $attempt to commit") + false + } + } +} + +private[spark] object OutputCommitCoordinator { + + // This actor is used only for RPC + class OutputCommitCoordinatorActor(outputCommitCoordinator: OutputCommitCoordinator) + extends Actor with ActorLogReceive with Logging { + + override def receiveWithLogging = { + case AskPermissionToCommitOutput(stage, partition, taskAttempt) => + sender ! outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt) + case StopCoordinator => + logInfo("OutputCommitCoordinator stopped!") + context.stop(self) + sender ! true + } + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index 584f4e7789d1a..95273c716b3e2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -39,22 +39,24 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { * error is thrown by this method. * * @param logData Stream containing event log data. - * @param version Spark version that generated the events. + * @param sourceName Filename (or other source identifier) from whence @logData is being read */ - def replay(logData: InputStream, version: String) { + def replay(logData: InputStream, sourceName: String): Unit = { var currentLine: String = null + var lineNumber: Int = 1 try { val lines = Source.fromInputStream(logData).getLines() lines.foreach { line => currentLine = line postToAll(JsonProtocol.sparkEventFromJson(parse(line))) + lineNumber += 1 } } catch { case ioe: IOException => throw ioe case e: Exception => - logError("Exception in parsing Spark event log.", e) - logError("Malformed line: %s\n".format(currentLine)) + logError(s"Exception parsing Spark event log: $sourceName", e) + logError(s"Malformed line #$lineNumber: $currentLine\n") } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index dd28ddb31de1f..52720d48ca67f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -116,6 +116,11 @@ case class SparkListenerApplicationStart(appName: String, appId: Option[String], @DeveloperApi case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent +/** + * An internal class that describes the metadata of an event log. + * This event is not meant to be posted to listeners downstream. + */ +private[spark] case class SparkListenerLogStart(sparkVersion: String) extends SparkListenerEvent /** * :: DeveloperApi :: diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index fe8a19a2c0cb9..61e69ecc08387 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -58,6 +58,7 @@ private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkLi listener.onExecutorAdded(executorAdded) case executorRemoved: SparkListenerExecutorRemoved => listener.onExecutorRemoved(executorRemoved) + case logStart: SparkListenerLogStart => // ignore event log metadata } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 6fa1f2c880f7a..132a9ced77700 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -81,9 +81,11 @@ class TaskInfo( def status: String = { if (running) { - "RUNNING" - } else if (gettingResult) { - "GET RESULT" + if (gettingResult) { + "GET RESULT" + } else { + "RUNNING" + } } else if (failed) { "FAILED" } else if (successful) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 774f3d8cdb275..3938580aeea59 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer +import java.util.concurrent.RejectedExecutionException import scala.language.existentials import scala.util.control.NonFatal @@ -95,25 +96,30 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul def enqueueFailedTask(taskSetManager: TaskSetManager, tid: Long, taskState: TaskState, serializedData: ByteBuffer) { var reason : TaskEndReason = UnknownReason - getTaskResultExecutor.execute(new Runnable { - override def run(): Unit = Utils.logUncaughtExceptions { - try { - if (serializedData != null && serializedData.limit() > 0) { - reason = serializer.get().deserialize[TaskEndReason]( - serializedData, Utils.getSparkClassLoader) + try { + getTaskResultExecutor.execute(new Runnable { + override def run(): Unit = Utils.logUncaughtExceptions { + try { + if (serializedData != null && serializedData.limit() > 0) { + reason = serializer.get().deserialize[TaskEndReason]( + serializedData, Utils.getSparkClassLoader) + } + } catch { + case cnd: ClassNotFoundException => + // Log an error but keep going here -- the task failed, so not catastrophic + // if we can't deserialize the reason. + val loader = Utils.getContextOrSparkClassLoader + logError( + "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader) + case ex: Exception => {} } - } catch { - case cnd: ClassNotFoundException => - // Log an error but keep going here -- the task failed, so not catastrophic if we can't - // deserialize the reason. - val loader = Utils.getContextOrSparkClassLoader - logError( - "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader) - case ex: Exception => {} + scheduler.handleFailedTask(taskSetManager, tid, taskState, reason) } - scheduler.handleFailedTask(taskSetManager, tid, taskState, reason) - } - }) + }) + } catch { + case e: RejectedExecutionException if sparkEnv.isStopped => + // ignore it + } } def stop() { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 79f84e70df9d5..54f8fcfc416d1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -158,7 +158,7 @@ private[spark] class TaskSchedulerImpl( val tasks = taskSet.tasks logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { - val manager = new TaskSetManager(this, taskSet, maxTaskFailures) + val manager = createTaskSetManager(taskSet, maxTaskFailures) activeTaskSets(taskSet.id) = manager schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) @@ -180,6 +180,13 @@ private[spark] class TaskSchedulerImpl( backend.reviveOffers() } + // Label as private[scheduler] to allow tests to swap in different task set managers if necessary + private[scheduler] def createTaskSetManager( + taskSet: TaskSet, + maxTaskFailures: Int): TaskSetManager = { + new TaskSetManager(this, taskSet, maxTaskFailures) + } + override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized { logInfo("Cancelling stage " + stageId) activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 97c22fe724abd..529237f0d35dc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -51,7 +51,7 @@ private[spark] class TaskSetManager( sched: TaskSchedulerImpl, val taskSet: TaskSet, val maxTaskFailures: Int, - clock: Clock = SystemClock) + clock: Clock = new SystemClock()) extends Schedulable with Logging { val conf = sched.sc.conf @@ -166,7 +166,7 @@ private[spark] class TaskSetManager( // last launched a task at that level, and move up a level when localityWaits[curLevel] expires. // We then move down if we manage to launch a "more local" task. var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels - var lastLaunchTime = clock.getTime() // Time we last launched a task at this level + var lastLaunchTime = clock.getTimeMillis() // Time we last launched a task at this level override def schedulableQueue = null @@ -281,7 +281,7 @@ private[spark] class TaskSetManager( val failed = failedExecutors.get(taskId).get return failed.contains(execId) && - clock.getTime() - failed.get(execId).get < EXECUTOR_TASK_BLACKLIST_TIMEOUT + clock.getTimeMillis() - failed.get(execId).get < EXECUTOR_TASK_BLACKLIST_TIMEOUT } false @@ -292,7 +292,8 @@ private[spark] class TaskSetManager( * an attempt running on this host, in case the host is slow. In addition, the task should meet * the given locality constraint. */ - private def dequeueSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value) + // Labeled as protected to allow tests to override providing speculative tasks if necessary + protected def dequeueSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value) : Option[(Int, TaskLocality.Value)] = { speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set @@ -427,7 +428,7 @@ private[spark] class TaskSetManager( : Option[TaskDescription] = { if (!isZombie) { - val curTime = clock.getTime() + val curTime = clock.getTimeMillis() var allowedLocality = maxLocality @@ -458,7 +459,7 @@ private[spark] class TaskSetManager( lastLaunchTime = curTime } // Serialize and return the task - val startTime = clock.getTime() + val startTime = clock.getTimeMillis() val serializedTask: ByteBuffer = try { Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser) } catch { @@ -506,13 +507,64 @@ private[spark] class TaskSetManager( * Get the level we can launch tasks according to delay scheduling, based on current wait time. */ private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = { - while (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex) && - currentLocalityIndex < myLocalityLevels.length - 1) - { - // Jump to the next locality level, and remove our waiting time for the current one since - // we don't want to count it again on the next one - lastLaunchTime += localityWaits(currentLocalityIndex) - currentLocalityIndex += 1 + // Remove the scheduled or finished tasks lazily + def tasksNeedToBeScheduledFrom(pendingTaskIds: ArrayBuffer[Int]): Boolean = { + var indexOffset = pendingTaskIds.size + while (indexOffset > 0) { + indexOffset -= 1 + val index = pendingTaskIds(indexOffset) + if (copiesRunning(index) == 0 && !successful(index)) { + return true + } else { + pendingTaskIds.remove(indexOffset) + } + } + false + } + // Walk through the list of tasks that can be scheduled at each location and returns true + // if there are any tasks that still need to be scheduled. Lazily cleans up tasks that have + // already been scheduled. + def moreTasksToRunIn(pendingTasks: HashMap[String, ArrayBuffer[Int]]): Boolean = { + val emptyKeys = new ArrayBuffer[String] + val hasTasks = pendingTasks.exists { + case (id: String, tasks: ArrayBuffer[Int]) => + if (tasksNeedToBeScheduledFrom(tasks)) { + true + } else { + emptyKeys += id + false + } + } + // The key could be executorId, host or rackId + emptyKeys.foreach(id => pendingTasks.remove(id)) + hasTasks + } + + while (currentLocalityIndex < myLocalityLevels.length - 1) { + val moreTasks = myLocalityLevels(currentLocalityIndex) match { + case TaskLocality.PROCESS_LOCAL => moreTasksToRunIn(pendingTasksForExecutor) + case TaskLocality.NODE_LOCAL => moreTasksToRunIn(pendingTasksForHost) + case TaskLocality.NO_PREF => pendingTasksWithNoPrefs.nonEmpty + case TaskLocality.RACK_LOCAL => moreTasksToRunIn(pendingTasksForRack) + } + if (!moreTasks) { + // This is a performance optimization: if there are no more tasks that can + // be scheduled at a particular locality level, there is no point in waiting + // for the locality wait timeout (SPARK-4939). + lastLaunchTime = curTime + logDebug(s"No tasks for locality level ${myLocalityLevels(currentLocalityIndex)}, " + + s"so moving to locality level ${myLocalityLevels(currentLocalityIndex + 1)}") + currentLocalityIndex += 1 + } else if (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex)) { + // Jump to the next locality level, and reset lastLaunchTime so that the next locality + // wait timer doesn't immediately expire + lastLaunchTime += localityWaits(currentLocalityIndex) + currentLocalityIndex += 1 + logDebug(s"Moving to ${myLocalityLevels(currentLocalityIndex)} after waiting for " + + s"${localityWaits(currentLocalityIndex)}ms") + } else { + return myLocalityLevels(currentLocalityIndex) + } } myLocalityLevels(currentLocalityIndex) } @@ -622,7 +674,7 @@ private[spark] class TaskSetManager( return } val key = ef.description - val now = clock.getTime() + val now = clock.getTimeMillis() val (printFull, dupCount) = { if (recentExceptions.contains(key)) { val (dupCount, printTime) = recentExceptions(key) @@ -654,10 +706,13 @@ private[spark] class TaskSetManager( } // always add to failed executors failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()). - put(info.executorId, clock.getTime()) + put(info.executorId, clock.getTimeMillis()) sched.dagScheduler.taskEnded(tasks(index), reason, null, null, info, taskMetrics) addPendingTask(index) - if (!isZombie && state != TaskState.KILLED) { + if (!isZombie && state != TaskState.KILLED && !reason.isInstanceOf[TaskCommitDenied]) { + // If a task failed because its attempt to commit was denied, do not count this failure + // towards failing the stage. This is intended to prevent spurious stage failures in cases + // where many speculative tasks are launched and denied to commit. assert (null != failureReason) numFailures(index) += 1 if (numFailures(index) >= maxTaskFailures) { @@ -766,7 +821,7 @@ private[spark] class TaskSetManager( val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) { - val time = clock.getTime() + val time = clock.getTimeMillis() val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray Arrays.sort(durations) val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.size - 1)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 1da6fe976da5b..9bf74f4be198d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -39,7 +39,11 @@ private[spark] object CoarseGrainedClusterMessages { case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage // Executors to driver - case class RegisterExecutor(executorId: String, hostPort: String, cores: Int) + case class RegisterExecutor( + executorId: String, + hostPort: String, + cores: Int, + logUrls: Map[String, String]) extends CoarseGrainedClusterMessage { Utils.checkHostPort(hostPort, "Expected host port") } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 103a5c053c289..87ebf31139ce9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -86,7 +86,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste } def receiveWithLogging = { - case RegisterExecutor(executorId, hostPort, cores) => + case RegisterExecutor(executorId, hostPort, cores, logUrls) => Utils.checkHostPort(hostPort, "Host port expected " + hostPort) if (executorDataMap.contains(executorId)) { sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId) @@ -98,7 +98,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste totalCoreCount.addAndGet(cores) totalRegisteredExecutors.addAndGet(1) val (host, _) = Utils.parseHostPort(hostPort) - val data = new ExecutorData(sender, sender.path.address, host, cores, cores) + val data = new ExecutorData(sender, sender.path.address, host, cores, cores, logUrls) // This must be synchronized because variables mutated // in this block are read when requesting executors CoarseGrainedSchedulerBackend.this.synchronized { @@ -211,6 +211,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // This must be synchronized because variables mutated // in this block are read when requesting executors CoarseGrainedSchedulerBackend.this.synchronized { + addressToExecutorId -= executorInfo.executorAddress executorDataMap -= executorId executorsPendingToRemove -= executorId } @@ -311,9 +312,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste /** * Request an additional number of executors from the cluster manager. - * Return whether the request is acknowledged. + * @return whether the request is acknowledged. */ final override def requestExecutors(numAdditionalExecutors: Int): Boolean = synchronized { + if (numAdditionalExecutors < 0) { + throw new IllegalArgumentException( + "Attempted to request a negative number of additional executor(s) " + + s"$numAdditionalExecutors from the cluster manager. Please specify a positive number!") + } logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager") logDebug(s"Number of pending executors is now $numPendingExecutors") numPendingExecutors += numAdditionalExecutors @@ -322,6 +328,22 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste doRequestTotalExecutors(newTotal) } + /** + * Express a preference to the cluster manager for a given total number of executors. This can + * result in canceling pending requests or filing additional requests. + * @return whether the request is acknowledged. + */ + final override def requestTotalExecutors(numExecutors: Int): Boolean = synchronized { + if (numExecutors < 0) { + throw new IllegalArgumentException( + "Attempted to request a negative number of executor(s) " + + s"$numExecutors from the cluster manager. Please specify a positive number!") + } + numPendingExecutors = + math.max(numExecutors - numExistingExecutors + executorsPendingToRemove.size, 0) + doRequestTotalExecutors(numExecutors) + } + /** * Request executors from the cluster manager by specifying the total number desired, * including existing pending and running executors. @@ -332,7 +354,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste * insufficient resources to satisfy the first request. We make the assumption here that the * cluster manager will eventually fulfill all requests when resources free up. * - * Return whether the request is acknowledged. + * @return whether the request is acknowledged. */ protected def doRequestTotalExecutors(requestedTotal: Int): Boolean = false @@ -350,6 +372,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste logWarning(s"Executor to kill $id does not exist!") } } + // Killing executors means effectively that we want less executors than before, so also update + // the target number of executors to avoid having the backend allocate new ones. + val newTotal = (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size + - filteredExecutorIds.size) + doRequestTotalExecutors(newTotal) + executorsPendingToRemove ++= filteredExecutorIds doKillExecutors(filteredExecutorIds) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index eb52ddfb1eab1..5e571efe76720 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -33,5 +33,6 @@ private[cluster] class ExecutorData( val executorAddress: Address, override val executorHost: String, var freeCores: Int, - override val totalCores: Int -) extends ExecutorInfo(executorHost, totalCores) + override val totalCores: Int, + override val logUrlMap: Map[String, String] +) extends ExecutorInfo(executorHost, totalCores, logUrlMap) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala index b4738e64c9391..7f218566146a1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala @@ -25,8 +25,8 @@ import org.apache.spark.annotation.DeveloperApi @DeveloperApi class ExecutorInfo( val executorHost: String, - val totalCores: Int -) { + val totalCores: Int, + val logUrlMap: Map[String, String]) { def canEqual(other: Any): Boolean = other.isInstanceOf[ExecutorInfo] @@ -34,12 +34,13 @@ class ExecutorInfo( case that: ExecutorInfo => (that canEqual this) && executorHost == that.executorHost && - totalCores == that.totalCores + totalCores == that.totalCores && + logUrlMap == that.logUrlMap case _ => false } override def hashCode(): Int = { - val state = Seq(executorHost, totalCores) + val state = Seq(executorHost, totalCores, logUrlMap) state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index d2e1680a5fd1b..c70d11f5ce0ec 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -52,8 +52,13 @@ private[spark] class SparkDeploySchedulerBackend( conf.get("spark.driver.host"), conf.get("spark.driver.port"), CoarseGrainedSchedulerBackend.ACTOR_NAME) - val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{APP_ID}}", - "{{WORKER_URL}}") + val args = Seq( + "--driver-url", driverUrl, + "--executor-id", "{{EXECUTOR_ID}}", + "--hostname", "{{HOSTNAME}}", + "--cores", "{{CORES}}", + "--app-id", "{{APP_ID}}", + "--worker-url", "{{WORKER_URL}}") val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions") .map(Utils.splitCommandString).getOrElse(Seq.empty) val classPathEntries = sc.conf.getOption("spark.executor.extraClassPath") @@ -78,7 +83,7 @@ private[spark] class SparkDeploySchedulerBackend( args, sc.executorEnvs, classPathEntries ++ testingClassPath, libraryPathEntries, javaOpts) val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("") val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, - appUIAddress, sc.eventLogDir) + appUIAddress, sc.eventLogDir, sc.eventLogCodec) client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf) client.start() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 0d1c2a916ca7f..e13de0f46ef89 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -28,7 +28,7 @@ import org.apache.mesos.{Scheduler => MScheduler} import org.apache.mesos._ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} -import org.apache.spark.{Logging, SparkContext, SparkEnv, SparkException} +import org.apache.spark.{Logging, SparkContext, SparkEnv, SparkException, TaskState} import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.{Utils, AkkaUtils} @@ -154,18 +154,25 @@ private[spark] class CoarseMesosSchedulerBackend( if (uri == null) { val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath command.setValue( - "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s".format( - prefixEnv, runScript, driverUrl, offer.getSlaveId.getValue, - offer.getHostname, numCores, appId)) + "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend" + .format(prefixEnv, runScript) + + s" --driver-url $driverUrl" + + s" --executor-id ${offer.getSlaveId.getValue}" + + s" --hostname ${offer.getHostname}" + + s" --cores $numCores" + + s" --app-id $appId") } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". val basename = uri.split('/').last.split('.').head command.setValue( - ("cd %s*; %s " + - "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s") - .format(basename, prefixEnv, driverUrl, offer.getSlaveId.getValue, - offer.getHostname, numCores, appId)) + s"cd $basename*; $prefixEnv " + + "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" + + s" --driver-url $driverUrl" + + s" --executor-id ${offer.getSlaveId.getValue}" + + s" --hostname ${offer.getHostname}" + + s" --cores $numCores" + + s" --app-id $appId") command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) } command.build() @@ -255,20 +262,12 @@ private[spark] class CoarseMesosSchedulerBackend( .build() } - /** Check whether a Mesos task state represents a finished task */ - private def isFinished(state: MesosTaskState) = { - state == MesosTaskState.TASK_FINISHED || - state == MesosTaskState.TASK_FAILED || - state == MesosTaskState.TASK_KILLED || - state == MesosTaskState.TASK_LOST - } - override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { val taskId = status.getTaskId.getValue.toInt val state = status.getState logInfo("Mesos task " + taskId + " is now " + state) synchronized { - if (isFinished(state)) { + if (TaskState.isFinished(TaskState.fromMesos(state))) { val slaveId = taskIdToSlaveId(taskId) slaveIdsWithExecutors -= slaveId taskIdToSlaveId -= taskId @@ -278,7 +277,7 @@ private[spark] class CoarseMesosSchedulerBackend( coresByTaskId -= taskId } // If it was a failure, mark the slave as failed for blacklisting purposes - if (state == MesosTaskState.TASK_FAILED || state == MesosTaskState.TASK_LOST) { + if (TaskState.isFailed(TaskState.fromMesos(state))) { failuresBySlaveId(slaveId) = failuresBySlaveId.getOrElse(slaveId, 0) + 1 if (failuresBySlaveId(slaveId) >= MAX_SLAVE_FAILURES) { logInfo("Blacklisting Mesos slave " + slaveId + " due to too many failures; " + diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index c3c546be6da15..06bb527522141 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -270,7 +270,8 @@ private[spark] class MesosSchedulerBackend( mesosTasks.foreach { case (slaveId, tasks) => slaveIdToWorkerOffer.get(slaveId).foreach(o => listenerBus.post(SparkListenerExecutorAdded(System.currentTimeMillis(), slaveId, - new ExecutorInfo(o.host, o.cores))) + // TODO: Add support for log urls for Mesos + new ExecutorInfo(o.host, o.cores, Map.empty))) ) d.launchTasks(Collections.singleton(slaveIdToOffer(slaveId).getId), tasks, filters) } @@ -312,24 +313,17 @@ private[spark] class MesosSchedulerBackend( .build() } - /** Check whether a Mesos task state represents a finished task */ - def isFinished(state: MesosTaskState) = { - state == MesosTaskState.TASK_FINISHED || - state == MesosTaskState.TASK_FAILED || - state == MesosTaskState.TASK_KILLED || - state == MesosTaskState.TASK_LOST - } - override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { inClassLoader() { val tid = status.getTaskId.getValue.toLong val state = TaskState.fromMesos(status.getState) synchronized { - if (status.getState == MesosTaskState.TASK_LOST && taskIdToSlaveId.contains(tid)) { + if (TaskState.isFailed(TaskState.fromMesos(status.getState)) + && taskIdToSlaveId.contains(tid)) { // We lost the executor on this slave, so remember that it's gone removeExecutor(taskIdToSlaveId(tid), "Lost executor") } - if (isFinished(status.getState)) { + if (TaskState.isFinished(state)) { taskIdToSlaveId.remove(tid) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 05b6fa54564b7..4676b828d3d89 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -19,6 +19,8 @@ package org.apache.spark.scheduler.local import java.nio.ByteBuffer +import scala.concurrent.duration._ + import akka.actor.{Actor, ActorRef, Props} import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState} @@ -46,6 +48,8 @@ private[spark] class LocalActor( private val totalCores: Int) extends Actor with ActorLogReceive with Logging { + import context.dispatcher // to use Akka's scheduler.scheduleOnce() + private var freeCores = totalCores private val localExecutorId = SparkContext.DRIVER_IDENTIFIER @@ -74,11 +78,16 @@ private[spark] class LocalActor( def reviveOffers() { val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) - for (task <- scheduler.resourceOffers(offers).flatten) { + val tasks = scheduler.resourceOffers(offers).flatten + for (task <- tasks) { freeCores -= scheduler.CPUS_PER_TASK executor.launchTask(executorBackend, taskId = task.taskId, attemptNumber = task.attemptNumber, task.name, task.serializedTask) } + if (tasks.isEmpty && scheduler.activeTaskSets.nonEmpty) { + // Try to reviveOffer after 1 second, because scheduler may wait for locality timeout + context.system.scheduler.scheduleOnce(1000 millis, self, ReviveOffers) + } } } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index d56e23ce4478a..dc7aa99738c17 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -20,22 +20,23 @@ package org.apache.spark.serializer import java.io.{EOFException, InputStream, OutputStream} import java.nio.ByteBuffer +import scala.reflect.ClassTag + import com.esotericsoftware.kryo.{Kryo, KryoException} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} +import org.roaringbitmap.{ArrayContainer, BitmapContainer, RoaringArray, RoaringBitmap} import org.apache.spark._ import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.HttpBroadcast -import org.apache.spark.network.nio.{PutBlock, GotBlock, GetBlock} +import org.apache.spark.network.nio.{GetBlock, GotBlock, PutBlock} import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ import org.apache.spark.util.BoundedPriorityQueue import org.apache.spark.util.collection.CompactBuffer -import scala.reflect.ClassTag - /** * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. * @@ -58,14 +59,6 @@ class KryoSerializer(conf: SparkConf) private val classesToRegister = conf.get("spark.kryo.classesToRegister", "") .split(',') .filter(!_.isEmpty) - .map { className => - try { - Class.forName(className) - } catch { - case e: Exception => - throw new SparkException("Failed to load class to register with Kryo", e) - } - } def newKryoOutput() = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) @@ -97,7 +90,8 @@ class KryoSerializer(conf: SparkConf) // Use the default classloader when calling the user registrator. Thread.currentThread.setContextClassLoader(classLoader) // Register classes given through spark.kryo.classesToRegister. - classesToRegister.foreach { clazz => kryo.register(clazz) } + classesToRegister + .foreach { className => kryo.register(Class.forName(className, true, classLoader)) } // Allow the user to register their own classes by setting spark.kryo.registrator. userRegistrator .map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator]) @@ -164,7 +158,13 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ override def serialize[T: ClassTag](t: T): ByteBuffer = { output.clear() - kryo.writeClassAndObject(output, t) + try { + kryo.writeClassAndObject(output, t) + } catch { + case e: KryoException if e.getMessage.startsWith("Buffer overflow") => + throw new SparkException(s"Kryo serialization failed: ${e.getMessage}. To avoid this, " + + "increase spark.kryoserializer.buffer.max.mb value.") + } ByteBuffer.wrap(output.toBytes) } @@ -209,9 +209,17 @@ private[serializer] object KryoSerializer { classOf[GetBlock], classOf[CompressedMapStatus], classOf[HighlyCompressedMapStatus], + classOf[RoaringBitmap], + classOf[RoaringArray], + classOf[RoaringArray.Element], + classOf[Array[RoaringArray.Element]], + classOf[ArrayContainer], + classOf[BitmapContainer], classOf[CompactBuffer[_]], classOf[BlockManagerId], classOf[Array[Byte]], + classOf[Array[Short]], + classOf[Array[Long]], classOf[BoundedPriorityQueue[_]], classOf[SparkConf] ) diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index 7de2f9cbb2866..7aec9f90f1e0b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -112,6 +112,7 @@ class FileShuffleBlockManager(conf: SparkConf) private val shuffleState = shuffleStates(shuffleId) private var fileGroup: ShuffleFileGroup = null + val openStartTime = System.nanoTime val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) { fileGroup = getUnusedFileGroup() Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => @@ -135,6 +136,9 @@ class FileShuffleBlockManager(conf: SparkConf) blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize, writeMetrics) } } + // Creating the file to write to and creating a disk writer both involve interacting with + // the disk, so should be included in the shuffle write time. + writeMetrics.incShuffleWriteTime(System.nanoTime - openStartTime) override def releaseWriters(success: Boolean) { if (consolidateShuffleFiles) { diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index e3e7434df45b0..7a2c5ae32d98b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -86,6 +86,12 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { context.taskMetrics.updateShuffleReadMetrics() }) - new InterruptibleIterator[T](context, completionIter) + new InterruptibleIterator[T](context, completionIter) { + val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + override def next(): T = { + readMetrics.incRecordsRead(1) + delegate.next() + } + } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 27496c5a289cb..621b264b5e482 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -63,6 +63,9 @@ private[spark] class SortShuffleWriter[K, V, C]( sorter.insertAll(records) } + // Don't bother including the time to open the merged output file in the shuffle write time, + // because it just opens a single file, so is typically too fast to measure accurately + // (see SPARK-3570). val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId) val blockId = shuffleBlockManager.consolidateId(dep.shuffleId, mapId) val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 8bc5a1cd18b64..86dbd89f0ffb8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -53,7 +53,7 @@ private[spark] class BlockResult( readMethod: DataReadMethod.Value, bytes: Long) { val inputMetrics = new InputMetrics(readMethod) - inputMetrics.addBytesRead(bytes) + inputMetrics.incBytesRead(bytes) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 3198d766fca37..81164178b9e8e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -29,7 +29,8 @@ import org.apache.spark.executor.ShuffleWriteMetrics * appending data to an existing block, and can guarantee atomicity in the case of faults * as it allows the caller to revert partial writes. * - * This interface does not support concurrent writes. + * This interface does not support concurrent writes. Also, once the writer has + * been opened, it cannot be reopened again. */ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) { @@ -95,6 +96,7 @@ private[spark] class DiskBlockObjectWriter( private var ts: TimeTrackingOutputStream = null private var objOut: SerializationStream = null private var initialized = false + private var hasBeenClosed = false /** * Cursors used to represent positions in the file. @@ -115,11 +117,16 @@ private[spark] class DiskBlockObjectWriter( private var finalPosition: Long = -1 private var reportedPosition = initialPosition - /** Calling channel.position() to update the write metrics can be a little bit expensive, so we - * only call it every N writes */ - private var writesSinceMetricsUpdate = 0 + /** + * Keep track of number of records written and also use this to periodically + * output bytes written since the latter is expensive to do for each record. + */ + private var numRecordsWritten = 0 override def open(): BlockObjectWriter = { + if (hasBeenClosed) { + throw new IllegalStateException("Writer already closed. Cannot be reopened.") + } fos = new FileOutputStream(file, true) ts = new TimeTrackingOutputStream(fos) channel = fos.getChannel() @@ -145,6 +152,7 @@ private[spark] class DiskBlockObjectWriter( ts = null objOut = null initialized = false + hasBeenClosed = true } } @@ -168,6 +176,7 @@ private[spark] class DiskBlockObjectWriter( override def revertPartialWritesAndClose() { try { writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition) + writeMetrics.decShuffleRecordsWritten(numRecordsWritten) if (initialized) { objOut.flush() @@ -193,12 +202,11 @@ private[spark] class DiskBlockObjectWriter( } objOut.writeObject(value) + numRecordsWritten += 1 + writeMetrics.incShuffleRecordsWritten(1) - if (writesSinceMetricsUpdate == 32) { - writesSinceMetricsUpdate = 0 + if (numRecordsWritten % 32 == 0) { updateBytesWritten() - } else { - writesSinceMetricsUpdate += 1 } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 53eaedacbf291..12cd8ea3bdf1f 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -49,7 +49,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon } private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) - addShutdownHook() + private val shutdownHook = addShutdownHook() /** Looks up a file by hashing it into one of our local subdirectories. */ // This method should be kept in sync with @@ -134,17 +134,29 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon } } - private def addShutdownHook() { - Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { + private def addShutdownHook(): Thread = { + val shutdownHook = new Thread("delete Spark local dirs") { override def run(): Unit = Utils.logUncaughtExceptions { logDebug("Shutdown hook called") - DiskBlockManager.this.stop() + DiskBlockManager.this.doStop() } - }) + } + Runtime.getRuntime.addShutdownHook(shutdownHook) + shutdownHook } /** Cleanup local dirs and stop shuffle sender. */ private[spark] def stop() { + // Remove the shutdown hook. It causes memory leaks if we leave it around. + try { + Runtime.getRuntime.removeShutdownHook(shutdownHook) + } catch { + case e: IllegalStateException => None + } + doStop() + } + + private def doStop(): Unit = { // Only perform cleanup if an external service is not serving our shuffle files. if (!blockManager.externalShuffleServiceEnabled || blockManager.blockManagerId.isDriver) { localDirs.foreach { localDir => diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index ab9ee4f0096bf..8f28ef49a8a6f 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -234,6 +234,7 @@ final class ShuffleBlockFetcherIterator( try { val buf = blockManager.getBlockData(blockId) shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) buf.retain() results.put(new SuccessFetchResult(blockId, 0, buf)) } catch { diff --git a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala index 27ba9e18237b5..67f572e79314d 100644 --- a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala +++ b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala @@ -28,7 +28,6 @@ import org.apache.spark._ * of them will be combined together, showed in one line. */ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { - // Carrige return val CR = '\r' // Update period of progress bar, in milliseconds @@ -121,4 +120,10 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { clear() lastFinishTime = System.currentTimeMillis() } + + /** + * Tear down the timer thread. The timer thread is a GC root, and it retains the entire + * SparkContext if it's not terminated. + */ + def stop(): Unit = timer.cancel() } diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 88fed833f922d..bf4b24e98b134 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -62,17 +62,22 @@ private[spark] object JettyUtils extends Logging { securityMgr: SecurityManager): HttpServlet = { new HttpServlet { override def doGet(request: HttpServletRequest, response: HttpServletResponse) { - if (securityMgr.checkUIViewPermissions(request.getRemoteUser)) { - response.setContentType("%s;charset=utf-8".format(servletParams.contentType)) - response.setStatus(HttpServletResponse.SC_OK) - val result = servletParams.responder(request) - response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") - response.getWriter.println(servletParams.extractFn(result)) - } else { - response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) - response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") - response.sendError(HttpServletResponse.SC_UNAUTHORIZED, - "User is not authorized to access this page.") + try { + if (securityMgr.checkUIViewPermissions(request.getRemoteUser)) { + response.setContentType("%s;charset=utf-8".format(servletParams.contentType)) + response.setStatus(HttpServletResponse.SC_OK) + val result = servletParams.responder(request) + response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + response.getWriter.println(servletParams.extractFn(result)) + } else { + response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) + response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + response.sendError(HttpServletResponse.SC_UNAUTHORIZED, + "User is not authorized to access this page.") + } + } catch { + case e: IllegalArgumentException => + response.sendError(HttpServletResponse.SC_BAD_REQUEST, e.getMessage) } } } diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index 4307029d44fbb..cae6870c2ab20 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -29,15 +29,20 @@ private[spark] object ToolTips { val SHUFFLE_READ_BLOCKED_TIME = "Time that the task spent blocked waiting for shuffle data to be read from remote machines." - val INPUT = "Bytes read from Hadoop or from Spark storage." + val INPUT = "Bytes and records read from Hadoop or from Spark storage." - val OUTPUT = "Bytes written to Hadoop." + val OUTPUT = "Bytes and records written to Hadoop." - val SHUFFLE_WRITE = "Bytes written to disk in order to be read by a shuffle in a future stage." + val SHUFFLE_WRITE = + "Bytes and records written to disk in order to be read by a shuffle in a future stage." val SHUFFLE_READ = - """Bytes read from remote executors. Typically less than shuffle write bytes - because this does not include shuffle data read locally.""" + """Total shuffle bytes and records read (includes both data read locally and data read from + remote executors). """ + + val SHUFFLE_READ_REMOTE_SIZE = + """Total shuffle bytes read from remote executors. This is a subset of the shuffle + read bytes; the remaining shuffle data is read locally. """ val GETTING_RESULT_TIME = """Time that the driver spends fetching task results from workers. If this is large, consider diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 9be65a4a39a09..ea548f23120d9 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -20,14 +20,15 @@ package org.apache.spark.ui import javax.servlet.http.HttpServletRequest import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap import scala.xml.Node import org.eclipse.jetty.servlet.ServletContextHandler import org.json4s.JsonAST.{JNothing, JValue} -import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.ui.JettyUtils._ import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SecurityManager, SparkConf} /** * The top level component of the UI hierarchy that contains the server. @@ -45,9 +46,10 @@ private[spark] abstract class WebUI( protected val tabs = ArrayBuffer[WebUITab]() protected val handlers = ArrayBuffer[ServletContextHandler]() + protected val pageToHandlers = new HashMap[WebUIPage, ArrayBuffer[ServletContextHandler]] protected var serverInfo: Option[ServerInfo] = None protected val localHostName = Utils.localHostName() - protected val publicHostName = Option(System.getenv("SPARK_PUBLIC_DNS")).getOrElse(localHostName) + protected val publicHostName = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(localHostName) private val className = Utils.getFormattedClassName(this) def getBasePath: String = basePath @@ -60,14 +62,30 @@ private[spark] abstract class WebUI( tab.pages.foreach(attachPage) tabs += tab } + + def detachTab(tab: WebUITab) { + tab.pages.foreach(detachPage) + tabs -= tab + } + + def detachPage(page: WebUIPage) { + pageToHandlers.remove(page).foreach(_.foreach(detachHandler)) + } /** Attach a page to this UI. */ def attachPage(page: WebUIPage) { val pagePath = "/" + page.prefix - attachHandler(createServletHandler(pagePath, - (request: HttpServletRequest) => page.render(request), securityManager, basePath)) - attachHandler(createServletHandler(pagePath.stripSuffix("/") + "/json", - (request: HttpServletRequest) => page.renderJson(request), securityManager, basePath)) + val renderHandler = createServletHandler(pagePath, + (request: HttpServletRequest) => page.render(request), securityManager, basePath) + val renderJsonHandler = createServletHandler(pagePath.stripSuffix("/") + "/json", + (request: HttpServletRequest) => page.renderJson(request), securityManager, basePath) + attachHandler(renderHandler) + attachHandler(renderJsonHandler) + pageToHandlers.getOrElseUpdate(page, ArrayBuffer[ServletContextHandler]()) + .append(renderHandler) + pageToHandlers.getOrElseUpdate(page, ArrayBuffer[ServletContextHandler]()) + .append(renderJsonHandler) + } /** Attach a handler to this UI. */ diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index c82730f524eb7..f0ae95bb8c812 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -43,7 +43,7 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage } id }.getOrElse { - return Text(s"Missing executorId parameter") + throw new IllegalArgumentException(s"Missing executorId parameter") } val time = System.currentTimeMillis() val maybeThreadDump = sc.get.getExecutorThreadDump(executorId) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index 363cb96de7998..956608d7c0cbe 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -26,7 +26,8 @@ import org.apache.spark.ui.{ToolTips, UIUtils, WebUIPage} import org.apache.spark.util.Utils /** Summary information about an executor to display in the UI. */ -private case class ExecutorSummaryInfo( +// Needs to be private[ui] because of a false positive MiMa failure. +private[ui] case class ExecutorSummaryInfo( id: String, hostPort: String, rddBlocks: Int, @@ -40,7 +41,8 @@ private case class ExecutorSummaryInfo( totalInputBytes: Long, totalShuffleRead: Long, totalShuffleWrite: Long, - maxMemory: Long) + maxMemory: Long, + executorLogs: Map[String, String]) private[ui] class ExecutorsPage( parent: ExecutorsTab, @@ -55,6 +57,7 @@ private[ui] class ExecutorsPage( val diskUsed = storageStatusList.map(_.diskUsed).sum val execInfo = for (statusId <- 0 until storageStatusList.size) yield getExecInfo(statusId) val execInfoSorted = execInfo.sortBy(_.id) + val logsExist = execInfo.filter(_.executorLogs.nonEmpty).nonEmpty val execTable = @@ -79,10 +82,11 @@ private[ui] class ExecutorsPage( Shuffle Write + {if (logsExist) else Seq.empty} {if (threadDumpEnabled) else Seq.empty} - {execInfoSorted.map(execRow)} + {execInfoSorted.map(execRow(_, logsExist))}
    LogsThread Dump
    @@ -107,7 +111,7 @@ private[ui] class ExecutorsPage( } /** Render an HTML row representing an executor */ - private def execRow(info: ExecutorSummaryInfo): Seq[Node] = { + private def execRow(info: ExecutorSummaryInfo, logsExist: Boolean): Seq[Node] = { val maximumMemory = info.maxMemory val memoryUsed = info.memoryUsed val diskUsed = info.diskUsed @@ -138,6 +142,21 @@ private[ui] class ExecutorsPage( {Utils.bytesToString(info.totalShuffleWrite)} + { + if (logsExist) { + + { + info.executorLogs.map { case (logName, logUrl) => + + } + } + + } + } { if (threadDumpEnabled) { val encodedId = URLEncoder.encode(info.id, "UTF-8") @@ -168,6 +187,7 @@ private[ui] class ExecutorsPage( val totalInputBytes = listener.executorToInputBytes.getOrElse(execId, 0L) val totalShuffleRead = listener.executorToShuffleRead.getOrElse(execId, 0L) val totalShuffleWrite = listener.executorToShuffleWrite.getOrElse(execId, 0L) + val executorLogs = listener.executorToLogUrls.getOrElse(execId, Map.empty) new ExecutorSummaryInfo( execId, @@ -183,7 +203,8 @@ private[ui] class ExecutorsPage( totalInputBytes, totalShuffleRead, totalShuffleWrite, - maxMem + maxMem, + executorLogs ) } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index dd1c2b78c4094..3afd7ef07d7c9 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -48,12 +48,20 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp val executorToTasksFailed = HashMap[String, Int]() val executorToDuration = HashMap[String, Long]() val executorToInputBytes = HashMap[String, Long]() + val executorToInputRecords = HashMap[String, Long]() val executorToOutputBytes = HashMap[String, Long]() + val executorToOutputRecords = HashMap[String, Long]() val executorToShuffleRead = HashMap[String, Long]() val executorToShuffleWrite = HashMap[String, Long]() + val executorToLogUrls = HashMap[String, Map[String, String]]() def storageStatusList = storageStatusListener.storageStatusList + override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded) = synchronized { + val eid = executorAdded.executorId + executorToLogUrls(eid) = executorAdded.executorInfo.logUrlMap + } + override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { val eid = taskStart.taskInfo.executorId executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 0) + 1 @@ -78,10 +86,14 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp metrics.inputMetrics.foreach { inputMetrics => executorToInputBytes(eid) = executorToInputBytes.getOrElse(eid, 0L) + inputMetrics.bytesRead + executorToInputRecords(eid) = + executorToInputRecords.getOrElse(eid, 0L) + inputMetrics.recordsRead } metrics.outputMetrics.foreach { outputMetrics => executorToOutputBytes(eid) = executorToOutputBytes.getOrElse(eid, 0L) + outputMetrics.bytesWritten + executorToOutputRecords(eid) = + executorToOutputRecords.getOrElse(eid, 0L) + outputMetrics.recordsWritten } metrics.shuffleReadMetrics.foreach { shuffleRead => executorToShuffleRead(eid) = diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 045c69da06feb..bd923d78a86ce 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -42,7 +42,9 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { } def makeRow(job: JobUIData): Seq[Node] = { - val lastStageInfo = listener.stageIdToInfo.get(job.stageIds.max) + val lastStageInfo = Option(job.stageIds) + .filter(_.nonEmpty) + .flatMap { ids => listener.stageIdToInfo.get(ids.max) } val lastStageData = lastStageInfo.flatMap { s => listener.stageIdToData.get((s.stageId, s.attemptId)) } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 9836d11a6d85f..1f8536d1b7195 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -36,6 +36,20 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage /** Special table which merges two header cells. */ private def executorTable[T](): Seq[Node] = { + val stageData = listener.stageIdToData.get((stageId, stageAttemptId)) + var hasInput = false + var hasOutput = false + var hasShuffleWrite = false + var hasShuffleRead = false + var hasBytesSpilled = false + stageData.foreach(data => { + hasInput = data.hasInput + hasOutput = data.hasOutput + hasShuffleRead = data.hasShuffleRead + hasShuffleWrite = data.hasShuffleWrite + hasBytesSpilled = data.hasBytesSpilled + }) + @@ -44,12 +58,32 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage - - - - - - + {if (hasInput) { + + }} + {if (hasOutput) { + + }} + {if (hasShuffleRead) { + + }} + {if (hasShuffleWrite) { + + }} + {if (hasBytesSpilled) { + + + }} {createExecutorTable()} @@ -76,18 +110,34 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage - - - - - - + {if (stageData.hasInput) { + + }} + {if (stageData.hasOutput) { + + }} + {if (stageData.hasShuffleRead) { + + }} + {if (stageData.hasShuffleWrite) { + + }} + {if (stageData.hasBytesSpilled) { + + + }} } case None => diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 77d36209c6048..7541d3e9c72e7 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -32,7 +32,10 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { def render(request: HttpServletRequest): Seq[Node] = { listener.synchronized { - val jobId = request.getParameter("id").toInt + val parameterId = request.getParameter("id") + require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") + + val jobId = parameterId.toInt val jobDataOption = listener.jobIdToData.get(jobId) if (jobDataOption.isEmpty) { val content = diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 4d200eeda86b9..78eae535c6eaa 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -44,6 +44,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { // These type aliases are public because they're used in the types of public fields: type JobId = Int + type JobGroupId = String type StageId = Int type StageAttemptId = Int type PoolName = String @@ -54,6 +55,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val completedJobs = ListBuffer[JobUIData]() val failedJobs = ListBuffer[JobUIData]() val jobIdToData = new HashMap[JobId, JobUIData] + val jobGroupToJobIds = new HashMap[JobGroupId, HashSet[JobId]] // Stages: val pendingStages = new HashMap[StageId, StageInfo] @@ -119,7 +121,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { Map( "jobIdToData" -> jobIdToData.size, "stageIdToData" -> stageIdToData.size, - "stageIdToStageInfo" -> stageIdToInfo.size + "stageIdToStageInfo" -> stageIdToInfo.size, + "jobGroupToJobIds" -> jobGroupToJobIds.values.map(_.size).sum, + // Since jobGroupToJobIds is map of sets, check that we don't leak keys with empty values: + "jobGroupToJobIds keySet" -> jobGroupToJobIds.keys.size ) } @@ -140,7 +145,19 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { if (jobs.size > retainedJobs) { val toRemove = math.max(retainedJobs / 10, 1) jobs.take(toRemove).foreach { job => - jobIdToData.remove(job.jobId) + // Remove the job's UI data, if it exists + jobIdToData.remove(job.jobId).foreach { removedJob => + // A null jobGroupId is used for jobs that are run without a job group + val jobGroupId = removedJob.jobGroup.orNull + // Remove the job group -> job mapping entry, if it exists + jobGroupToJobIds.get(jobGroupId).foreach { jobsInGroup => + jobsInGroup.remove(job.jobId) + // If this was the last job in this job group, remove the map entry for the job group + if (jobsInGroup.isEmpty) { + jobGroupToJobIds.remove(jobGroupId) + } + } + } } jobs.trimStart(toRemove) } @@ -158,6 +175,8 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageIds = jobStart.stageIds, jobGroup = jobGroup, status = JobExecutionStatus.RUNNING) + // A null jobGroupId is used for jobs that are run without a job group + jobGroupToJobIds.getOrElseUpdate(jobGroup.orNull, new HashSet[JobId]).add(jobStart.jobId) jobStart.stageInfos.foreach(x => pendingStages(x.stageId) = x) // Compute (a potential underestimate of) the number of tasks that will be run by this job. // This may be an underestimate because the job start event references all of the result @@ -203,6 +222,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { for (stageId <- jobData.stageIds) { stageIdToActiveJobIds.get(stageId).foreach { jobsUsingStage => jobsUsingStage.remove(jobEnd.jobId) + if (jobsUsingStage.isEmpty) { + stageIdToActiveJobIds.remove(stageId) + } stageIdToInfo.get(stageId).foreach { stageInfo => if (stageInfo.submissionTime.isEmpty) { // if this stage is pending, it won't complete, so mark it as "skipped": @@ -394,24 +416,48 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageData.shuffleWriteBytes += shuffleWriteDelta execSummary.shuffleWrite += shuffleWriteDelta + val shuffleWriteRecordsDelta = + (taskMetrics.shuffleWriteMetrics.map(_.shuffleRecordsWritten).getOrElse(0L) + - oldMetrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleRecordsWritten).getOrElse(0L)) + stageData.shuffleWriteRecords += shuffleWriteRecordsDelta + execSummary.shuffleWriteRecords += shuffleWriteRecordsDelta + val shuffleReadDelta = - (taskMetrics.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L) - - oldMetrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead).getOrElse(0L)) - stageData.shuffleReadBytes += shuffleReadDelta + (taskMetrics.shuffleReadMetrics.map(_.totalBytesRead).getOrElse(0L) + - oldMetrics.flatMap(_.shuffleReadMetrics).map(_.totalBytesRead).getOrElse(0L)) + stageData.shuffleReadTotalBytes += shuffleReadDelta execSummary.shuffleRead += shuffleReadDelta + val shuffleReadRecordsDelta = + (taskMetrics.shuffleReadMetrics.map(_.recordsRead).getOrElse(0L) + - oldMetrics.flatMap(_.shuffleReadMetrics).map(_.recordsRead).getOrElse(0L)) + stageData.shuffleReadRecords += shuffleReadRecordsDelta + execSummary.shuffleReadRecords += shuffleReadRecordsDelta + val inputBytesDelta = (taskMetrics.inputMetrics.map(_.bytesRead).getOrElse(0L) - oldMetrics.flatMap(_.inputMetrics).map(_.bytesRead).getOrElse(0L)) stageData.inputBytes += inputBytesDelta execSummary.inputBytes += inputBytesDelta + val inputRecordsDelta = + (taskMetrics.inputMetrics.map(_.recordsRead).getOrElse(0L) + - oldMetrics.flatMap(_.inputMetrics).map(_.recordsRead).getOrElse(0L)) + stageData.inputRecords += inputRecordsDelta + execSummary.inputRecords += inputRecordsDelta + val outputBytesDelta = (taskMetrics.outputMetrics.map(_.bytesWritten).getOrElse(0L) - oldMetrics.flatMap(_.outputMetrics).map(_.bytesWritten).getOrElse(0L)) stageData.outputBytes += outputBytesDelta execSummary.outputBytes += outputBytesDelta + val outputRecordsDelta = + (taskMetrics.outputMetrics.map(_.recordsWritten).getOrElse(0L) + - oldMetrics.flatMap(_.outputMetrics).map(_.recordsWritten).getOrElse(0L)) + stageData.outputRecords += outputRecordsDelta + execSummary.outputRecords += outputRecordsDelta + val diskSpillDelta = taskMetrics.diskBytesSpilled - oldMetrics.map(_.diskBytesSpilled).getOrElse(0L) stageData.diskBytesSpilled += diskSpillDelta diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index 5fc6cc7533150..f47cdc935e539 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -32,6 +32,8 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { def render(request: HttpServletRequest): Seq[Node] = { listener.synchronized { val poolName = request.getParameter("poolname") + require(poolName != null && poolName.nonEmpty, "Missing poolname parameter") + val poolToActiveStages = listener.poolToActiveStages val activeStages = poolToActiveStages.get(poolName) match { case Some(s) => s.values.toSeq diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index d8be1b20b3acd..a1215562eb427 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -36,8 +36,14 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { def render(request: HttpServletRequest): Seq[Node] = { listener.synchronized { - val stageId = request.getParameter("id").toInt - val stageAttemptId = request.getParameter("attempt").toInt + val parameterId = request.getParameter("id") + require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") + + val parameterAttempt = request.getParameter("attempt") + require(parameterAttempt != null && parameterAttempt.nonEmpty, "Missing attempt parameter") + + val stageId = parameterId.toInt + val stageAttemptId = parameterAttempt.toInt val stageDataOption = listener.stageIdToData.get((stageId, stageAttemptId)) if (stageDataOption.isEmpty || stageDataOption.get.taskData.isEmpty) { @@ -56,11 +62,6 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val numCompleted = tasks.count(_.taskInfo.finished) val accumulables = listener.stageIdToData((stageId, stageAttemptId)).accumulables val hasAccumulators = accumulables.size > 0 - val hasInput = stageData.inputBytes > 0 - val hasOutput = stageData.outputBytes > 0 - val hasShuffleRead = stageData.shuffleReadBytes > 0 - val hasShuffleWrite = stageData.shuffleWriteBytes > 0 - val hasBytesSpilled = stageData.memoryBytesSpilled > 0 && stageData.diskBytesSpilled > 0 val summary =
    @@ -69,31 +70,33 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { Total task time across all tasks: {UIUtils.formatDuration(stageData.executorRunTime)} - {if (hasInput) { + {if (stageData.hasInput) {
  • - Input: - {Utils.bytesToString(stageData.inputBytes)} + Input Size / Records: + {s"${Utils.bytesToString(stageData.inputBytes)} / ${stageData.inputRecords}"}
  • }} - {if (hasOutput) { + {if (stageData.hasOutput) {
  • Output: - {Utils.bytesToString(stageData.outputBytes)} + {s"${Utils.bytesToString(stageData.outputBytes)} / ${stageData.outputRecords}"}
  • }} - {if (hasShuffleRead) { + {if (stageData.hasShuffleRead) {
  • Shuffle read: - {Utils.bytesToString(stageData.shuffleReadBytes)} + {s"${Utils.bytesToString(stageData.shuffleReadTotalBytes)} / " + + s"${stageData.shuffleReadRecords}"}
  • }} - {if (hasShuffleWrite) { + {if (stageData.hasShuffleWrite) {
  • Shuffle write: - {Utils.bytesToString(stageData.shuffleWriteBytes)} + {s"${Utils.bytesToString(stageData.shuffleWriteBytes)} / " + + s"${stageData.shuffleWriteRecords}"}
  • }} - {if (hasBytesSpilled) { + {if (stageData.hasBytesSpilled) {
  • Shuffle spill (memory): {Utils.bytesToString(stageData.memoryBytesSpilled)} @@ -132,7 +135,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { Task Deserialization Time
  • - {if (hasShuffleRead) { + {if (stageData.hasShuffleRead) {
  • @@ -140,6 +143,13 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { Shuffle Read Blocked Time
  • +
  • + + + Shuffle Remote Reads + +
  • }}
  • + getDistributionQuantiles(times).map { millis =>
  • } } @@ -247,11 +268,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { +: getFormattedTimeQuantiles(serializationTimes) val gettingResultTimes = validTasks.map { case TaskUIData(info, _, _) => - if (info.gettingResultTime > 0) { - (info.finishTime - info.gettingResultTime).toDouble - } else { - 0.0 - } + getGettingResultTime(info).toDouble } val gettingResultQuantiles = ) + getDistributionQuantiles(data).map(d => ) + + def getFormattedSizeQuantilesWithRecords(data: Seq[Double], records: Seq[Double]) = { + val recordDist = getDistributionQuantiles(records).iterator + getDistributionQuantiles(data).map(d => + + ) + } val inputSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.inputMetrics.map(_.bytesRead).getOrElse(0L).toDouble } - val inputQuantiles = +: getFormattedSizeQuantiles(inputSizes) + + val inputRecords = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.inputMetrics.map(_.recordsRead).getOrElse(0L).toDouble + } + + val inputQuantiles = +: + getFormattedSizeQuantilesWithRecords(inputSizes, inputRecords) val outputSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.outputMetrics.map(_.bytesWritten).getOrElse(0L).toDouble } - val outputQuantiles = +: getFormattedSizeQuantiles(outputSizes) + + val outputRecords = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.outputMetrics.map(_.recordsWritten).getOrElse(0L).toDouble + } + + val outputQuantiles = +: + getFormattedSizeQuantilesWithRecords(outputSizes, outputRecords) val shuffleReadBlockedTimes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.shuffleReadMetrics.map(_.fetchWaitTime).getOrElse(0L).toDouble } - val shuffleReadBlockedQuantiles = +: + val shuffleReadBlockedQuantiles = + +: getFormattedTimeQuantiles(shuffleReadBlockedTimes) - val shuffleReadSizes = validTasks.map { case TaskUIData(_, metrics, _) => + val shuffleReadTotalSizes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.shuffleReadMetrics.map(_.totalBytesRead).getOrElse(0L).toDouble + } + val shuffleReadTotalRecords = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.shuffleReadMetrics.map(_.recordsRead).getOrElse(0L).toDouble + } + val shuffleReadTotalQuantiles = + +: + getFormattedSizeQuantilesWithRecords(shuffleReadTotalSizes, shuffleReadTotalRecords) + + val shuffleReadRemoteSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble } - val shuffleReadQuantiles = +: - getFormattedSizeQuantiles(shuffleReadSizes) + val shuffleReadRemoteQuantiles = + +: + getFormattedSizeQuantiles(shuffleReadRemoteSizes) val shuffleWriteSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L).toDouble } - val shuffleWriteQuantiles = +: - getFormattedSizeQuantiles(shuffleWriteSizes) + + val shuffleWriteRecords = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.shuffleWriteMetrics.map(_.shuffleRecordsWritten).getOrElse(0L).toDouble + } + + val shuffleWriteQuantiles = +: + getFormattedSizeQuantilesWithRecords(shuffleWriteSizes, shuffleWriteRecords) val memoryBytesSpilledSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.memoryBytesSpilled.toDouble @@ -326,19 +394,22 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { {serializationQuantiles} , {gettingResultQuantiles}, - if (hasInput) {inputQuantiles} else Nil, - if (hasOutput) {outputQuantiles} else Nil, - if (hasShuffleRead) { + if (stageData.hasInput) {inputQuantiles} else Nil, + if (stageData.hasOutput) {outputQuantiles} else Nil, + if (stageData.hasShuffleRead) { {shuffleReadBlockedQuantiles} - {shuffleReadQuantiles} + {shuffleReadTotalQuantiles} + + {shuffleReadRemoteQuantiles} + } else { Nil }, - if (hasShuffleWrite) {shuffleWriteQuantiles} else Nil, - if (hasBytesSpilled) {memoryBytesSpilledQuantiles} else Nil, - if (hasBytesSpilled) {diskBytesSpilledQuantiles} else Nil) + if (stageData.hasShuffleWrite) {shuffleWriteQuantiles} else Nil, + if (stageData.hasBytesSpilled) {memoryBytesSpilledQuantiles} else Nil, + if (stageData.hasBytesSpilled) {diskBytesSpilledQuantiles} else Nil) val quantileHeaders = Seq("Metric", "Min", "25th percentile", "Median", "75th percentile", "Max") @@ -387,7 +458,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L) val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) - val gettingResultTime = info.gettingResultTime + val gettingResultTime = getGettingResultTime(info) val maybeAccumulators = info.accumulables val accumulatorsReadable = maybeAccumulators.map{acc => s"${acc.name}: ${acc.update.get}"} @@ -397,26 +468,36 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val inputReadable = maybeInput .map(m => s"${Utils.bytesToString(m.bytesRead)} (${m.readMethod.toString.toLowerCase()})") .getOrElse("") + val inputRecords = maybeInput.map(_.recordsRead.toString).getOrElse("") val maybeOutput = metrics.flatMap(_.outputMetrics) val outputSortable = maybeOutput.map(_.bytesWritten.toString).getOrElse("") val outputReadable = maybeOutput .map(m => s"${Utils.bytesToString(m.bytesWritten)}") .getOrElse("") + val outputRecords = maybeOutput.map(_.recordsWritten.toString).getOrElse("") - val maybeShuffleReadBlockedTime = metrics.flatMap(_.shuffleReadMetrics).map(_.fetchWaitTime) - val shuffleReadBlockedTimeSortable = maybeShuffleReadBlockedTime.map(_.toString).getOrElse("") + val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics) + val shuffleReadBlockedTimeSortable = maybeShuffleRead + .map(_.fetchWaitTime.toString).getOrElse("") val shuffleReadBlockedTimeReadable = - maybeShuffleReadBlockedTime.map(ms => UIUtils.formatDuration(ms)).getOrElse("") + maybeShuffleRead.map(ms => UIUtils.formatDuration(ms.fetchWaitTime)).getOrElse("") - val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead) - val shuffleReadSortable = maybeShuffleRead.map(_.toString).getOrElse("") - val shuffleReadReadable = maybeShuffleRead.map(Utils.bytesToString).getOrElse("") + val totalShuffleBytes = maybeShuffleRead.map(_.totalBytesRead) + val shuffleReadSortable = totalShuffleBytes.map(_.toString).getOrElse("") + val shuffleReadReadable = totalShuffleBytes.map(Utils.bytesToString).getOrElse("") + val shuffleReadRecords = maybeShuffleRead.map(_.recordsRead.toString).getOrElse("") - val maybeShuffleWrite = - metrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleBytesWritten) - val shuffleWriteSortable = maybeShuffleWrite.map(_.toString).getOrElse("") - val shuffleWriteReadable = maybeShuffleWrite.map(Utils.bytesToString).getOrElse("") + val remoteShuffleBytes = maybeShuffleRead.map(_.remoteBytesRead) + val shuffleReadRemoteSortable = remoteShuffleBytes.map(_.toString).getOrElse("") + val shuffleReadRemoteReadable = remoteShuffleBytes.map(Utils.bytesToString).getOrElse("") + + val maybeShuffleWrite = metrics.flatMap(_.shuffleWriteMetrics) + val shuffleWriteSortable = maybeShuffleWrite.map(_.shuffleBytesWritten.toString).getOrElse("") + val shuffleWriteReadable = maybeShuffleWrite + .map(m => s"${Utils.bytesToString(m.shuffleBytesWritten)}").getOrElse("") + val shuffleWriteRecords = maybeShuffleWrite + .map(_.shuffleRecordsWritten.toString).getOrElse("") val maybeWriteTime = metrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleWriteTime) val writeTimeSortable = maybeWriteTime.map(_.toString).getOrElse("") @@ -472,12 +553,12 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { }} {if (hasInput) { }} {if (hasOutput) { }} {if (hasShuffleRead) { @@ -486,7 +567,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { {shuffleReadBlockedTimeReadable} + }} {if (hasShuffleWrite) { @@ -494,7 +579,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { {writeTimeReadable} }} {if (hasBytesSpilled) { @@ -536,16 +621,32 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { } - private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics): Long = { - val totalExecutionTime = { - if (info.gettingResultTime > 0) { - (info.gettingResultTime - info.launchTime) + private def getGettingResultTime(info: TaskInfo): Long = { + if (info.gettingResultTime > 0) { + if (info.finishTime > 0) { + info.finishTime - info.gettingResultTime } else { - (info.finishTime - info.launchTime) + // The task is still fetching the result. + System.currentTimeMillis - info.gettingResultTime } + } else { + 0L } + } + + private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics): Long = { + val totalExecutionTime = + if (info.gettingResult) { + info.gettingResultTime - info.launchTime + } else if (info.finished) { + info.finishTime - info.launchTime + } else { + 0 + } val executorOverhead = (metrics.executorDeserializeTime + metrics.resultSerializationTime) - totalExecutionTime - metrics.executorRunTime - executorOverhead + math.max( + 0, + totalExecutionTime - metrics.executorRunTime - executorOverhead - getGettingResultTime(info)) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 703d43f9c640d..5865850fa09b5 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -138,7 +138,7 @@ private[ui] class StageTableBase( val inputReadWithUnit = if (inputRead > 0) Utils.bytesToString(inputRead) else "" val outputWrite = stageData.outputBytes val outputWriteWithUnit = if (outputWrite > 0) Utils.bytesToString(outputWrite) else "" - val shuffleRead = stageData.shuffleReadBytes + val shuffleRead = stageData.shuffleReadTotalBytes val shuffleReadWithUnit = if (shuffleRead > 0) Utils.bytesToString(shuffleRead) else "" val shuffleWrite = stageData.shuffleWriteBytes val shuffleWriteWithUnit = if (shuffleWrite > 0) Utils.bytesToString(shuffleWrite) else "" diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala index 37cf2c207ba40..9bf67db8acde1 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala @@ -28,6 +28,7 @@ private[spark] object TaskDetailsClassNames { val SCHEDULER_DELAY = "scheduler_delay" val TASK_DESERIALIZATION_TIME = "deserialization_time" val SHUFFLE_READ_BLOCKED_TIME = "fetch_wait_time" + val SHUFFLE_READ_REMOTE_SIZE = "shuffle_read_remote" val RESULT_SERIALIZATION_TIME = "serialization_time" val GETTING_RESULT_TIME = "getting_result_time" } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index 01f7e23212c3d..dbf1ceeda1878 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -31,9 +31,13 @@ private[jobs] object UIData { var failedTasks : Int = 0 var succeededTasks : Int = 0 var inputBytes : Long = 0 + var inputRecords : Long = 0 var outputBytes : Long = 0 + var outputRecords : Long = 0 var shuffleRead : Long = 0 + var shuffleReadRecords : Long = 0 var shuffleWrite : Long = 0 + var shuffleWriteRecords : Long = 0 var memoryBytesSpilled : Long = 0 var diskBytesSpilled : Long = 0 } @@ -73,9 +77,13 @@ private[jobs] object UIData { var executorRunTime: Long = _ var inputBytes: Long = _ + var inputRecords: Long = _ var outputBytes: Long = _ - var shuffleReadBytes: Long = _ + var outputRecords: Long = _ + var shuffleReadTotalBytes: Long = _ + var shuffleReadRecords : Long = _ var shuffleWriteBytes: Long = _ + var shuffleWriteRecords: Long = _ var memoryBytesSpilled: Long = _ var diskBytesSpilled: Long = _ @@ -85,6 +93,12 @@ private[jobs] object UIData { var accumulables = new HashMap[Long, AccumulableInfo] var taskData = new HashMap[Long, TaskUIData] var executorSummary = new HashMap[String, ExecutorSummary] + + def hasInput = inputBytes > 0 + def hasOutput = outputBytes > 0 + def hasShuffleRead = shuffleReadTotalBytes > 0 + def hasShuffleWrite = shuffleWriteBytes > 0 + def hasBytesSpilled = memoryBytesSpilled > 0 && diskBytesSpilled > 0 } /** diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 12d23a92878cf..199f731b92bcc 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -30,7 +30,10 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { - val rddId = request.getParameter("id").toInt + val parameterId = request.getParameter("id") + require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") + + val rddId = parameterId.toInt val storageStatusList = listener.storageStatusList val rddInfo = listener.rddInfoList.find(_.id == rddId).getOrElse { // Rather than crashing, render an "RDD Not Found" page diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 3d9c6192ff7f7..48a6ede05e17b 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -79,8 +79,6 @@ private[spark] object AkkaUtils extends Logging { val logAkkaConfig = if (conf.getBoolean("spark.akka.logAkkaConfig", false)) "on" else "off" val akkaHeartBeatPauses = conf.getInt("spark.akka.heartbeat.pauses", 6000) - val akkaFailureDetector = - conf.getDouble("spark.akka.failure-detector.threshold", 300.0) val akkaHeartBeatInterval = conf.getInt("spark.akka.heartbeat.interval", 1000) val secretKey = securityManager.getSecretKey() @@ -106,7 +104,6 @@ private[spark] object AkkaUtils extends Logging { |akka.remote.secure-cookie = "$secureCookie" |akka.remote.transport-failure-detector.heartbeat-interval = $akkaHeartBeatInterval s |akka.remote.transport-failure-detector.acceptable-heartbeat-pause = $akkaHeartBeatPauses s - |akka.remote.transport-failure-detector.threshold = $akkaFailureDetector |akka.actor.provider = "akka.remote.RemoteActorRefProvider" |akka.remote.netty.tcp.transport-class = "akka.remote.transport.netty.NettyTransport" |akka.remote.netty.tcp.hostname = "$host" diff --git a/core/src/main/scala/org/apache/spark/util/Clock.scala b/core/src/main/scala/org/apache/spark/util/Clock.scala index 97c2b45aabf28..e92ed11bd165b 100644 --- a/core/src/main/scala/org/apache/spark/util/Clock.scala +++ b/core/src/main/scala/org/apache/spark/util/Clock.scala @@ -21,9 +21,47 @@ package org.apache.spark.util * An interface to represent clocks, so that they can be mocked out in unit tests. */ private[spark] trait Clock { - def getTime(): Long + def getTimeMillis(): Long + def waitTillTime(targetTime: Long): Long } -private[spark] object SystemClock extends Clock { - def getTime(): Long = System.currentTimeMillis() +/** + * A clock backed by the actual time from the OS as reported by the `System` API. + */ +private[spark] class SystemClock extends Clock { + + val minPollTime = 25L + + /** + * @return the same time (milliseconds since the epoch) + * as is reported by `System.currentTimeMillis()` + */ + def getTimeMillis(): Long = System.currentTimeMillis() + + /** + * @param targetTime block until the current time is at least this value + * @return current system time when wait has completed + */ + def waitTillTime(targetTime: Long): Long = { + var currentTime = 0L + currentTime = System.currentTimeMillis() + + var waitTime = targetTime - currentTime + if (waitTime <= 0) { + return currentTime + } + + val pollTime = math.max(waitTime / 10.0, minPollTime).toLong + + while (true) { + currentTime = System.currentTimeMillis() + waitTime = targetTime - currentTime + if (waitTime <= 0) { + return currentTime + } + val sleepTime = math.min(waitTime, pollTime) + Thread.sleep(sleepTime) + } + -1 + } } diff --git a/core/src/main/scala/org/apache/spark/util/EventLoop.scala b/core/src/main/scala/org/apache/spark/util/EventLoop.scala index b0ed908b84424..e9b2b8d24b476 100644 --- a/core/src/main/scala/org/apache/spark/util/EventLoop.scala +++ b/core/src/main/scala/org/apache/spark/util/EventLoop.scala @@ -76,9 +76,21 @@ private[spark] abstract class EventLoop[E](name: String) extends Logging { def stop(): Unit = { if (stopped.compareAndSet(false, true)) { eventThread.interrupt() - eventThread.join() - // Call onStop after the event thread exits to make sure onReceive happens before onStop - onStop() + var onStopCalled = false + try { + eventThread.join() + // Call onStop after the event thread exits to make sure onReceive happens before onStop + onStopCalled = true + onStop() + } catch { + case ie: InterruptedException => + Thread.currentThread().interrupt() + if (!onStopCalled) { + // ie is thrown from `eventThread.join()`. Otherwise, we should not call `onStop` since + // it's already called. + onStop() + } + } } else { // Keep quiet to allow calling `stop` multiple times. } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 8e0e41ad3782e..474f79fb756f6 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -89,6 +89,8 @@ private[spark] object JsonProtocol { executorAddedToJson(executorAdded) case executorRemoved: SparkListenerExecutorRemoved => executorRemovedToJson(executorRemoved) + case logStart: SparkListenerLogStart => + logStartToJson(logStart) // These aren't used, but keeps compiler happy case SparkListenerExecutorMetricsUpdate(_, _) => JNothing } @@ -214,6 +216,11 @@ private[spark] object JsonProtocol { ("Removed Reason" -> executorRemoved.reason) } + def logStartToJson(logStart: SparkListenerLogStart): JValue = { + ("Event" -> Utils.getFormattedClassName(logStart)) ~ + ("Spark Version" -> SPARK_VERSION) + } + /** ------------------------------------------------------------------- * * JSON serialization methods for classes SparkListenerEvents depend on | * -------------------------------------------------------------------- */ @@ -293,22 +300,27 @@ private[spark] object JsonProtocol { ("Remote Blocks Fetched" -> shuffleReadMetrics.remoteBlocksFetched) ~ ("Local Blocks Fetched" -> shuffleReadMetrics.localBlocksFetched) ~ ("Fetch Wait Time" -> shuffleReadMetrics.fetchWaitTime) ~ - ("Remote Bytes Read" -> shuffleReadMetrics.remoteBytesRead) + ("Remote Bytes Read" -> shuffleReadMetrics.remoteBytesRead) ~ + ("Local Bytes Read" -> shuffleReadMetrics.localBytesRead) ~ + ("Total Records Read" -> shuffleReadMetrics.recordsRead) } def shuffleWriteMetricsToJson(shuffleWriteMetrics: ShuffleWriteMetrics): JValue = { ("Shuffle Bytes Written" -> shuffleWriteMetrics.shuffleBytesWritten) ~ - ("Shuffle Write Time" -> shuffleWriteMetrics.shuffleWriteTime) + ("Shuffle Write Time" -> shuffleWriteMetrics.shuffleWriteTime) ~ + ("Shuffle Records Written" -> shuffleWriteMetrics.shuffleRecordsWritten) } def inputMetricsToJson(inputMetrics: InputMetrics): JValue = { ("Data Read Method" -> inputMetrics.readMethod.toString) ~ - ("Bytes Read" -> inputMetrics.bytesRead) + ("Bytes Read" -> inputMetrics.bytesRead) ~ + ("Records Read" -> inputMetrics.recordsRead) } def outputMetricsToJson(outputMetrics: OutputMetrics): JValue = { ("Data Write Method" -> outputMetrics.writeMethod.toString) ~ - ("Bytes Written" -> outputMetrics.bytesWritten) + ("Bytes Written" -> outputMetrics.bytesWritten) ~ + ("Records Written" -> outputMetrics.recordsWritten) } def taskEndReasonToJson(taskEndReason: TaskEndReason): JValue = { @@ -383,7 +395,8 @@ private[spark] object JsonProtocol { def executorInfoToJson(executorInfo: ExecutorInfo): JValue = { ("Host" -> executorInfo.executorHost) ~ - ("Total Cores" -> executorInfo.totalCores) + ("Total Cores" -> executorInfo.totalCores) ~ + ("Log Urls" -> mapToJson(executorInfo.logUrlMap)) } /** ------------------------------ * @@ -441,6 +454,7 @@ private[spark] object JsonProtocol { val applicationEnd = Utils.getFormattedClassName(SparkListenerApplicationEnd) val executorAdded = Utils.getFormattedClassName(SparkListenerExecutorAdded) val executorRemoved = Utils.getFormattedClassName(SparkListenerExecutorRemoved) + val logStart = Utils.getFormattedClassName(SparkListenerLogStart) (json \ "Event").extract[String] match { case `stageSubmitted` => stageSubmittedFromJson(json) @@ -458,6 +472,7 @@ private[spark] object JsonProtocol { case `applicationEnd` => applicationEndFromJson(json) case `executorAdded` => executorAddedFromJson(json) case `executorRemoved` => executorRemovedFromJson(json) + case `logStart` => logStartFromJson(json) } } @@ -568,6 +583,11 @@ private[spark] object JsonProtocol { SparkListenerExecutorRemoved(time, executorId, reason) } + def logStartFromJson(json: JValue): SparkListenerLogStart = { + val sparkVersion = (json \ "Spark Version").extract[String] + SparkListenerLogStart(sparkVersion) + } + /** --------------------------------------------------------------------- * * JSON deserialization methods for classes SparkListenerEvents depend on | * ---------------------------------------------------------------------- */ @@ -669,6 +689,8 @@ private[spark] object JsonProtocol { metrics.incLocalBlocksFetched((json \ "Local Blocks Fetched").extract[Int]) metrics.incFetchWaitTime((json \ "Fetch Wait Time").extract[Long]) metrics.incRemoteBytesRead((json \ "Remote Bytes Read").extract[Long]) + metrics.incLocalBytesRead((json \ "Local Bytes Read").extractOpt[Long].getOrElse(0)) + metrics.incRecordsRead((json \ "Total Records Read").extractOpt[Long].getOrElse(0)) metrics } @@ -676,13 +698,16 @@ private[spark] object JsonProtocol { val metrics = new ShuffleWriteMetrics metrics.incShuffleBytesWritten((json \ "Shuffle Bytes Written").extract[Long]) metrics.incShuffleWriteTime((json \ "Shuffle Write Time").extract[Long]) + metrics.setShuffleRecordsWritten((json \ "Shuffle Records Written") + .extractOpt[Long].getOrElse(0)) metrics } def inputMetricsFromJson(json: JValue): InputMetrics = { val metrics = new InputMetrics( DataReadMethod.withName((json \ "Data Read Method").extract[String])) - metrics.addBytesRead((json \ "Bytes Read").extract[Long]) + metrics.incBytesRead((json \ "Bytes Read").extract[Long]) + metrics.incRecordsRead((json \ "Records Read").extractOpt[Long].getOrElse(0)) metrics } @@ -690,6 +715,7 @@ private[spark] object JsonProtocol { val metrics = new OutputMetrics( DataWriteMethod.withName((json \ "Data Write Method").extract[String])) metrics.setBytesWritten((json \ "Bytes Written").extract[Long]) + metrics.setRecordsWritten((json \ "Records Written").extractOpt[Long].getOrElse(0)) metrics } @@ -792,7 +818,8 @@ private[spark] object JsonProtocol { def executorInfoFromJson(json: JValue): ExecutorInfo = { val executorHost = (json \ "Host").extract[String] val totalCores = (json \ "Total Cores").extract[Int] - new ExecutorInfo(executorHost, totalCores) + val logUrls = mapFromJson(json \ "Log Urls").toMap + new ExecutorInfo(executorHost, totalCores, logUrls) } /** -------------------------------- * diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index bd0aa4dc4650f..d60b8b9a31a9b 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -28,7 +28,8 @@ import org.apache.spark.Logging */ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { - private val listeners = new CopyOnWriteArrayList[L] + // Marked `private[spark]` for access in tests. + private[spark] val listeners = new CopyOnWriteArrayList[L] /** * Add a listener to listen events. This method is thread-safe and can be called in any thread. diff --git a/core/src/main/scala/org/apache/spark/util/ManualClock.scala b/core/src/main/scala/org/apache/spark/util/ManualClock.scala new file mode 100644 index 0000000000000..cf89c1782fd67 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ManualClock.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * A `Clock` whose time can be manually set and modified. Its reported time does not change + * as time elapses, but only as its time is modified by callers. This is mainly useful for + * testing. + * + * @param time initial time (in milliseconds since the epoch) + */ +private[spark] class ManualClock(private var time: Long) extends Clock { + + /** + * @return `ManualClock` with initial time 0 + */ + def this() = this(0L) + + def getTimeMillis(): Long = + synchronized { + time + } + + /** + * @param timeToSet new time (in milliseconds) that the clock should represent + */ + def setTime(timeToSet: Long) = + synchronized { + time = timeToSet + notifyAll() + } + + /** + * @param timeToAdd time (in milliseconds) to add to the clock's time + */ + def advance(timeToAdd: Long) = + synchronized { + time += timeToAdd + notifyAll() + } + + /** + * @param targetTime block until the clock time is set or advanced to at least this time + * @return current time reported by the clock when waiting finishes + */ + def waitTillTime(targetTime: Long): Long = + synchronized { + while (time < targetTime) { + wait(100) + } + getTimeMillis() + } + +} diff --git a/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala new file mode 100644 index 0000000000000..d9c7103b2f3bf --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.net.{URLClassLoader, URL} +import java.util.Enumeration +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConversions._ + +import org.apache.spark.util.ParentClassLoader + +/** + * URL class loader that exposes the `addURL` and `getURLs` methods in URLClassLoader. + */ +private[spark] class MutableURLClassLoader(urls: Array[URL], parent: ClassLoader) + extends URLClassLoader(urls, parent) { + + override def addURL(url: URL): Unit = { + super.addURL(url) + } + + override def getURLs(): Array[URL] = { + super.getURLs() + } + +} + +/** + * A mutable class loader that gives preference to its own URLs over the parent class loader + * when loading classes and resources. + */ +private[spark] class ChildFirstURLClassLoader(urls: Array[URL], parent: ClassLoader) + extends MutableURLClassLoader(urls, null) { + + private val parentClassLoader = new ParentClassLoader(parent) + + /** + * Used to implement fine-grained class loading locks similar to what is done by Java 7. This + * prevents deadlock issues when using non-hierarchical class loaders. + * + * Note that due to Java 6 compatibility (and some issues with implementing class loaders in + * Scala), Java 7's `ClassLoader.registerAsParallelCapable` method is not called. + */ + private val locks = new ConcurrentHashMap[String, Object]() + + override def loadClass(name: String, resolve: Boolean): Class[_] = { + var lock = locks.get(name) + if (lock == null) { + val newLock = new Object() + lock = locks.putIfAbsent(name, newLock) + if (lock == null) { + lock = newLock + } + } + + lock.synchronized { + try { + super.loadClass(name, resolve) + } catch { + case e: ClassNotFoundException => + parentClassLoader.loadClass(name, resolve) + } + } + } + + override def getResource(name: String): URL = { + val url = super.findResource(name) + val res = if (url != null) url else parentClassLoader.getResource(name) + res + } + + override def getResources(name: String): Enumeration[URL] = { + val urls = super.findResources(name) + val res = + if (urls != null && urls.hasMoreElements()) { + urls + } else { + parentClassLoader.getResources(name) + } + res + } + + override def addURL(url: URL) { + super.addURL(url) + } + +} diff --git a/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala b/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala index 3abc12681fe9a..6d8d9e8da3678 100644 --- a/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala +++ b/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala @@ -18,7 +18,7 @@ package org.apache.spark.util /** - * A class loader which makes findClass accesible to the child + * A class loader which makes some protected methods in ClassLoader accesible. */ private[spark] class ParentClassLoader(parent: ClassLoader) extends ClassLoader(parent) { @@ -29,4 +29,9 @@ private[spark] class ParentClassLoader(parent: ClassLoader) extends ClassLoader( override def loadClass(name: String): Class[_] = { super.loadClass(name) } + + override def loadClass(name: String, resolve: Boolean): Class[_] = { + super.loadClass(name, resolve) + } + } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index e9f2aed9ffbea..c136584b196a3 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -38,6 +38,7 @@ import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} +import org.apache.hadoop.security.UserGroupInformation import org.apache.log4j.PropertyConfigurator import org.eclipse.jetty.util.MultiException import org.json4s._ @@ -212,8 +213,8 @@ private[spark] object Utils extends Logging { // Is the path already registered to be deleted via a shutdown hook ? def hasShutdownDeleteTachyonDir(file: TachyonFile): Boolean = { val absolutePath = file.getPath() - shutdownDeletePaths.synchronized { - shutdownDeletePaths.contains(absolutePath) + shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths.contains(absolutePath) } } @@ -279,16 +280,9 @@ private[spark] object Utils extends Logging { maxAttempts + " attempts!") } try { - dir = new File(root, "spark-" + UUID.randomUUID.toString) + dir = new File(root, namePrefix + "-" + UUID.randomUUID.toString) if (dir.exists() || !dir.mkdirs()) { dir = null - } else { - // Restrict file permissions via chmod if available. - // For Windows this step is ignored. - if (!isWindows && !chmod700(dir)) { - dir.delete() - dir = null - } } } catch { case e: SecurityException => dir = null; } } @@ -386,8 +380,10 @@ private[spark] object Utils extends Logging { } /** - * Download a file to target directory. Supports fetching the file in a variety of ways, - * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. + * Download a file or directory to target directory. Supports fetching the file in a variety of + * ways, including HTTP, Hadoop-compatible filesystems, and files on a standard filesystem, based + * on the URL parameter. Fetching directories is only supported from Hadoop-compatible + * filesystems. * * If `useCache` is true, first attempts to fetch the file to a local cache that's shared * across executors running the same application. `useCache` is used mainly for @@ -406,7 +402,8 @@ private[spark] object Utils extends Logging { useCache: Boolean) { val fileName = url.split("/").last val targetFile = new File(targetDir, fileName) - if (useCache) { + val fetchCacheEnabled = conf.getBoolean("spark.files.useFetchCache", defaultValue = true) + if (useCache && fetchCacheEnabled) { val cachedFileName = s"${url.hashCode}${timestamp}_cache" val lockFileName = s"${url.hashCode}${timestamp}_lock" val localDir = new File(getLocalDir(conf)) @@ -456,7 +453,6 @@ private[spark] object Utils extends Logging { * * @param url URL that `sourceFile` originated from, for logging purposes. * @param in InputStream to download. - * @param tempFile File path to download `in` to. * @param destFile File path to move `tempFile` to. * @param fileOverwrite Whether to delete/overwrite an existing `destFile` that does not match * `sourceFile` @@ -464,9 +460,11 @@ private[spark] object Utils extends Logging { private def downloadFile( url: String, in: InputStream, - tempFile: File, destFile: File, fileOverwrite: Boolean): Unit = { + val tempFile = File.createTempFile("fetchFileTemp", null, + new File(destFile.getParentFile.getAbsolutePath)) + logInfo(s"Fetching $url to $tempFile") try { val out = new FileOutputStream(tempFile) @@ -505,7 +503,7 @@ private[spark] object Utils extends Logging { removeSourceFile: Boolean = false): Unit = { if (destFile.exists) { - if (!Files.equal(sourceFile, destFile)) { + if (!filesEqualRecursive(sourceFile, destFile)) { if (fileOverwrite) { logInfo( s"File $destFile exists and does not match contents of $url, replacing it with $url" @@ -540,13 +538,44 @@ private[spark] object Utils extends Logging { Files.move(sourceFile, destFile) } else { logInfo(s"Copying ${sourceFile.getAbsolutePath} to ${destFile.getAbsolutePath}") - Files.copy(sourceFile, destFile) + copyRecursive(sourceFile, destFile) + } + } + + private def filesEqualRecursive(file1: File, file2: File): Boolean = { + if (file1.isDirectory && file2.isDirectory) { + val subfiles1 = file1.listFiles() + val subfiles2 = file2.listFiles() + if (subfiles1.size != subfiles2.size) { + return false + } + subfiles1.sortBy(_.getName).zip(subfiles2.sortBy(_.getName)).forall { + case (f1, f2) => filesEqualRecursive(f1, f2) + } + } else if (file1.isFile && file2.isFile) { + Files.equal(file1, file2) + } else { + false + } + } + + private def copyRecursive(source: File, dest: File): Unit = { + if (source.isDirectory) { + if (!dest.mkdir()) { + throw new IOException(s"Failed to create directory ${dest.getPath}") + } + val subfiles = source.listFiles() + subfiles.foreach(f => copyRecursive(f, new File(dest, f.getName))) + } else { + Files.copy(source, dest) } } /** - * Download a file to target directory. Supports fetching the file in a variety of ways, - * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. + * Download a file or directory to target directory. Supports fetching the file in a variety of + * ways, including HTTP, Hadoop-compatible filesystems, and files on a standard filesystem, based + * on the URL parameter. Fetching directories is only supported from Hadoop-compatible + * filesystems. * * Throws SparkException if the target file already exists and has different contents than * the requested file. @@ -558,14 +587,11 @@ private[spark] object Utils extends Logging { conf: SparkConf, securityMgr: SecurityManager, hadoopConf: Configuration) { - val tempFile = File.createTempFile("fetchFileTemp", null, new File(targetDir.getAbsolutePath)) val targetFile = new File(targetDir, filename) val uri = new URI(url) val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false) Option(uri.getScheme).getOrElse("file") match { case "http" | "https" | "ftp" => - logInfo("Fetching " + url + " to " + tempFile) - var uc: URLConnection = null if (securityMgr.isAuthenticationEnabled()) { logDebug("fetchFile with security enabled") @@ -583,17 +609,48 @@ private[spark] object Utils extends Logging { uc.setReadTimeout(timeout) uc.connect() val in = uc.getInputStream() - downloadFile(url, in, tempFile, targetFile, fileOverwrite) + downloadFile(url, in, targetFile, fileOverwrite) case "file" => // In the case of a local file, copy the local file to the target directory. // Note the difference between uri vs url. val sourceFile = if (uri.isAbsolute) new File(uri) else new File(url) copyFile(url, sourceFile, targetFile, fileOverwrite) case _ => - // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others val fs = getHadoopFileSystem(uri, hadoopConf) - val in = fs.open(new Path(uri)) - downloadFile(url, in, tempFile, targetFile, fileOverwrite) + val path = new Path(uri) + fetchHcfsFile(path, targetDir, fs, conf, hadoopConf, fileOverwrite, + filename = Some(filename)) + } + } + + /** + * Fetch a file or directory from a Hadoop-compatible filesystem. + * + * Visible for testing + */ + private[spark] def fetchHcfsFile( + path: Path, + targetDir: File, + fs: FileSystem, + conf: SparkConf, + hadoopConf: Configuration, + fileOverwrite: Boolean, + filename: Option[String] = None): Unit = { + if (!targetDir.exists() && !targetDir.mkdir()) { + throw new IOException(s"Failed to create directory ${targetDir.getPath}") + } + val dest = new File(targetDir, filename.getOrElse(path.getName)) + if (fs.isFile(path)) { + val in = fs.open(path) + try { + downloadFile(path.toString, in, dest, fileOverwrite) + } finally { + in.close() + } + } else { + fs.listStatus(path).foreach { fileStatus => + fetchHcfsFile(fileStatus.getPath(), dest, fs, conf, hadoopConf, fileOverwrite) + } } } @@ -644,7 +701,9 @@ private[spark] object Utils extends Logging { try { val rootDir = new File(root) if (rootDir.exists || rootDir.mkdirs()) { - Some(createDirectory(root).getAbsolutePath()) + val dir = createDirectory(root) + chmod700(dir) + Some(dir.getAbsolutePath) } else { logError(s"Failed to create dir in $root. Ignoring this directory.") None @@ -1104,9 +1163,9 @@ private[spark] object Utils extends Logging { // finding the call site of a method. val SPARK_CORE_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?(\.broadcast)?\.[A-Z]""".r - val SCALA_CLASS_REGEX = """^scala""".r + val SCALA_CORE_CLASS_PREFIX = "scala" val isSparkCoreClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined - val isScalaClass = SCALA_CLASS_REGEX.findFirstIn(className).isDefined + val isScalaClass = className.startsWith(SCALA_CORE_CLASS_PREFIX) // If the class is a Spark internal class or a Scala class, then exclude. isSparkCoreClass || isScalaClass } @@ -1760,6 +1819,10 @@ private[spark] object Utils extends Logging { startService: Int => (T, Int), conf: SparkConf, serviceName: String = ""): (T, Int) = { + + require(startPort == 0 || (1024 <= startPort && startPort < 65536), + "startPort should be between 1024 and 65535 (inclusive), or 0 for a random free port.") + val serviceString = if (serviceName.isEmpty) "" else s" '$serviceName'" val maxRetries = portMaxRetries(conf) for (offset <- 0 to maxRetries) { @@ -1928,6 +1991,16 @@ private[spark] object Utils extends Logging { throw new SparkException("Invalid master URL: " + sparkUrl, e) } } + + /** + * Returns the current user name. This is the currently logged in user, unless that's been + * overridden by the `SPARK_USER` environment variable. + */ + def getCurrentUserName(): String = { + Option(System.getenv("SPARK_USER")) + .getOrElse(UserGroupInformation.getCurrentUser().getShortUserName()) + } + } /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 6ba03841f746b..244d4af60c4cb 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -352,6 +352,7 @@ private[spark] class ExternalSorter[K, V, C]( // Create our file writers if we haven't done so yet if (partitionWriters == null) { curWriteMetrics = new ShuffleWriteMetrics() + val openStartTime = System.nanoTime partitionWriters = Array.fill(numPartitions) { // Because these files may be read during shuffle, their compression must be controlled by // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use @@ -359,6 +360,10 @@ private[spark] class ExternalSorter[K, V, C]( val (blockId, file) = diskBlockManager.createTempShuffleBlock() blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics).open() } + // Creating the file to write to and creating a disk writer both involve interacting with + // the disk, and can take a long time in aggregate when we open many files, so should be + // included in the shuffle write time. + curWriteMetrics.incShuffleWriteTime(System.nanoTime - openStartTime) } // No need to sort stuff, just write each element out @@ -723,6 +728,7 @@ private[spark] class ExternalSorter[K, V, C]( partitionWriters.foreach(_.commitAndClose()) var out: FileOutputStream = null var in: FileInputStream = null + val writeStartTime = System.nanoTime try { out = new FileOutputStream(outputFile, true) for (i <- 0 until numPartitions) { @@ -739,6 +745,8 @@ private[spark] class ExternalSorter[K, V, C]( if (in != null) { in.close() } + context.taskMetrics.shuffleWriteMetrics.foreach( + _.incShuffleWriteTime(System.nanoTime - writeStartTime)) } } else { // Either we're not bypassing merge-sort or we have only in-memory data; get an iterator by @@ -763,6 +771,7 @@ private[spark] class ExternalSorter[K, V, C]( if (curWriteMetrics != null) { m.incShuffleBytesWritten(curWriteMetrics.shuffleBytesWritten) m.incShuffleWriteTime(curWriteMetrics.shuffleWriteTime) + m.incShuffleRecordsWritten(curWriteMetrics.shuffleRecordsWritten) } } diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index b16a1e9460286..dc9f77360d679 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -267,6 +267,22 @@ public void call(String s) throws IOException { Assert.assertEquals(2, accum.value().intValue()); } + @Test + public void foreachPartition() { + final Accumulator accum = sc.accumulator(0); + JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); + rdd.foreachPartition(new VoidFunction>() { + @Override + public void call(Iterator iter) throws IOException { + while (iter.hasNext()) { + iter.next(); + accum.add(1); + } + } + }); + Assert.assertEquals(2, accum.value().intValue()); + } + @Test public void toLocalIterator() { List correct = Arrays.asList(1, 2, 3, 4); @@ -657,6 +673,13 @@ public Boolean call(Integer i) { }).isEmpty()); } + @Test + public void toArray() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3)); + List list = rdd.toArray(); + Assert.assertEquals(Arrays.asList(1, 2, 3), list); + } + @Test public void cartesian() { JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); @@ -710,6 +733,80 @@ public void javaDoubleRDDHistoGram() { Assert.assertArrayEquals(expected_counts, histogram); } + private static class DoubleComparator implements Comparator, Serializable { + public int compare(Double o1, Double o2) { + return o1.compareTo(o2); + } + } + + @Test + public void max() { + JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + double max = rdd.max(new DoubleComparator()); + Assert.assertEquals(4.0, max, 0.001); + } + + @Test + public void min() { + JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + double max = rdd.min(new DoubleComparator()); + Assert.assertEquals(1.0, max, 0.001); + } + + @Test + public void takeOrdered() { + JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + Assert.assertEquals(Arrays.asList(1.0, 2.0), rdd.takeOrdered(2, new DoubleComparator())); + Assert.assertEquals(Arrays.asList(1.0, 2.0), rdd.takeOrdered(2)); + } + + @Test + public void top() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); + List top2 = rdd.top(2); + Assert.assertEquals(Arrays.asList(4, 3), top2); + } + + private static class AddInts implements Function2 { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + } + + @Test + public void reduce() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); + int sum = rdd.reduce(new AddInts()); + Assert.assertEquals(10, sum); + } + + @Test + public void reduceOnJavaDoubleRDD() { + JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + double sum = rdd.reduce(new Function2() { + @Override + public Double call(Double v1, Double v2) throws Exception { + return v1 + v2; + } + }); + Assert.assertEquals(10.0, sum, 0.001); + } + + @Test + public void fold() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); + int sum = rdd.fold(0, new AddInts()); + Assert.assertEquals(10, sum); + } + + @Test + public void aggregate() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); + int sum = rdd.aggregate(0, new AddInts(), new AddInts()); + Assert.assertEquals(10, sum); + } + @Test public void map() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); @@ -826,6 +923,25 @@ public Iterable call(Iterator iter) { Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); } + + @Test + public void mapPartitionsWithIndex() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); + JavaRDD partitionSums = rdd.mapPartitionsWithIndex( + new Function2, Iterator>() { + @Override + public Iterator call(Integer index, Iterator iter) throws Exception { + int sum = 0; + while (iter.hasNext()) { + sum += iter.next(); + } + return Collections.singletonList(sum).iterator(); + } + }, false); + Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); + } + + @Test public void repartition() { // Shrinking number of partitions @@ -1512,6 +1628,19 @@ public void collectAsync() throws Exception { Assert.assertEquals(1, future.jobIds().size()); } + @Test + public void takeAsync() throws Exception { + List data = Arrays.asList(1, 2, 3, 4, 5); + JavaRDD rdd = sc.parallelize(data, 1); + JavaFutureAction> future = rdd.takeAsync(1); + List result = future.get(); + Assert.assertEquals(1, result.size()); + Assert.assertEquals((Integer) 1, result.get(0)); + Assert.assertFalse(future.isCancelled()); + Assert.assertTrue(future.isDone()); + Assert.assertEquals(1, future.jobIds().size()); + } + @Test public void foreachAsync() throws Exception { List data = Arrays.asList(1, 2, 3, 4, 5); diff --git a/core/src/test/java/org/apache/spark/util/collection/TestTimSort.java b/core/src/test/java/org/apache/spark/util/collection/TestTimSort.java new file mode 100644 index 0000000000000..45772b6d3c20d --- /dev/null +++ b/core/src/test/java/org/apache/spark/util/collection/TestTimSort.java @@ -0,0 +1,134 @@ +/** + * Copyright 2015 Stijn de Gouw + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +package org.apache.spark.util.collection; + +import java.util.*; + +/** + * This codes generates a int array which fails the standard TimSort. + * + * The blog that reported the bug + * http://www.envisage-project.eu/timsort-specification-and-verification/ + * + * This codes was originally wrote by Stijn de Gouw, modified by Evan Yu to adapt to + * our test suite. + * + * https://github.com/abstools/java-timsort-bug + * https://github.com/abstools/java-timsort-bug/blob/master/LICENSE + */ +public class TestTimSort { + + private static final int MIN_MERGE = 32; + + /** + * Returns an array of integers that demonstrate the bug in TimSort + */ + public static int[] getTimSortBugTestSet(int length) { + int minRun = minRunLength(length); + List runs = runsJDKWorstCase(minRun, length); + return createArray(runs, length); + } + + private static int minRunLength(int n) { + int r = 0; // Becomes 1 if any 1 bits are shifted off + while (n >= MIN_MERGE) { + r |= (n & 1); + n >>= 1; + } + return n + r; + } + + private static int[] createArray(List runs, int length) { + int[] a = new int[length]; + Arrays.fill(a, 0); + int endRun = -1; + for (long len : runs) { + a[endRun += len] = 1; + } + a[length - 1] = 0; + return a; + } + + /** + * Fills runs with a sequence of run lengths of the form
    + * Y_n x_{n,1} x_{n,2} ... x_{n,l_n}
    + * Y_{n-1} x_{n-1,1} x_{n-1,2} ... x_{n-1,l_{n-1}}
    + * ...
    + * Y_1 x_{1,1} x_{1,2} ... x_{1,l_1}
    + * The Y_i's are chosen to satisfy the invariant throughout execution, + * but the x_{i,j}'s are merged (by TimSort.mergeCollapse) + * into an X_i that violates the invariant. + * + * @param length The sum of all run lengths that will be added to runs. + */ + private static List runsJDKWorstCase(int minRun, int length) { + List runs = new ArrayList(); + + long runningTotal = 0, Y = minRun + 4, X = minRun; + + while (runningTotal + Y + X <= length) { + runningTotal += X + Y; + generateJDKWrongElem(runs, minRun, X); + runs.add(0, Y); + // X_{i+1} = Y_i + x_{i,1} + 1, since runs.get(1) = x_{i,1} + X = Y + runs.get(1) + 1; + // Y_{i+1} = X_{i+1} + Y_i + 1 + Y += X + 1; + } + + if (runningTotal + X <= length) { + runningTotal += X; + generateJDKWrongElem(runs, minRun, X); + } + + runs.add(length - runningTotal); + return runs; + } + + /** + * Adds a sequence x_1, ..., x_n of run lengths to runs such that:
    + * 1. X = x_1 + ... + x_n
    + * 2. x_j >= minRun for all j
    + * 3. x_1 + ... + x_{j-2} < x_j < x_1 + ... + x_{j-1} for all j
    + * These conditions guarantee that TimSort merges all x_j's one by one + * (resulting in X) using only merges on the second-to-last element. + * + * @param X The sum of the sequence that should be added to runs. + */ + private static void generateJDKWrongElem(List runs, int minRun, long X) { + for (long newTotal; X >= 2 * minRun + 1; X = newTotal) { + //Default strategy + newTotal = X / 2 + 1; + //Specialized strategies + if (3 * minRun + 3 <= X && X <= 4 * minRun + 1) { + // add x_1=MIN+1, x_2=MIN, x_3=X-newTotal to runs + newTotal = 2 * minRun + 1; + } else if (5 * minRun + 5 <= X && X <= 6 * minRun + 5) { + // add x_1=MIN+1, x_2=MIN, x_3=MIN+2, x_4=X-newTotal to runs + newTotal = 3 * minRun + 3; + } else if (8 * minRun + 9 <= X && X <= 10 * minRun + 9) { + // add x_1=MIN+1, x_2=MIN, x_3=MIN+2, x_4=2MIN+2, x_5=X-newTotal to runs + newTotal = 5 * minRun + 5; + } else if (13 * minRun + 15 <= X && X <= 16 * minRun + 17) { + // add x_1=MIN+1, x_2=MIN, x_3=MIN+2, x_4=2MIN+2, x_5=3MIN+4, x_6=X-newTotal to runs + newTotal = 8 * minRun + 9; + } + runs.add(0, X - newTotal); + } + runs.add(0, X); + } +} diff --git a/core/src/test/resources/test_metrics_system.properties b/core/src/test/resources/test_metrics_system.properties index 35d0bd3b8d0b8..4e8b8465696e5 100644 --- a/core/src/test/resources/test_metrics_system.properties +++ b/core/src/test/resources/test_metrics_system.properties @@ -18,7 +18,5 @@ *.sink.console.period = 10 *.sink.console.unit = seconds test.sink.console.class = org.apache.spark.metrics.sink.ConsoleSink -test.sink.dummy.class = org.apache.spark.metrics.sink.DummySink -test.source.dummy.class = org.apache.spark.metrics.source.DummySource test.sink.console.period = 20 test.sink.console.unit = minutes diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index d7d9dc7b50f30..4b25c200a695a 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -17,16 +17,18 @@ package org.apache.spark +import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfter, FunSuite} -import org.scalatest.mock.EasyMockSugar +import org.scalatest.mock.MockitoSugar -import org.apache.spark.executor.{DataReadMethod, TaskMetrics} +import org.apache.spark.executor.DataReadMethod import org.apache.spark.rdd.RDD import org.apache.spark.storage._ // TODO: Test the CacheManager's thread-safety aspects -class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar { - var sc : SparkContext = _ +class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAfter + with MockitoSugar { + var blockManager: BlockManager = _ var cacheManager: CacheManager = _ var split: Partition = _ @@ -57,10 +59,6 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar }.cache() } - after { - sc.stop() - } - test("get uncached rdd") { // Do not mock this test, because attempting to match Array[Any], which is not covariant, // in blockManager.put is a losing battle. You have been warned. @@ -75,29 +73,21 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } test("get cached rdd") { - expecting { - val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12) - blockManager.get(RDDBlockId(0, 0)).andReturn(Some(result)) - } + val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12) + when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result)) - whenExecuting(blockManager) { - val context = new TaskContextImpl(0, 0, 0, 0) - val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) - assert(value.toList === List(5, 6, 7)) - } + val context = new TaskContextImpl(0, 0, 0, 0) + val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) + assert(value.toList === List(5, 6, 7)) } test("get uncached local rdd") { - expecting { - // Local computation should not persist the resulting value, so don't expect a put(). - blockManager.get(RDDBlockId(0, 0)).andReturn(None) - } + // Local computation should not persist the resulting value, so don't expect a put(). + when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None) - whenExecuting(blockManager) { - val context = new TaskContextImpl(0, 0, 0, 0, true) - val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) - assert(value.toList === List(1, 2, 3, 4)) - } + val context = new TaskContextImpl(0, 0, 0, 0, true) + val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) + assert(value.toList === List(1, 2, 3, 4)) } test("verify task metrics updated correctly") { diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 57081ddd959a5..3ded1e4af8742 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -19,18 +19,29 @@ package org.apache.spark import scala.collection.mutable -import org.scalatest.{FunSuite, PrivateMethodTester} +import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.scheduler.cluster.ExecutorInfo +import org.apache.spark.util.ManualClock /** * Test add and remove behavior of ExecutorAllocationManager. */ -class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { +class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAfter { import ExecutorAllocationManager._ import ExecutorAllocationManagerSuite._ + private val contexts = new mutable.ListBuffer[SparkContext]() + + before { + contexts.clear() + } + + after { + contexts.foreach(_.stop()) + } + test("verify min/max executors") { val conf = new SparkConf() .setMaster("local") @@ -38,18 +49,19 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { .set("spark.dynamicAllocation.enabled", "true") .set("spark.dynamicAllocation.testing", "true") val sc0 = new SparkContext(conf) + contexts += sc0 assert(sc0.executorAllocationManager.isDefined) sc0.stop() // Min < 0 val conf1 = conf.clone().set("spark.dynamicAllocation.minExecutors", "-1") - intercept[SparkException] { new SparkContext(conf1) } + intercept[SparkException] { contexts += new SparkContext(conf1) } SparkEnv.get.stop() SparkContext.clearActiveContext() // Max < 0 val conf2 = conf.clone().set("spark.dynamicAllocation.maxExecutors", "-1") - intercept[SparkException] { new SparkContext(conf2) } + intercept[SparkException] { contexts += new SparkContext(conf2) } SparkEnv.get.stop() SparkContext.clearActiveContext() @@ -144,8 +156,8 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { // Verify that running a task reduces the cap sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 3))) - sc.listenerBus.postToAll(SparkListenerBlockManagerAdded( - 0L, BlockManagerId("executor-1", "host1", 1), 100L)) + sc.listenerBus.postToAll(SparkListenerExecutorAdded( + 0L, "executor-1", new ExecutorInfo("host1", 1, Map.empty))) sc.listenerBus.postToAll(SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-1"))) assert(numExecutorsPending(manager) === 4) assert(addExecutors(manager) === 1) @@ -175,6 +187,33 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(numExecutorsPending(manager) === 9) } + test("cancel pending executors when no longer needed") { + sc = createSparkContext(1, 10) + val manager = sc.executorAllocationManager.get + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(2, 5))) + + assert(numExecutorsPending(manager) === 0) + assert(numExecutorsToAdd(manager) === 1) + assert(addExecutors(manager) === 1) + assert(numExecutorsPending(manager) === 1) + assert(numExecutorsToAdd(manager) === 2) + assert(addExecutors(manager) === 2) + assert(numExecutorsPending(manager) === 3) + + val task1Info = createTaskInfo(0, 0, "executor-1") + sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, task1Info)) + + assert(numExecutorsToAdd(manager) === 4) + assert(addExecutors(manager) === 2) + + val task2Info = createTaskInfo(1, 0, "executor-1") + sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, task2Info)) + sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, null, task1Info, null)) + sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, null, task2Info, null)) + + assert(adjustRequestedExecutors(manager) === -1) + } + test("remove executors") { sc = createSparkContext(5, 10) val manager = sc.executorAllocationManager.get @@ -270,15 +309,15 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(removeExecutor(manager, "5")) assert(removeExecutor(manager, "6")) assert(executorIds(manager).size === 10) - assert(addExecutors(manager) === 0) // still at upper limit + assert(addExecutors(manager) === 1) onExecutorRemoved(manager, "3") onExecutorRemoved(manager, "4") assert(executorIds(manager).size === 8) // Add succeeds again, now that we are no longer at the upper limit // Number of executors added restarts at 1 - assert(addExecutors(manager) === 1) - assert(addExecutors(manager) === 1) // upper limit reached again + assert(addExecutors(manager) === 2) + assert(addExecutors(manager) === 1) // upper limit reached assert(addExecutors(manager) === 0) assert(executorIds(manager).size === 8) onExecutorRemoved(manager, "5") @@ -286,9 +325,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { onExecutorAdded(manager, "13") onExecutorAdded(manager, "14") assert(executorIds(manager).size === 8) - assert(addExecutors(manager) === 1) - assert(addExecutors(manager) === 1) // upper limit reached again - assert(addExecutors(manager) === 0) + assert(addExecutors(manager) === 0) // still at upper limit onExecutorAdded(manager, "15") onExecutorAdded(manager, "16") assert(executorIds(manager).size === 10) @@ -296,7 +333,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { test("starting/canceling add timer") { sc = createSparkContext(2, 10) - val clock = new TestClock(8888L) + val clock = new ManualClock(8888L) val manager = sc.executorAllocationManager.get manager.setClock(clock) @@ -305,21 +342,21 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { onSchedulerBacklogged(manager) val firstAddTime = addTime(manager) assert(firstAddTime === clock.getTimeMillis + schedulerBacklogTimeout * 1000) - clock.tick(100L) + clock.advance(100L) onSchedulerBacklogged(manager) assert(addTime(manager) === firstAddTime) // timer is already started - clock.tick(200L) + clock.advance(200L) onSchedulerBacklogged(manager) assert(addTime(manager) === firstAddTime) onSchedulerQueueEmpty(manager) // Restart add timer - clock.tick(1000L) + clock.advance(1000L) assert(addTime(manager) === NOT_SET) onSchedulerBacklogged(manager) val secondAddTime = addTime(manager) assert(secondAddTime === clock.getTimeMillis + schedulerBacklogTimeout * 1000) - clock.tick(100L) + clock.advance(100L) onSchedulerBacklogged(manager) assert(addTime(manager) === secondAddTime) // timer is already started assert(addTime(manager) !== firstAddTime) @@ -328,7 +365,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { test("starting/canceling remove timers") { sc = createSparkContext(2, 10) - val clock = new TestClock(14444L) + val clock = new ManualClock(14444L) val manager = sc.executorAllocationManager.get manager.setClock(clock) @@ -341,17 +378,17 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(removeTimes(manager).contains("1")) val firstRemoveTime = removeTimes(manager)("1") assert(firstRemoveTime === clock.getTimeMillis + executorIdleTimeout * 1000) - clock.tick(100L) + clock.advance(100L) onExecutorIdle(manager, "1") assert(removeTimes(manager)("1") === firstRemoveTime) // timer is already started - clock.tick(200L) + clock.advance(200L) onExecutorIdle(manager, "1") assert(removeTimes(manager)("1") === firstRemoveTime) - clock.tick(300L) + clock.advance(300L) onExecutorIdle(manager, "2") assert(removeTimes(manager)("2") !== firstRemoveTime) // different executor assert(removeTimes(manager)("2") === clock.getTimeMillis + executorIdleTimeout * 1000) - clock.tick(400L) + clock.advance(400L) onExecutorIdle(manager, "3") assert(removeTimes(manager)("3") !== firstRemoveTime) assert(removeTimes(manager)("3") === clock.getTimeMillis + executorIdleTimeout * 1000) @@ -360,7 +397,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(removeTimes(manager).contains("3")) // Restart remove timer - clock.tick(1000L) + clock.advance(1000L) onExecutorBusy(manager, "1") assert(removeTimes(manager).size === 2) onExecutorIdle(manager, "1") @@ -376,7 +413,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { test("mock polling loop with no events") { sc = createSparkContext(1, 20) val manager = sc.executorAllocationManager.get - val clock = new TestClock(2020L) + val clock = new ManualClock(2020L) manager.setClock(clock) // No events - we should not be adding or removing @@ -385,15 +422,15 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { schedule(manager) assert(numExecutorsPending(manager) === 0) assert(executorsPendingToRemove(manager).isEmpty) - clock.tick(100L) + clock.advance(100L) schedule(manager) assert(numExecutorsPending(manager) === 0) assert(executorsPendingToRemove(manager).isEmpty) - clock.tick(1000L) + clock.advance(1000L) schedule(manager) assert(numExecutorsPending(manager) === 0) assert(executorsPendingToRemove(manager).isEmpty) - clock.tick(10000L) + clock.advance(10000L) schedule(manager) assert(numExecutorsPending(manager) === 0) assert(executorsPendingToRemove(manager).isEmpty) @@ -401,57 +438,57 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { test("mock polling loop add behavior") { sc = createSparkContext(1, 20) - val clock = new TestClock(2020L) + val clock = new ManualClock(2020L) val manager = sc.executorAllocationManager.get manager.setClock(clock) sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) // Scheduler queue backlogged onSchedulerBacklogged(manager) - clock.tick(schedulerBacklogTimeout * 1000 / 2) + clock.advance(schedulerBacklogTimeout * 1000 / 2) schedule(manager) assert(numExecutorsPending(manager) === 0) // timer not exceeded yet - clock.tick(schedulerBacklogTimeout * 1000) + clock.advance(schedulerBacklogTimeout * 1000) schedule(manager) assert(numExecutorsPending(manager) === 1) // first timer exceeded - clock.tick(sustainedSchedulerBacklogTimeout * 1000 / 2) + clock.advance(sustainedSchedulerBacklogTimeout * 1000 / 2) schedule(manager) assert(numExecutorsPending(manager) === 1) // second timer not exceeded yet - clock.tick(sustainedSchedulerBacklogTimeout * 1000) + clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) assert(numExecutorsPending(manager) === 1 + 2) // second timer exceeded - clock.tick(sustainedSchedulerBacklogTimeout * 1000) + clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) assert(numExecutorsPending(manager) === 1 + 2 + 4) // third timer exceeded // Scheduler queue drained onSchedulerQueueEmpty(manager) - clock.tick(sustainedSchedulerBacklogTimeout * 1000) + clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) assert(numExecutorsPending(manager) === 7) // timer is canceled - clock.tick(sustainedSchedulerBacklogTimeout * 1000) + clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) assert(numExecutorsPending(manager) === 7) // Scheduler queue backlogged again onSchedulerBacklogged(manager) - clock.tick(schedulerBacklogTimeout * 1000) + clock.advance(schedulerBacklogTimeout * 1000) schedule(manager) assert(numExecutorsPending(manager) === 7 + 1) // timer restarted - clock.tick(sustainedSchedulerBacklogTimeout * 1000) + clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) assert(numExecutorsPending(manager) === 7 + 1 + 2) - clock.tick(sustainedSchedulerBacklogTimeout * 1000) + clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) assert(numExecutorsPending(manager) === 7 + 1 + 2 + 4) - clock.tick(sustainedSchedulerBacklogTimeout * 1000) + clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) assert(numExecutorsPending(manager) === 20) // limit reached } test("mock polling loop remove behavior") { sc = createSparkContext(1, 20) - val clock = new TestClock(2020L) + val clock = new ManualClock(2020L) val manager = sc.executorAllocationManager.get manager.setClock(clock) @@ -461,11 +498,11 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { onExecutorAdded(manager, "executor-3") assert(removeTimes(manager).size === 3) assert(executorsPendingToRemove(manager).isEmpty) - clock.tick(executorIdleTimeout * 1000 / 2) + clock.advance(executorIdleTimeout * 1000 / 2) schedule(manager) assert(removeTimes(manager).size === 3) // idle threshold not reached yet assert(executorsPendingToRemove(manager).isEmpty) - clock.tick(executorIdleTimeout * 1000) + clock.advance(executorIdleTimeout * 1000) schedule(manager) assert(removeTimes(manager).isEmpty) // idle threshold exceeded assert(executorsPendingToRemove(manager).size === 2) // limit reached (1 executor remaining) @@ -486,7 +523,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(!removeTimes(manager).contains("executor-5")) assert(!removeTimes(manager).contains("executor-6")) assert(executorsPendingToRemove(manager).size === 2) - clock.tick(executorIdleTimeout * 1000) + clock.advance(executorIdleTimeout * 1000) schedule(manager) assert(removeTimes(manager).isEmpty) // idle executors are removed assert(executorsPendingToRemove(manager).size === 4) @@ -504,7 +541,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(removeTimes(manager).contains("executor-5")) assert(removeTimes(manager).contains("executor-6")) assert(executorsPendingToRemove(manager).size === 4) - clock.tick(executorIdleTimeout * 1000) + clock.advance(executorIdleTimeout * 1000) schedule(manager) assert(removeTimes(manager).isEmpty) assert(executorsPendingToRemove(manager).size === 6) // limit reached (1 executor remaining) @@ -578,30 +615,28 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(removeTimes(manager).isEmpty) // New executors have registered - sc.listenerBus.postToAll(SparkListenerBlockManagerAdded( - 0L, BlockManagerId("executor-1", "host1", 1), 100L)) + sc.listenerBus.postToAll(SparkListenerExecutorAdded( + 0L, "executor-1", new ExecutorInfo("host1", 1, Map.empty))) assert(executorIds(manager).size === 1) assert(executorIds(manager).contains("executor-1")) assert(removeTimes(manager).size === 1) assert(removeTimes(manager).contains("executor-1")) - sc.listenerBus.postToAll(SparkListenerBlockManagerAdded( - 0L, BlockManagerId("executor-2", "host2", 1), 100L)) + sc.listenerBus.postToAll(SparkListenerExecutorAdded( + 0L, "executor-2", new ExecutorInfo("host2", 1, Map.empty))) assert(executorIds(manager).size === 2) assert(executorIds(manager).contains("executor-2")) assert(removeTimes(manager).size === 2) assert(removeTimes(manager).contains("executor-2")) // Existing executors have disconnected - sc.listenerBus.postToAll(SparkListenerBlockManagerRemoved( - 0L, BlockManagerId("executor-1", "host1", 1))) + sc.listenerBus.postToAll(SparkListenerExecutorRemoved(0L, "executor-1", "")) assert(executorIds(manager).size === 1) assert(!executorIds(manager).contains("executor-1")) assert(removeTimes(manager).size === 1) assert(!removeTimes(manager).contains("executor-1")) // Unknown executor has disconnected - sc.listenerBus.postToAll(SparkListenerBlockManagerRemoved( - 0L, BlockManagerId("executor-3", "host3", 1))) + sc.listenerBus.postToAll(SparkListenerExecutorRemoved(0L, "executor-3", "")) assert(executorIds(manager).size === 1) assert(removeTimes(manager).size === 1) } @@ -613,8 +648,8 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(removeTimes(manager).isEmpty) sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) - sc.listenerBus.postToAll(SparkListenerBlockManagerAdded( - 0L, BlockManagerId("executor-1", "host1", 1), 100L)) + sc.listenerBus.postToAll(SparkListenerExecutorAdded( + 0L, "executor-1", new ExecutorInfo("host1", 1, Map.empty))) assert(executorIds(manager).size === 1) assert(executorIds(manager).contains("executor-1")) assert(removeTimes(manager).size === 0) @@ -625,32 +660,22 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { val manager = sc.executorAllocationManager.get assert(executorIds(manager).isEmpty) assert(removeTimes(manager).isEmpty) - sc.listenerBus.postToAll(SparkListenerBlockManagerAdded( - 0L, BlockManagerId("executor-1", "host1", 1), 100L)) + sc.listenerBus.postToAll(SparkListenerExecutorAdded( + 0L, "executor-1", new ExecutorInfo("host1", 1, Map.empty))) sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) assert(executorIds(manager).size === 1) assert(executorIds(manager).contains("executor-1")) assert(removeTimes(manager).size === 0) - sc.listenerBus.postToAll(SparkListenerBlockManagerAdded( - 0L, BlockManagerId("executor-2", "host1", 1), 100L)) + sc.listenerBus.postToAll(SparkListenerExecutorAdded( + 0L, "executor-2", new ExecutorInfo("host1", 1, Map.empty))) assert(executorIds(manager).size === 2) assert(executorIds(manager).contains("executor-2")) assert(removeTimes(manager).size === 1) assert(removeTimes(manager).contains("executor-2")) assert(!removeTimes(manager).contains("executor-1")) } -} - -/** - * Helper methods for testing ExecutorAllocationManager. - * This includes methods to access private methods and fields in ExecutorAllocationManager. - */ -private object ExecutorAllocationManagerSuite extends PrivateMethodTester { - private val schedulerBacklogTimeout = 1L - private val sustainedSchedulerBacklogTimeout = 2L - private val executorIdleTimeout = 3L private def createSparkContext(minExecutors: Int = 1, maxExecutors: Int = 5): SparkContext = { val conf = new SparkConf() @@ -664,9 +689,22 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { sustainedSchedulerBacklogTimeout.toString) .set("spark.dynamicAllocation.executorIdleTimeout", executorIdleTimeout.toString) .set("spark.dynamicAllocation.testing", "true") - new SparkContext(conf) + val sc = new SparkContext(conf) + contexts += sc + sc } +} + +/** + * Helper methods for testing ExecutorAllocationManager. + * This includes methods to access private methods and fields in ExecutorAllocationManager. + */ +private object ExecutorAllocationManagerSuite extends PrivateMethodTester { + private val schedulerBacklogTimeout = 1L + private val sustainedSchedulerBacklogTimeout = 2L + private val executorIdleTimeout = 3L + private def createStageInfo(stageId: Int, numTasks: Int): StageInfo = { new StageInfo(stageId, 0, "name", numTasks, Seq.empty, "no details") } @@ -681,6 +719,7 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { private val _numExecutorsToAdd = PrivateMethod[Int]('numExecutorsToAdd) private val _numExecutorsPending = PrivateMethod[Int]('numExecutorsPending) + private val _maxNumExecutorsNeeded = PrivateMethod[Int]('maxNumExecutorsNeeded) private val _executorsPendingToRemove = PrivateMethod[collection.Set[String]]('executorsPendingToRemove) private val _executorIds = PrivateMethod[collection.Set[String]]('executorIds) @@ -688,6 +727,7 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { private val _removeTimes = PrivateMethod[collection.Map[String, Long]]('removeTimes) private val _schedule = PrivateMethod[Unit]('schedule) private val _addExecutors = PrivateMethod[Int]('addExecutors) + private val _addOrCancelExecutorRequests = PrivateMethod[Int]('addOrCancelExecutorRequests) private val _removeExecutor = PrivateMethod[Boolean]('removeExecutor) private val _onExecutorAdded = PrivateMethod[Unit]('onExecutorAdded) private val _onExecutorRemoved = PrivateMethod[Unit]('onExecutorRemoved) @@ -726,7 +766,12 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { } private def addExecutors(manager: ExecutorAllocationManager): Int = { - manager invokePrivate _addExecutors() + val maxNumExecutorsNeeded = manager invokePrivate _maxNumExecutorsNeeded() + manager invokePrivate _addExecutors(maxNumExecutorsNeeded) + } + + private def adjustRequestedExecutors(manager: ExecutorAllocationManager): Int = { + manager invokePrivate _addOrCancelExecutorRequests(0L) } private def removeExecutor(manager: ExecutorAllocationManager, id: String): Boolean = { diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 5e24196101fbc..7acd27c735727 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -32,7 +32,6 @@ import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInp import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} import org.scalatest.FunSuite -import org.apache.spark.SparkContext._ import org.apache.spark.rdd.{NewHadoopRDD, HadoopRDD} import org.apache.spark.util.Utils diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index f57921b768310..30b6184c77839 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -242,14 +242,14 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex shuffleSpillCompress <- Set(true, false); shuffleCompress <- Set(true, false) ) { - val conf = new SparkConf() + val myConf = conf.clone() .setAppName("test") .setMaster("local") .set("spark.shuffle.spill.compress", shuffleSpillCompress.toString) .set("spark.shuffle.compress", shuffleCompress.toString) .set("spark.shuffle.memoryFraction", "0.001") resetSparkContext() - sc = new SparkContext(conf) + sc = new SparkContext(myConf) try { sc.parallelize(0 until 100000).map(i => (i / 4, i)).groupByKey().collect() } catch { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 8b3c6871a7b39..26d10195d47b8 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -17,9 +17,19 @@ package org.apache.spark +import java.io.File +import java.util.concurrent.TimeUnit + +import com.google.common.base.Charsets._ +import com.google.common.io.Files + import org.scalatest.FunSuite import org.apache.hadoop.io.BytesWritable +import org.apache.spark.util.Utils + +import scala.concurrent.Await +import scala.concurrent.duration.Duration class SparkContextSuite extends FunSuite with LocalSparkContext { @@ -72,4 +82,112 @@ class SparkContextSuite extends FunSuite with LocalSparkContext { val byteArray2 = converter.convert(bytesWritable) assert(byteArray2.length === 0) } + + test("addFile works") { + val file1 = File.createTempFile("someprefix1", "somesuffix1") + val absolutePath1 = file1.getAbsolutePath + + val pluto = Utils.createTempDir() + val file2 = File.createTempFile("someprefix2", "somesuffix2", pluto) + val relativePath = file2.getParent + "/../" + file2.getParentFile.getName + "/" + file2.getName + val absolutePath2 = file2.getAbsolutePath + + try { + Files.write("somewords1", file1, UTF_8) + Files.write("somewords2", file2, UTF_8) + val length1 = file1.length() + val length2 = file2.length() + + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + sc.addFile(file1.getAbsolutePath) + sc.addFile(relativePath) + sc.parallelize(Array(1), 1).map(x => { + val gotten1 = new File(SparkFiles.get(file1.getName)) + val gotten2 = new File(SparkFiles.get(file2.getName)) + if (!gotten1.exists()) { + throw new SparkException("file doesn't exist : " + absolutePath1) + } + if (!gotten2.exists()) { + throw new SparkException("file doesn't exist : " + absolutePath2) + } + + if (length1 != gotten1.length()) { + throw new SparkException( + s"file has different length $length1 than added file ${gotten1.length()} : " + absolutePath1) + } + if (length2 != gotten2.length()) { + throw new SparkException( + s"file has different length $length2 than added file ${gotten2.length()} : " + absolutePath2) + } + + if (absolutePath1 == gotten1.getAbsolutePath) { + throw new SparkException("file should have been copied :" + absolutePath1) + } + if (absolutePath2 == gotten2.getAbsolutePath) { + throw new SparkException("file should have been copied : " + absolutePath2) + } + x + }).count() + } finally { + sc.stop() + } + } + + test("addFile recursive works") { + val pluto = Utils.createTempDir() + val neptune = Utils.createTempDir(pluto.getAbsolutePath) + val saturn = Utils.createTempDir(neptune.getAbsolutePath) + val alien1 = File.createTempFile("alien", "1", neptune) + val alien2 = File.createTempFile("alien", "2", saturn) + + try { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + sc.addFile(neptune.getAbsolutePath, true) + sc.parallelize(Array(1), 1).map(x => { + val sep = File.separator + if (!new File(SparkFiles.get(neptune.getName + sep + alien1.getName)).exists()) { + throw new SparkException("can't access file under root added directory") + } + if (!new File(SparkFiles.get(neptune.getName + sep + saturn.getName + sep + alien2.getName)) + .exists()) { + throw new SparkException("can't access file in nested directory") + } + if (new File(SparkFiles.get(pluto.getName + sep + neptune.getName + sep + alien1.getName)) + .exists()) { + throw new SparkException("file exists that shouldn't") + } + x + }).count() + } finally { + sc.stop() + } + } + + test("addFile recursive can't add directories by default") { + val dir = Utils.createTempDir() + + try { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + intercept[SparkException] { + sc.addFile(dir.getAbsolutePath) + } + } finally { + sc.stop() + } + } + + test("Cancelling job group should not cause SparkContext to shutdown (SPARK-6414)") { + try { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + val future = sc.parallelize(Seq(0)).foreachAsync(_ => {Thread.sleep(1000L)}) + sc.cancelJobGroup("nonExistGroupId") + Await.ready(future, Duration(2, TimeUnit.SECONDS)) + + // In SPARK-6414, sc.cancelJobGroup will cause NullPointerException and cause + // SparkContext to shutdown, so the following assertion will fail. + assert(sc.parallelize(1 to 10).count() == 10L) + } finally { + sc.stop() + } + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index aa65f7e8915e6..68b5776fc6515 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -68,7 +68,8 @@ class JsonProtocolSuite extends FunSuite { val completedApps = Array[ApplicationInfo]() val activeDrivers = Array(createDriverInfo()) val completedDrivers = Array(createDriverInfo()) - val stateResponse = new MasterStateResponse("host", 8080, workers, activeApps, completedApps, + val stateResponse = new MasterStateResponse( + "host", 8080, None, workers, activeApps, completedApps, activeDrivers, completedDrivers, RecoveryState.ALIVE) val output = JsonProtocol.writeMasterState(stateResponse) assertValidJson(output) @@ -117,8 +118,8 @@ class JsonProtocolSuite extends FunSuite { } def createExecutorRunner(): ExecutorRunner = { - new ExecutorRunner("appId", 123, createAppDesc(), 4, 1234, null, "workerId", "host", - new File("sparkHome"), new File("workDir"), "akka://worker", + new ExecutorRunner("appId", 123, createAppDesc(), 4, 1234, null, "workerId", "host", 123, + "publicAddress", new File("sparkHome"), new File("workDir"), "akka://worker", new SparkConf, Seq("localDir"), ExecutorState.RUNNING) } diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala new file mode 100644 index 0000000000000..54dd7c9c45c61 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.net.URL + +import scala.collection.mutable +import scala.io.Source + +import org.scalatest.FunSuite + +import org.apache.spark.scheduler.cluster.ExecutorInfo +import org.apache.spark.scheduler.{SparkListenerExecutorAdded, SparkListener} +import org.apache.spark.{SparkConf, SparkContext, LocalSparkContext} + +class LogUrlsStandaloneSuite extends FunSuite with LocalSparkContext { + + /** Length of time to wait while draining listener events. */ + private val WAIT_TIMEOUT_MILLIS = 10000 + + test("verify that correct log urls get propagated from workers") { + sc = new SparkContext("local-cluster[2,1,512]", "test") + + val listener = new SaveExecutorInfo + sc.addSparkListener(listener) + + // Trigger a job so that executors get added + sc.parallelize(1 to 100, 4).map(_.toString).count() + + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + listener.addedExecutorInfos.values.foreach { info => + assert(info.logUrlMap.nonEmpty) + // Browse to each URL to check that it's valid + info.logUrlMap.foreach { case (logType, logUrl) => + val html = Source.fromURL(logUrl).mkString + assert(html.contains(s"$logType log page")) + } + } + } + + test("verify that log urls reflect SPARK_PUBLIC_DNS (SPARK-6175)") { + val SPARK_PUBLIC_DNS = "public_dns" + class MySparkConf extends SparkConf(false) { + override def getenv(name: String) = { + if (name == "SPARK_PUBLIC_DNS") SPARK_PUBLIC_DNS + else super.getenv(name) + } + + override def clone: SparkConf = { + new MySparkConf().setAll(getAll) + } + } + val conf = new MySparkConf() + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + + val listener = new SaveExecutorInfo + sc.addSparkListener(listener) + + // Trigger a job so that executors get added + sc.parallelize(1 to 100, 4).map(_.toString).count() + + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + listener.addedExecutorInfos.values.foreach { info => + assert(info.logUrlMap.nonEmpty) + info.logUrlMap.values.foreach { logUrl => + assert(new URL(logUrl).getHost === SPARK_PUBLIC_DNS) + } + } + } + + private class SaveExecutorInfo extends SparkListener { + val addedExecutorInfos = mutable.Map[String, ExecutorInfo]() + + override def onExecutorAdded(executor: SparkListenerExecutorAdded) { + addedExecutorInfos(executor.executorId) = executor.executorInfo + } + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 82628ad3abd99..46d745c4ecbfa 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -21,6 +21,8 @@ import java.io._ import scala.collection.mutable.ArrayBuffer +import com.google.common.base.Charsets.UTF_8 +import com.google.common.io.ByteStreams import org.scalatest.FunSuite import org.scalatest.Matchers import org.scalatest.concurrent.Timeouts @@ -141,7 +143,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) + val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) val childArgsStr = childArgs.mkString(" ") childArgsStr should include ("--class org.SomeClass") childArgsStr should include ("--executor-memory 5g") @@ -180,7 +182,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) + val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (4) @@ -201,6 +203,18 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties } test("handles standalone cluster mode") { + testStandaloneCluster(useRest = true) + } + + test("handles legacy standalone cluster mode") { + testStandaloneCluster(useRest = false) + } + + /** + * Test whether the launch environment is correctly set up in standalone cluster mode. + * @param useRest whether to use the REST submission gateway introduced in Spark 1.3 + */ + private def testStandaloneCluster(useRest: Boolean): Unit = { val clArgs = Seq( "--deploy-mode", "cluster", "--master", "spark://h:p", @@ -212,17 +226,26 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) + appArgs.useRest = useRest + val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) val childArgsStr = childArgs.mkString(" ") - childArgsStr should startWith ("--memory 4g --cores 5 --supervise") - childArgsStr should include regex ("launch spark://h:p .*thejar.jar org.SomeClass arg1 arg2") - mainClass should be ("org.apache.spark.deploy.Client") - classpath should have size (0) - sysProps should have size (5) + if (useRest) { + childArgsStr should endWith ("thejar.jar org.SomeClass arg1 arg2") + mainClass should be ("org.apache.spark.deploy.rest.StandaloneRestClient") + } else { + childArgsStr should startWith ("--supervise --memory 4g --cores 5") + childArgsStr should include regex "launch spark://h:p .*thejar.jar org.SomeClass arg1 arg2" + mainClass should be ("org.apache.spark.deploy.Client") + } + classpath should have size 0 + sysProps should have size 8 sysProps.keys should contain ("SPARK_SUBMIT") sysProps.keys should contain ("spark.master") sysProps.keys should contain ("spark.app.name") sysProps.keys should contain ("spark.jars") + sysProps.keys should contain ("spark.driver.memory") + sysProps.keys should contain ("spark.driver.cores") + sysProps.keys should contain ("spark.driver.supervise") sysProps.keys should contain ("spark.shuffle.spill") sysProps("spark.shuffle.spill") should be ("false") } @@ -239,7 +262,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) + val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (1) @@ -261,7 +284,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) + val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (1) @@ -281,7 +304,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, sysProps, mainClass) = createLaunchEnv(appArgs) + val (_, _, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) sysProps("spark.executor.memory") should be ("5g") sysProps("spark.master") should be ("yarn-cluster") mainClass should be ("org.apache.spark.deploy.yarn.Client") @@ -307,7 +330,21 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "--name", "testApp", "--master", "local-cluster[2,1,512]", "--jars", jarsString, - unusedJar.toString) + unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") + runSparkSubmit(args) + } + + test("includes jars passed in through --packages") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val packagesString = "com.databricks:spark-csv_2.10:0.1,com.databricks:spark-avro_2.10:0.1" + val args = Seq( + "--class", JarCreationTest.getClass.getName.stripSuffix("$"), + "--name", "testApp", + "--master", "local-cluster[2,1,512]", + "--packages", packagesString, + "--conf", "spark.ui.enabled=false", + unusedJar.toString, + "com.databricks.spark.csv.DefaultSource", "com.databricks.spark.avro.DefaultSource") runSparkSubmit(args) } @@ -325,7 +362,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "--files", files, "thejar.jar") val appArgs = new SparkSubmitArguments(clArgs) - val sysProps = SparkSubmit.createLaunchEnv(appArgs)._3 + val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3 appArgs.jars should be (Utils.resolveURIs(jars)) appArgs.files should be (Utils.resolveURIs(files)) sysProps("spark.jars") should be (Utils.resolveURIs(jars + ",thejar.jar")) @@ -340,7 +377,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar" ) val appArgs2 = new SparkSubmitArguments(clArgs2) - val sysProps2 = SparkSubmit.createLaunchEnv(appArgs2)._3 + val sysProps2 = SparkSubmit.prepareSubmitEnvironment(appArgs2)._3 appArgs2.files should be (Utils.resolveURIs(files)) appArgs2.archives should be (Utils.resolveURIs(archives)) sysProps2("spark.yarn.dist.files") should be (Utils.resolveURIs(files)) @@ -353,7 +390,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "mister.py" ) val appArgs3 = new SparkSubmitArguments(clArgs3) - val sysProps3 = SparkSubmit.createLaunchEnv(appArgs3)._3 + val sysProps3 = SparkSubmit.prepareSubmitEnvironment(appArgs3)._3 appArgs3.pyFiles should be (Utils.resolveURIs(pyFiles)) sysProps3("spark.submit.pyFiles") should be ( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) @@ -378,7 +415,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar" ) val appArgs = new SparkSubmitArguments(clArgs) - val sysProps = SparkSubmit.createLaunchEnv(appArgs)._3 + val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3 sysProps("spark.jars") should be(Utils.resolveURIs(jars + ",thejar.jar")) sysProps("spark.files") should be(Utils.resolveURIs(files)) @@ -395,7 +432,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar" ) val appArgs2 = new SparkSubmitArguments(clArgs2) - val sysProps2 = SparkSubmit.createLaunchEnv(appArgs2)._3 + val sysProps2 = SparkSubmit.prepareSubmitEnvironment(appArgs2)._3 sysProps2("spark.yarn.dist.files") should be(Utils.resolveURIs(files)) sysProps2("spark.yarn.dist.archives") should be(Utils.resolveURIs(archives)) @@ -410,11 +447,24 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "mister.py" ) val appArgs3 = new SparkSubmitArguments(clArgs3) - val sysProps3 = SparkSubmit.createLaunchEnv(appArgs3)._3 + val sysProps3 = SparkSubmit.prepareSubmitEnvironment(appArgs3)._3 sysProps3("spark.submit.pyFiles") should be( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) } + test("user classpath first in driver") { + val systemJar = TestUtils.createJarWithFiles(Map("test.resource" -> "SYSTEM")) + val userJar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER")) + val args = Seq( + "--class", UserClasspathFirstTest.getClass.getName.stripSuffix("$"), + "--name", "testApp", + "--master", "local", + "--conf", "spark.driver.extraClassPath=" + systemJar, + "--conf", "spark.driver.userClassPathFirst=true", + userJar.toString) + runSparkSubmit(args) + } + test("SPARK_CONF_DIR overrides spark-defaults.conf") { forConfDir(Map("spark.executor.memory" -> "2.3g")) { path => val unusedJar = TestUtils.createJarWithClasses(Seq.empty) @@ -426,7 +476,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties val appArgs = new SparkSubmitArguments(args, Map("SPARK_CONF_DIR" -> path)) assert(appArgs.propertiesFile != null) assert(appArgs.propertiesFile.startsWith(path)) - appArgs.executorMemory should be ("2.3g") + appArgs.executorMemory should be ("2.3g") } } @@ -467,8 +517,8 @@ object JarCreationTest extends Logging { val result = sc.makeRDD(1 to 100, 10).mapPartitions { x => var exception: String = null try { - Class.forName("SparkSubmitClassA", true, Thread.currentThread().getContextClassLoader) - Class.forName("SparkSubmitClassA", true, Thread.currentThread().getContextClassLoader) + Class.forName(args(0), true, Thread.currentThread().getContextClassLoader) + Class.forName(args(1), true, Thread.currentThread().getContextClassLoader) } catch { case t: Throwable => exception = t + "\n" + t.getStackTraceString @@ -506,3 +556,15 @@ object SimpleApplicationTest { } } } + +object UserClasspathFirstTest { + def main(args: Array[String]) { + val ccl = Thread.currentThread().getContextClassLoader() + val resource = ccl.getResourceAsStream("test.resource") + val bytes = ByteStreams.toByteArray(resource) + val contents = new String(bytes, 0, bytes.length, UTF_8) + if (contents != "USER") { + throw new SparkException("Should have read user resource, but instead read: " + contents) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala new file mode 100644 index 0000000000000..8bcca926097a1 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io.{PrintStream, OutputStream, File} + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.ivy.core.module.descriptor.MDArtifact +import org.apache.ivy.plugins.resolver.IBiblioResolver + +class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { + + private val noOpOutputStream = new OutputStream { + def write(b: Int) = {} + } + + /** Simple PrintStream that reads data into a buffer */ + private class BufferPrintStream extends PrintStream(noOpOutputStream) { + var lineBuffer = ArrayBuffer[String]() + override def println(line: String) { + lineBuffer += line + } + } + + override def beforeAll() { + super.beforeAll() + // We don't want to write logs during testing + SparkSubmitUtils.printStream = new BufferPrintStream + } + + test("incorrect maven coordinate throws error") { + val coordinates = Seq("a:b: ", " :a:b", "a: :b", "a:b:", ":a:b", "a::b", "::", "a:b", "a") + for (coordinate <- coordinates) { + intercept[IllegalArgumentException] { + SparkSubmitUtils.extractMavenCoordinates(coordinate) + } + } + } + + test("create repo resolvers") { + val resolver1 = SparkSubmitUtils.createRepoResolvers(None) + // should have central and spark-packages by default + assert(resolver1.getResolvers.size() === 2) + assert(resolver1.getResolvers.get(0).asInstanceOf[IBiblioResolver].getName === "central") + assert(resolver1.getResolvers.get(1).asInstanceOf[IBiblioResolver].getName === "spark-packages") + + val repos = "a/1,b/2,c/3" + val resolver2 = SparkSubmitUtils.createRepoResolvers(Option(repos)) + assert(resolver2.getResolvers.size() === 5) + val expected = repos.split(",").map(r => s"$r/") + resolver2.getResolvers.toArray.zipWithIndex.foreach { case (resolver: IBiblioResolver, i) => + if (i == 0) { + assert(resolver.getName === "central") + } else if (i == 1) { + assert(resolver.getName === "spark-packages") + } else { + assert(resolver.getName === s"repo-${i - 1}") + assert(resolver.getRoot === expected(i - 2)) + } + } + } + + test("add dependencies works correctly") { + val md = SparkSubmitUtils.getModuleDescriptor + val artifacts = SparkSubmitUtils.extractMavenCoordinates("com.databricks:spark-csv_2.10:0.1," + + "com.databricks:spark-avro_2.10:0.1") + + SparkSubmitUtils.addDependenciesToIvy(md, artifacts, "default") + assert(md.getDependencies.length === 2) + } + + test("ivy path works correctly") { + val ivyPath = "dummy/ivy" + val md = SparkSubmitUtils.getModuleDescriptor + val artifacts = for (i <- 0 until 3) yield new MDArtifact(md, s"jar-$i", "jar", "jar") + var jPaths = SparkSubmitUtils.resolveDependencyPaths(artifacts.toArray, new File(ivyPath)) + for (i <- 0 until 3) { + val index = jPaths.indexOf(ivyPath) + assert(index >= 0) + jPaths = jPaths.substring(index + ivyPath.length) + } + // end to end + val jarPath = SparkSubmitUtils.resolveMavenCoordinates( + "com.databricks:spark-csv_2.10:0.1", None, Option(ivyPath), true) + assert(jarPath.indexOf(ivyPath) >= 0, "should use non-default ivy path") + } + + test("search for artifact at other repositories") { + val path = SparkSubmitUtils.resolveMavenCoordinates("com.agimatec:agimatec-validation:0.9.3", + Option("https://oss.sonatype.org/content/repositories/agimatec/"), None, true) + assert(path.indexOf("agimatec-validation") >= 0, "should find package. If it doesn't, check" + + "if package still exists. If it has been removed, replace the example in this test.") + } + + test("dependency not found throws RuntimeException") { + intercept[RuntimeException] { + SparkSubmitUtils.resolveMavenCoordinates("a:b:c", None, None, true) + } + } + + test("neglects Spark and Spark's dependencies") { + val components = Seq("bagel_", "catalyst_", "core_", "graphx_", "hive_", "mllib_", "repl_", + "sql_", "streaming_", "yarn_", "network-common_", "network-shuffle_", "network-yarn_") + + val coordinates = + components.map(comp => s"org.apache.spark:spark-${comp}2.10:1.2.0").mkString(",") + + ",org.apache.spark:spark-core_fake:1.2.0" + + val path = SparkSubmitUtils.resolveMavenCoordinates(coordinates, None, None, true) + assert(path === "", "should return empty path") + // Should not exclude the following dependency. Will throw an error, because it doesn't exist, + // but the fact that it is checking means that it wasn't excluded. + intercept[RuntimeException] { + SparkSubmitUtils.resolveMavenCoordinates(coordinates + + ",org.apache.spark:spark-streaming-kafka-assembly_2.10:1.2.0", None, None, true) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 3fbc1a21d10ed..fcae603c7d18e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -17,18 +17,17 @@ package org.apache.spark.deploy.history -import java.io.{File, FileOutputStream, OutputStreamWriter} +import java.io.{BufferedOutputStream, File, FileOutputStream, OutputStreamWriter} +import java.net.URI import scala.io.Source -import com.google.common.io.Files import org.apache.hadoop.fs.Path import org.json4s.jackson.JsonMethods._ import org.scalatest.{BeforeAndAfter, FunSuite} import org.scalatest.Matchers import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io._ import org.apache.spark.scheduler._ import org.apache.spark.util.{JsonProtocol, Utils} @@ -37,54 +36,67 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers private var testDir: File = null - private var provider: FsHistoryProvider = null - before { testDir = Utils.createTempDir() - provider = new FsHistoryProvider(new SparkConf() - .set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) - .set("spark.history.fs.updateInterval", "0")) } after { Utils.deleteRecursively(testDir) } + /** Create a fake log file using the new log format used in Spark 1.3+ */ + private def newLogFile( + appId: String, + inProgress: Boolean, + codec: Option[String] = None): File = { + val ip = if (inProgress) EventLoggingListener.IN_PROGRESS else "" + val logUri = EventLoggingListener.getLogPath(testDir.toURI, appId) + val logPath = new URI(logUri).getPath + ip + new File(logPath) + } + test("Parse new and old application logs") { - val conf = new SparkConf() - .set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) - .set("spark.history.fs.updateInterval", "0") - val provider = new FsHistoryProvider(conf) + val provider = new FsHistoryProvider(createTestConf()) // Write a new-style application log. - val logFile1 = new File(testDir, "new1") - writeFile(logFile1, true, None, - SparkListenerApplicationStart("app1-1", None, 1L, "test"), - SparkListenerApplicationEnd(2L) + val newAppComplete = newLogFile("new1", inProgress = false) + writeFile(newAppComplete, true, None, + SparkListenerApplicationStart("new-app-complete", None, 1L, "test"), + SparkListenerApplicationEnd(5L) ) + // Write a new-style application log. + val newAppCompressedComplete = newLogFile("new1compressed", inProgress = false, Some("lzf")) + writeFile(newAppCompressedComplete, true, None, + SparkListenerApplicationStart("new-app-compressed-complete", None, 1L, "test"), + SparkListenerApplicationEnd(4L)) + // Write an unfinished app, new-style. - val logFile2 = new File(testDir, "new2" + EventLoggingListener.IN_PROGRESS) - writeFile(logFile2, true, None, - SparkListenerApplicationStart("app2-2", None, 1L, "test") + val newAppIncomplete = newLogFile("new2", inProgress = true) + writeFile(newAppIncomplete, true, None, + SparkListenerApplicationStart("new-app-incomplete", None, 1L, "test") ) // Write an old-style application log. - val oldLog = new File(testDir, "old1") - oldLog.mkdir() - createEmptyFile(new File(oldLog, provider.SPARK_VERSION_PREFIX + "1.0")) - writeFile(new File(oldLog, provider.LOG_PREFIX + "1"), false, None, - SparkListenerApplicationStart("app3", None, 2L, "test"), + val oldAppComplete = new File(testDir, "old1") + oldAppComplete.mkdir() + createEmptyFile(new File(oldAppComplete, provider.SPARK_VERSION_PREFIX + "1.0")) + writeFile(new File(oldAppComplete, provider.LOG_PREFIX + "1"), false, None, + SparkListenerApplicationStart("old-app-complete", None, 2L, "test"), SparkListenerApplicationEnd(3L) ) - createEmptyFile(new File(oldLog, provider.APPLICATION_COMPLETE)) + createEmptyFile(new File(oldAppComplete, provider.APPLICATION_COMPLETE)) + + // Check for logs so that we force the older unfinished app to be loaded, to make + // sure unfinished apps are also sorted correctly. + provider.checkForLogs() // Write an unfinished app, old-style. - val oldLog2 = new File(testDir, "old2") - oldLog2.mkdir() - createEmptyFile(new File(oldLog2, provider.SPARK_VERSION_PREFIX + "1.0")) - writeFile(new File(oldLog2, provider.LOG_PREFIX + "1"), false, None, - SparkListenerApplicationStart("app4", None, 2L, "test") + val oldAppIncomplete = new File(testDir, "old2") + oldAppIncomplete.mkdir() + createEmptyFile(new File(oldAppIncomplete, provider.SPARK_VERSION_PREFIX + "1.0")) + writeFile(new File(oldAppIncomplete, provider.LOG_PREFIX + "1"), false, None, + SparkListenerApplicationStart("old-app-incomplete", None, 2L, "test") ) // Force a reload of data from the log directory, and check that both logs are loaded. @@ -93,17 +105,19 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers val list = provider.getListing().toSeq list should not be (null) - list.size should be (4) - list.count(e => e.completed) should be (2) - - list(0) should be (ApplicationHistoryInfo(oldLog.getName(), "app3", 2L, 3L, - oldLog.lastModified(), "test", true)) - list(1) should be (ApplicationHistoryInfo(logFile1.getName(), "app1-1", 1L, 2L, - logFile1.lastModified(), "test", true)) - list(2) should be (ApplicationHistoryInfo(oldLog2.getName(), "app4", 2L, -1L, - oldLog2.lastModified(), "test", false)) - list(3) should be (ApplicationHistoryInfo(logFile2.getName(), "app2-2", 1L, -1L, - logFile2.lastModified(), "test", false)) + list.size should be (5) + list.count(_.completed) should be (3) + + list(0) should be (ApplicationHistoryInfo(newAppComplete.getName(), "new-app-complete", 1L, 5L, + newAppComplete.lastModified(), "test", true)) + list(1) should be (ApplicationHistoryInfo(newAppCompressedComplete.getName(), + "new-app-compressed-complete", 1L, 4L, newAppCompressedComplete.lastModified(), "test", true)) + list(2) should be (ApplicationHistoryInfo(oldAppComplete.getName(), "old-app-complete", 2L, 3L, + oldAppComplete.lastModified(), "test", true)) + list(3) should be (ApplicationHistoryInfo(oldAppIncomplete.getName(), "old-app-incomplete", 2L, + -1L, oldAppIncomplete.lastModified(), "test", false)) + list(4) should be (ApplicationHistoryInfo(newAppIncomplete.getName(), "new-app-incomplete", 1L, + -1L, newAppIncomplete.lastModified(), "test", false)) // Make sure the UI can be rendered. list.foreach { case info => @@ -113,6 +127,7 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers } test("Parse legacy logs with compression codec set") { + val provider = new FsHistoryProvider(createTestConf()) val testCodecs = List((classOf[LZFCompressionCodec].getName(), true), (classOf[SnappyCompressionCodec].getName(), true), ("invalid.codec", false)) @@ -130,7 +145,7 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers val logPath = new Path(logDir.getAbsolutePath()) try { - val (logInput, sparkVersion) = provider.openLegacyEventLog(logPath) + val logInput = provider.openLegacyEventLog(logPath) try { Source.fromInputStream(logInput).getLines().toSeq.size should be (2) } finally { @@ -144,22 +159,19 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers } test("SPARK-3697: ignore directories that cannot be read.") { - val logFile1 = new File(testDir, "new1") + val logFile1 = newLogFile("new1", inProgress = false) writeFile(logFile1, true, None, SparkListenerApplicationStart("app1-1", None, 1L, "test"), SparkListenerApplicationEnd(2L) ) - val logFile2 = new File(testDir, "new2") + val logFile2 = newLogFile("new2", inProgress = false) writeFile(logFile2, true, None, SparkListenerApplicationStart("app1-2", None, 1L, "test"), SparkListenerApplicationEnd(2L) ) logFile2.setReadable(false, false) - val conf = new SparkConf() - .set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) - .set("spark.history.fs.updateInterval", "0") - val provider = new FsHistoryProvider(conf) + val provider = new FsHistoryProvider(createTestConf()) provider.checkForLogs() val list = provider.getListing().toSeq @@ -168,12 +180,9 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers } test("history file is renamed from inprogress to completed") { - val conf = new SparkConf() - .set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) - .set("spark.testing", "true") - val provider = new FsHistoryProvider(conf) + val provider = new FsHistoryProvider(createTestConf()) - val logFile1 = new File(testDir, "app1" + EventLoggingListener.IN_PROGRESS) + val logFile1 = newLogFile("app1", inProgress = true) writeFile(logFile1, true, None, SparkListenerApplicationStart("app1", Some("app1"), 1L, "test"), SparkListenerApplicationEnd(2L) @@ -183,23 +192,38 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers appListBeforeRename.size should be (1) appListBeforeRename.head.logPath should endWith(EventLoggingListener.IN_PROGRESS) - logFile1.renameTo(new File(testDir, "app1")) + logFile1.renameTo(newLogFile("app1", inProgress = false)) provider.checkForLogs() val appListAfterRename = provider.getListing() appListAfterRename.size should be (1) appListAfterRename.head.logPath should not endWith(EventLoggingListener.IN_PROGRESS) } + test("SPARK-5582: empty log directory") { + val provider = new FsHistoryProvider(createTestConf()) + + val logFile1 = newLogFile("app1", inProgress = true) + writeFile(logFile1, true, None, + SparkListenerApplicationStart("app1", Some("app1"), 1L, "test"), + SparkListenerApplicationEnd(2L)) + + val oldLog = new File(testDir, "old1") + oldLog.mkdir() + + provider.checkForLogs() + val appListAfterRename = provider.getListing() + appListAfterRename.size should be (1) + } + private def writeFile(file: File, isNewFormat: Boolean, codec: Option[CompressionCodec], events: SparkListenerEvent*) = { - val out = - if (isNewFormat) { - EventLoggingListener.initEventLog(new FileOutputStream(file), codec) - } else { - val fileStream = new FileOutputStream(file) - codec.map(_.compressedOutputStream(fileStream)).getOrElse(fileStream) - } - val writer = new OutputStreamWriter(out, "UTF-8") + val fstream = new FileOutputStream(file) + val cstream = codec.map(_.compressedOutputStream(fstream)).getOrElse(fstream) + val bstream = new BufferedOutputStream(cstream) + if (isNewFormat) { + EventLoggingListener.initEventLog(new FileOutputStream(file)) + } + val writer = new OutputStreamWriter(bstream, "UTF-8") try { events.foreach(e => writer.write(compact(render(JsonProtocol.sparkEventToJson(e))) + "\n")) } finally { @@ -211,4 +235,8 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers new FileOutputStream(file).close() } + private def createTestConf(): SparkConf = { + new SparkConf().set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) + } + } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala new file mode 100644 index 0000000000000..3a9963a5ce7b7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.history + +import javax.servlet.http.HttpServletRequest + +import scala.collection.mutable + +import org.apache.hadoop.fs.Path +import org.mockito.Mockito.{when} +import org.scalatest.FunSuite +import org.scalatest.Matchers +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.ui.SparkUI + +class HistoryServerSuite extends FunSuite with Matchers with MockitoSugar { + + test("generate history page with relative links") { + val historyServer = mock[HistoryServer] + val request = mock[HttpServletRequest] + val ui = mock[SparkUI] + val link = "/history/app1" + val info = new ApplicationHistoryInfo("app1", "app1", 0, 2, 1, "xxx", true) + when(historyServer.getApplicationList()).thenReturn(Seq(info)) + when(ui.basePath).thenReturn(link) + when(historyServer.getProviderConfig()).thenReturn(Map[String, String]()) + val page = new HistoryPage(historyServer) + + //when + val response = page.render(request) + + //then + val links = response \\ "a" + val justHrefs = for { + l <- links + attrs <- l.attribute("href") + } yield (attrs.toString) + justHrefs should contain(link) + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala new file mode 100644 index 0000000000000..2fa90e3bd1c63 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -0,0 +1,606 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.io.DataOutputStream +import java.net.{HttpURLConnection, URL} +import javax.servlet.http.HttpServletResponse + +import scala.collection.mutable + +import akka.actor.{Actor, ActorRef, ActorSystem, Props} +import com.google.common.base.Charsets +import org.scalatest.{BeforeAndAfterEach, FunSuite} +import org.json4s.JsonAST._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark._ +import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.deploy.DeployMessages._ +import org.apache.spark.deploy.{SparkSubmit, SparkSubmitArguments} +import org.apache.spark.deploy.master.DriverState._ + +/** + * Tests for the REST application submission protocol used in standalone cluster mode. + */ +class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { + private val client = new StandaloneRestClient + private var actorSystem: Option[ActorSystem] = None + private var server: Option[StandaloneRestServer] = None + + override def afterEach() { + actorSystem.foreach(_.shutdown()) + server.foreach(_.stop()) + } + + test("construct submit request") { + val appArgs = Array("one", "two", "three") + val sparkProperties = Map("spark.app.name" -> "pi") + val environmentVariables = Map("SPARK_ONE" -> "UN", "SPARK_TWO" -> "DEUX") + val request = client.constructSubmitRequest( + "my-app-resource", "my-main-class", appArgs, sparkProperties, environmentVariables) + assert(request.action === Utils.getFormattedClassName(request)) + assert(request.clientSparkVersion === SPARK_VERSION) + assert(request.appResource === "my-app-resource") + assert(request.mainClass === "my-main-class") + assert(request.appArgs === appArgs) + assert(request.sparkProperties === sparkProperties) + assert(request.environmentVariables === environmentVariables) + } + + test("create submission") { + val submittedDriverId = "my-driver-id" + val submitMessage = "your driver is submitted" + val masterUrl = startDummyServer(submitId = submittedDriverId, submitMessage = submitMessage) + val appArgs = Array("one", "two", "four") + val request = constructSubmitRequest(masterUrl, appArgs) + assert(request.appArgs === appArgs) + assert(request.sparkProperties("spark.master") === masterUrl) + val response = client.createSubmission(masterUrl, request) + val submitResponse = getSubmitResponse(response) + assert(submitResponse.action === Utils.getFormattedClassName(submitResponse)) + assert(submitResponse.serverSparkVersion === SPARK_VERSION) + assert(submitResponse.message === submitMessage) + assert(submitResponse.submissionId === submittedDriverId) + assert(submitResponse.success) + } + + test("create submission from main method") { + val submittedDriverId = "your-driver-id" + val submitMessage = "my driver is submitted" + val masterUrl = startDummyServer(submitId = submittedDriverId, submitMessage = submitMessage) + val conf = new SparkConf(loadDefaults = false) + conf.set("spark.master", masterUrl) + conf.set("spark.app.name", "dreamer") + val appArgs = Array("one", "two", "six") + // main method calls this + val response = StandaloneRestClient.run("app-resource", "main-class", appArgs, conf) + val submitResponse = getSubmitResponse(response) + assert(submitResponse.action === Utils.getFormattedClassName(submitResponse)) + assert(submitResponse.serverSparkVersion === SPARK_VERSION) + assert(submitResponse.message === submitMessage) + assert(submitResponse.submissionId === submittedDriverId) + assert(submitResponse.success) + } + + test("kill submission") { + val submissionId = "my-lyft-driver" + val killMessage = "your driver is killed" + val masterUrl = startDummyServer(killMessage = killMessage) + val response = client.killSubmission(masterUrl, submissionId) + val killResponse = getKillResponse(response) + assert(killResponse.action === Utils.getFormattedClassName(killResponse)) + assert(killResponse.serverSparkVersion === SPARK_VERSION) + assert(killResponse.message === killMessage) + assert(killResponse.submissionId === submissionId) + assert(killResponse.success) + } + + test("request submission status") { + val submissionId = "my-uber-driver" + val submissionState = KILLED + val submissionException = new Exception("there was an irresponsible mix of alcohol and cars") + val masterUrl = startDummyServer(state = submissionState, exception = Some(submissionException)) + val response = client.requestSubmissionStatus(masterUrl, submissionId) + val statusResponse = getStatusResponse(response) + assert(statusResponse.action === Utils.getFormattedClassName(statusResponse)) + assert(statusResponse.serverSparkVersion === SPARK_VERSION) + assert(statusResponse.message.contains(submissionException.getMessage)) + assert(statusResponse.submissionId === submissionId) + assert(statusResponse.driverState === submissionState.toString) + assert(statusResponse.success) + } + + test("create then kill") { + val masterUrl = startSmartServer() + val request = constructSubmitRequest(masterUrl) + val response1 = client.createSubmission(masterUrl, request) + val submitResponse = getSubmitResponse(response1) + assert(submitResponse.success) + assert(submitResponse.submissionId != null) + // kill submission that was just created + val submissionId = submitResponse.submissionId + val response2 = client.killSubmission(masterUrl, submissionId) + val killResponse = getKillResponse(response2) + assert(killResponse.success) + assert(killResponse.submissionId === submissionId) + } + + test("create then request status") { + val masterUrl = startSmartServer() + val request = constructSubmitRequest(masterUrl) + val response1 = client.createSubmission(masterUrl, request) + val submitResponse = getSubmitResponse(response1) + assert(submitResponse.success) + assert(submitResponse.submissionId != null) + // request status of submission that was just created + val submissionId = submitResponse.submissionId + val response2 = client.requestSubmissionStatus(masterUrl, submissionId) + val statusResponse = getStatusResponse(response2) + assert(statusResponse.success) + assert(statusResponse.submissionId === submissionId) + assert(statusResponse.driverState === RUNNING.toString) + } + + test("create then kill then request status") { + val masterUrl = startSmartServer() + val request = constructSubmitRequest(masterUrl) + val response1 = client.createSubmission(masterUrl, request) + val response2 = client.createSubmission(masterUrl, request) + val submitResponse1 = getSubmitResponse(response1) + val submitResponse2 = getSubmitResponse(response2) + assert(submitResponse1.success) + assert(submitResponse2.success) + assert(submitResponse1.submissionId != null) + assert(submitResponse2.submissionId != null) + val submissionId1 = submitResponse1.submissionId + val submissionId2 = submitResponse2.submissionId + // kill only submission 1, but not submission 2 + val response3 = client.killSubmission(masterUrl, submissionId1) + val killResponse = getKillResponse(response3) + assert(killResponse.success) + assert(killResponse.submissionId === submissionId1) + // request status for both submissions: 1 should be KILLED but 2 should be RUNNING still + val response4 = client.requestSubmissionStatus(masterUrl, submissionId1) + val response5 = client.requestSubmissionStatus(masterUrl, submissionId2) + val statusResponse1 = getStatusResponse(response4) + val statusResponse2 = getStatusResponse(response5) + assert(statusResponse1.submissionId === submissionId1) + assert(statusResponse2.submissionId === submissionId2) + assert(statusResponse1.driverState === KILLED.toString) + assert(statusResponse2.driverState === RUNNING.toString) + } + + test("kill or request status before create") { + val masterUrl = startSmartServer() + val doesNotExist = "does-not-exist" + // kill a non-existent submission + val response1 = client.killSubmission(masterUrl, doesNotExist) + val killResponse = getKillResponse(response1) + assert(!killResponse.success) + assert(killResponse.submissionId === doesNotExist) + // request status for a non-existent submission + val response2 = client.requestSubmissionStatus(masterUrl, doesNotExist) + val statusResponse = getStatusResponse(response2) + assert(!statusResponse.success) + assert(statusResponse.submissionId === doesNotExist) + } + + /* ---------------------------------------- * + | Aberrant client / server behavior | + * ---------------------------------------- */ + + test("good request paths") { + val masterUrl = startSmartServer() + val httpUrl = masterUrl.replace("spark://", "http://") + val v = StandaloneRestServer.PROTOCOL_VERSION + val json = constructSubmitRequest(masterUrl).toJson + val submitRequestPath = s"$httpUrl/$v/submissions/create" + val killRequestPath = s"$httpUrl/$v/submissions/kill" + val statusRequestPath = s"$httpUrl/$v/submissions/status" + val (response1, code1) = sendHttpRequestWithResponse(submitRequestPath, "POST", json) + val (response2, code2) = sendHttpRequestWithResponse(s"$killRequestPath/anything", "POST") + val (response3, code3) = sendHttpRequestWithResponse(s"$killRequestPath/any/thing", "POST") + val (response4, code4) = sendHttpRequestWithResponse(s"$statusRequestPath/anything", "GET") + val (response5, code5) = sendHttpRequestWithResponse(s"$statusRequestPath/any/thing", "GET") + // these should all succeed and the responses should be of the correct types + getSubmitResponse(response1) + val killResponse1 = getKillResponse(response2) + val killResponse2 = getKillResponse(response3) + val statusResponse1 = getStatusResponse(response4) + val statusResponse2 = getStatusResponse(response5) + assert(killResponse1.submissionId === "anything") + assert(killResponse2.submissionId === "any") + assert(statusResponse1.submissionId === "anything") + assert(statusResponse2.submissionId === "any") + assert(code1 === HttpServletResponse.SC_OK) + assert(code2 === HttpServletResponse.SC_OK) + assert(code3 === HttpServletResponse.SC_OK) + assert(code4 === HttpServletResponse.SC_OK) + assert(code5 === HttpServletResponse.SC_OK) + } + + test("good request paths, bad requests") { + val masterUrl = startSmartServer() + val httpUrl = masterUrl.replace("spark://", "http://") + val v = StandaloneRestServer.PROTOCOL_VERSION + val submitRequestPath = s"$httpUrl/$v/submissions/create" + val killRequestPath = s"$httpUrl/$v/submissions/kill" + val statusRequestPath = s"$httpUrl/$v/submissions/status" + val goodJson = constructSubmitRequest(masterUrl).toJson + val badJson1 = goodJson.replaceAll("action", "fraction") // invalid JSON + val badJson2 = goodJson.substring(goodJson.size / 2) // malformed JSON + val notJson = "\"hello, world\"" + val (response1, code1) = sendHttpRequestWithResponse(submitRequestPath, "POST") // missing JSON + val (response2, code2) = sendHttpRequestWithResponse(submitRequestPath, "POST", badJson1) + val (response3, code3) = sendHttpRequestWithResponse(submitRequestPath, "POST", badJson2) + val (response4, code4) = sendHttpRequestWithResponse(killRequestPath, "POST") // missing ID + val (response5, code5) = sendHttpRequestWithResponse(s"$killRequestPath/", "POST") + val (response6, code6) = sendHttpRequestWithResponse(statusRequestPath, "GET") // missing ID + val (response7, code7) = sendHttpRequestWithResponse(s"$statusRequestPath/", "GET") + val (response8, code8) = sendHttpRequestWithResponse(submitRequestPath, "POST", notJson) + // these should all fail as error responses + getErrorResponse(response1) + getErrorResponse(response2) + getErrorResponse(response3) + getErrorResponse(response4) + getErrorResponse(response5) + getErrorResponse(response6) + getErrorResponse(response7) + getErrorResponse(response8) + assert(code1 === HttpServletResponse.SC_BAD_REQUEST) + assert(code2 === HttpServletResponse.SC_BAD_REQUEST) + assert(code3 === HttpServletResponse.SC_BAD_REQUEST) + assert(code4 === HttpServletResponse.SC_BAD_REQUEST) + assert(code5 === HttpServletResponse.SC_BAD_REQUEST) + assert(code6 === HttpServletResponse.SC_BAD_REQUEST) + assert(code7 === HttpServletResponse.SC_BAD_REQUEST) + assert(code8 === HttpServletResponse.SC_BAD_REQUEST) + } + + test("bad request paths") { + val masterUrl = startSmartServer() + val httpUrl = masterUrl.replace("spark://", "http://") + val v = StandaloneRestServer.PROTOCOL_VERSION + val (response1, code1) = sendHttpRequestWithResponse(httpUrl, "GET") + val (response2, code2) = sendHttpRequestWithResponse(s"$httpUrl/", "GET") + val (response3, code3) = sendHttpRequestWithResponse(s"$httpUrl/$v", "GET") + val (response4, code4) = sendHttpRequestWithResponse(s"$httpUrl/$v/", "GET") + val (response5, code5) = sendHttpRequestWithResponse(s"$httpUrl/$v/submissions", "GET") + val (response6, code6) = sendHttpRequestWithResponse(s"$httpUrl/$v/submissions/", "GET") + val (response7, code7) = sendHttpRequestWithResponse(s"$httpUrl/$v/submissions/bad", "GET") + val (response8, code8) = sendHttpRequestWithResponse(s"$httpUrl/bad-version", "GET") + assert(code1 === HttpServletResponse.SC_BAD_REQUEST) + assert(code2 === HttpServletResponse.SC_BAD_REQUEST) + assert(code3 === HttpServletResponse.SC_BAD_REQUEST) + assert(code4 === HttpServletResponse.SC_BAD_REQUEST) + assert(code5 === HttpServletResponse.SC_BAD_REQUEST) + assert(code6 === HttpServletResponse.SC_BAD_REQUEST) + assert(code7 === HttpServletResponse.SC_BAD_REQUEST) + assert(code8 === StandaloneRestServer.SC_UNKNOWN_PROTOCOL_VERSION) + // all responses should be error responses + val errorResponse1 = getErrorResponse(response1) + val errorResponse2 = getErrorResponse(response2) + val errorResponse3 = getErrorResponse(response3) + val errorResponse4 = getErrorResponse(response4) + val errorResponse5 = getErrorResponse(response5) + val errorResponse6 = getErrorResponse(response6) + val errorResponse7 = getErrorResponse(response7) + val errorResponse8 = getErrorResponse(response8) + // only the incompatible version response should have server protocol version set + assert(errorResponse1.highestProtocolVersion === null) + assert(errorResponse2.highestProtocolVersion === null) + assert(errorResponse3.highestProtocolVersion === null) + assert(errorResponse4.highestProtocolVersion === null) + assert(errorResponse5.highestProtocolVersion === null) + assert(errorResponse6.highestProtocolVersion === null) + assert(errorResponse7.highestProtocolVersion === null) + assert(errorResponse8.highestProtocolVersion === StandaloneRestServer.PROTOCOL_VERSION) + } + + test("server returns unknown fields") { + val masterUrl = startSmartServer() + val httpUrl = masterUrl.replace("spark://", "http://") + val v = StandaloneRestServer.PROTOCOL_VERSION + val submitRequestPath = s"$httpUrl/$v/submissions/create" + val oldJson = constructSubmitRequest(masterUrl).toJson + val oldFields = parse(oldJson).asInstanceOf[JObject].obj + val newFields = oldFields ++ Seq( + JField("tomato", JString("not-a-fruit")), + JField("potato", JString("not-po-tah-to")) + ) + val newJson = pretty(render(JObject(newFields))) + // send two requests, one with the unknown fields and the other without + val (response1, code1) = sendHttpRequestWithResponse(submitRequestPath, "POST", oldJson) + val (response2, code2) = sendHttpRequestWithResponse(submitRequestPath, "POST", newJson) + val submitResponse1 = getSubmitResponse(response1) + val submitResponse2 = getSubmitResponse(response2) + assert(code1 === HttpServletResponse.SC_OK) + assert(code2 === HttpServletResponse.SC_OK) + // only the response to the modified request should have unknown fields set + assert(submitResponse1.unknownFields === null) + assert(submitResponse2.unknownFields === Array("tomato", "potato")) + } + + test("client handles faulty server") { + val masterUrl = startFaultyServer() + val httpUrl = masterUrl.replace("spark://", "http://") + val v = StandaloneRestServer.PROTOCOL_VERSION + val submitRequestPath = s"$httpUrl/$v/submissions/create" + val killRequestPath = s"$httpUrl/$v/submissions/kill/anything" + val statusRequestPath = s"$httpUrl/$v/submissions/status/anything" + val json = constructSubmitRequest(masterUrl).toJson + // server returns malformed response unwittingly + // client should throw an appropriate exception to indicate server failure + val conn1 = sendHttpRequest(submitRequestPath, "POST", json) + intercept[SubmitRestProtocolException] { client.readResponse(conn1) } + // server attempts to send invalid response, but fails internally on validation + // client should receive an error response as server is able to recover + val conn2 = sendHttpRequest(killRequestPath, "POST") + val response2 = client.readResponse(conn2) + getErrorResponse(response2) + assert(conn2.getResponseCode === HttpServletResponse.SC_INTERNAL_SERVER_ERROR) + // server explodes internally beyond recovery + // client should throw an appropriate exception to indicate server failure + val conn3 = sendHttpRequest(statusRequestPath, "GET") + intercept[SubmitRestProtocolException] { client.readResponse(conn3) } // empty response + assert(conn3.getResponseCode === HttpServletResponse.SC_INTERNAL_SERVER_ERROR) + } + + /* --------------------- * + | Helper methods | + * --------------------- */ + + /** Start a dummy server that responds to requests using the specified parameters. */ + private def startDummyServer( + submitId: String = "fake-driver-id", + submitMessage: String = "driver is submitted", + killMessage: String = "driver is killed", + state: DriverState = FINISHED, + exception: Option[Exception] = None): String = { + startServer(new DummyMaster(submitId, submitMessage, killMessage, state, exception)) + } + + /** Start a smarter dummy server that keeps track of submitted driver states. */ + private def startSmartServer(): String = { + startServer(new SmarterMaster) + } + + /** Start a dummy server that is faulty in many ways... */ + private def startFaultyServer(): String = { + startServer(new DummyMaster, faulty = true) + } + + /** + * Start a [[StandaloneRestServer]] that communicates with the given actor. + * If `faulty` is true, start an [[FaultyStandaloneRestServer]] instead. + * Return the master URL that corresponds to the address of this server. + */ + private def startServer(makeFakeMaster: => Actor, faulty: Boolean = false): String = { + val name = "test-standalone-rest-protocol" + val conf = new SparkConf + val localhost = Utils.localHostName() + val securityManager = new SecurityManager(conf) + val (_actorSystem, _) = AkkaUtils.createActorSystem(name, localhost, 0, conf, securityManager) + val fakeMasterRef = _actorSystem.actorOf(Props(makeFakeMaster)) + val _server = + if (faulty) { + new FaultyStandaloneRestServer(localhost, 0, fakeMasterRef, "spark://fake:7077", conf) + } else { + new StandaloneRestServer(localhost, 0, fakeMasterRef, "spark://fake:7077", conf) + } + val port = _server.start() + // set these to clean them up after every test + actorSystem = Some(_actorSystem) + server = Some(_server) + s"spark://$localhost:$port" + } + + /** Create a submit request with real parameters using Spark submit. */ + private def constructSubmitRequest( + masterUrl: String, + appArgs: Array[String] = Array.empty): CreateSubmissionRequest = { + val mainClass = "main-class-not-used" + val mainJar = "dummy-jar-not-used.jar" + val commandLineArgs = Array( + "--deploy-mode", "cluster", + "--master", masterUrl, + "--name", mainClass, + "--class", mainClass, + mainJar) ++ appArgs + val args = new SparkSubmitArguments(commandLineArgs) + val (_, _, sparkProperties, _) = SparkSubmit.prepareSubmitEnvironment(args) + client.constructSubmitRequest( + mainJar, mainClass, appArgs, sparkProperties.toMap, Map.empty) + } + + /** Return the response as a submit response, or fail with error otherwise. */ + private def getSubmitResponse(response: SubmitRestProtocolResponse): CreateSubmissionResponse = { + response match { + case s: CreateSubmissionResponse => s + case e: ErrorResponse => fail(s"Server returned error: ${e.message}") + case r => fail(s"Expected submit response. Actual: ${r.toJson}") + } + } + + /** Return the response as a kill response, or fail with error otherwise. */ + private def getKillResponse(response: SubmitRestProtocolResponse): KillSubmissionResponse = { + response match { + case k: KillSubmissionResponse => k + case e: ErrorResponse => fail(s"Server returned error: ${e.message}") + case r => fail(s"Expected kill response. Actual: ${r.toJson}") + } + } + + /** Return the response as a status response, or fail with error otherwise. */ + private def getStatusResponse(response: SubmitRestProtocolResponse): SubmissionStatusResponse = { + response match { + case s: SubmissionStatusResponse => s + case e: ErrorResponse => fail(s"Server returned error: ${e.message}") + case r => fail(s"Expected status response. Actual: ${r.toJson}") + } + } + + /** Return the response as an error response, or fail if the response was not an error. */ + private def getErrorResponse(response: SubmitRestProtocolResponse): ErrorResponse = { + response match { + case e: ErrorResponse => e + case r => fail(s"Expected error response. Actual: ${r.toJson}") + } + } + + /** + * Send an HTTP request to the given URL using the method and the body specified. + * Return the connection object. + */ + private def sendHttpRequest( + url: String, + method: String, + body: String = ""): HttpURLConnection = { + val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod(method) + if (body.nonEmpty) { + conn.setDoOutput(true) + val out = new DataOutputStream(conn.getOutputStream) + out.write(body.getBytes(Charsets.UTF_8)) + out.close() + } + conn + } + + /** + * Send an HTTP request to the given URL using the method and the body specified. + * Return a 2-tuple of the response message from the server and the response code. + */ + private def sendHttpRequestWithResponse( + url: String, + method: String, + body: String = ""): (SubmitRestProtocolResponse, Int) = { + val conn = sendHttpRequest(url, method, body) + (client.readResponse(conn), conn.getResponseCode) + } +} + +/** + * A mock standalone Master that responds with dummy messages. + * In all responses, the success parameter is always true. + */ +private class DummyMaster( + submitId: String = "fake-driver-id", + submitMessage: String = "submitted", + killMessage: String = "killed", + state: DriverState = FINISHED, + exception: Option[Exception] = None) + extends Actor { + + override def receive = { + case RequestSubmitDriver(driverDesc) => + sender ! SubmitDriverResponse(success = true, Some(submitId), submitMessage) + case RequestKillDriver(driverId) => + sender ! KillDriverResponse(driverId, success = true, killMessage) + case RequestDriverStatus(driverId) => + sender ! DriverStatusResponse(found = true, Some(state), None, None, exception) + } +} + +/** + * A mock standalone Master that keeps track of drivers that have been submitted. + * + * If a driver is submitted, its state is immediately set to RUNNING. + * If an existing driver is killed, its state is immediately set to KILLED. + * If an existing driver's status is requested, its state is returned in the response. + * Submits are always successful while kills and status requests are successful only + * if the driver was submitted in the past. + */ +private class SmarterMaster extends Actor { + private var counter: Int = 0 + private val submittedDrivers = new mutable.HashMap[String, DriverState] + + override def receive = { + case RequestSubmitDriver(driverDesc) => + val driverId = s"driver-$counter" + submittedDrivers(driverId) = RUNNING + counter += 1 + sender ! SubmitDriverResponse(success = true, Some(driverId), "submitted") + + case RequestKillDriver(driverId) => + val success = submittedDrivers.contains(driverId) + if (success) { + submittedDrivers(driverId) = KILLED + } + sender ! KillDriverResponse(driverId, success, "killed") + + case RequestDriverStatus(driverId) => + val found = submittedDrivers.contains(driverId) + val state = submittedDrivers.get(driverId) + sender ! DriverStatusResponse(found, state, None, None, None) + } +} + +/** + * A [[StandaloneRestServer]] that is faulty in many ways. + * + * When handling a submit request, the server returns a malformed JSON. + * When handling a kill request, the server returns an invalid JSON. + * When handling a status request, the server throws an internal exception. + * The purpose of this class is to test that client handles these cases gracefully. + */ +private class FaultyStandaloneRestServer( + host: String, + requestedPort: Int, + masterActor: ActorRef, + masterUrl: String, + masterConf: SparkConf) + extends StandaloneRestServer(host, requestedPort, masterActor, masterUrl, masterConf) { + + protected override val contextToServlet = Map[String, StandaloneRestServlet]( + s"$baseContext/create/*" -> new MalformedSubmitServlet, + s"$baseContext/kill/*" -> new InvalidKillServlet, + s"$baseContext/status/*" -> new ExplodingStatusServlet, + "/*" -> new ErrorServlet + ) + + /** A faulty servlet that produces malformed responses. */ + class MalformedSubmitServlet extends SubmitRequestServlet(masterActor, masterUrl, masterConf) { + protected override def sendResponse( + responseMessage: SubmitRestProtocolResponse, + responseServlet: HttpServletResponse): Unit = { + val badJson = responseMessage.toJson.drop(10).dropRight(20) + responseServlet.getWriter.write(badJson) + } + } + + /** A faulty servlet that produces invalid responses. */ + class InvalidKillServlet extends KillRequestServlet(masterActor, masterConf) { + protected override def handleKill(submissionId: String): KillSubmissionResponse = { + val k = super.handleKill(submissionId) + k.submissionId = null + k + } + } + + /** A faulty status servlet that explodes. */ + class ExplodingStatusServlet extends StatusRequestServlet(masterActor, masterConf) { + private def explode: Int = 1 / 0 + protected override def handleStatus(submissionId: String): SubmissionStatusResponse = { + val s = super.handleStatus(submissionId) + s.workerId = explode.toString + s + } + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala new file mode 100644 index 0000000000000..1d64ec201e647 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.lang.Boolean +import java.lang.Integer + +import org.json4s.jackson.JsonMethods._ +import org.scalatest.FunSuite + +import org.apache.spark.SparkConf + +/** + * Tests for the REST application submission protocol. + */ +class SubmitRestProtocolSuite extends FunSuite { + + test("validate") { + val request = new DummyRequest + intercept[SubmitRestProtocolException] { request.validate() } // missing everything + request.clientSparkVersion = "1.2.3" + intercept[SubmitRestProtocolException] { request.validate() } // missing name and age + request.name = "something" + intercept[SubmitRestProtocolException] { request.validate() } // missing only age + request.age = 2 + intercept[SubmitRestProtocolException] { request.validate() } // age too low + request.age = 10 + request.validate() // everything is set properly + request.clientSparkVersion = null + intercept[SubmitRestProtocolException] { request.validate() } // missing only Spark version + request.clientSparkVersion = "1.2.3" + request.name = null + intercept[SubmitRestProtocolException] { request.validate() } // missing only name + request.message = "not-setting-name" + intercept[SubmitRestProtocolException] { request.validate() } // still missing name + } + + test("request to and from JSON") { + val request = new DummyRequest + intercept[SubmitRestProtocolException] { request.toJson } // implicit validation + request.clientSparkVersion = "1.2.3" + request.active = true + request.age = 25 + request.name = "jung" + val json = request.toJson + assertJsonEquals(json, dummyRequestJson) + val newRequest = SubmitRestProtocolMessage.fromJson(json, classOf[DummyRequest]) + assert(newRequest.clientSparkVersion === "1.2.3") + assert(newRequest.clientSparkVersion === "1.2.3") + assert(newRequest.active) + assert(newRequest.age === 25) + assert(newRequest.name === "jung") + assert(newRequest.message === null) + } + + test("response to and from JSON") { + val response = new DummyResponse + response.serverSparkVersion = "3.3.4" + response.success = true + val json = response.toJson + assertJsonEquals(json, dummyResponseJson) + val newResponse = SubmitRestProtocolMessage.fromJson(json, classOf[DummyResponse]) + assert(newResponse.serverSparkVersion === "3.3.4") + assert(newResponse.serverSparkVersion === "3.3.4") + assert(newResponse.success) + assert(newResponse.message === null) + } + + test("CreateSubmissionRequest") { + val message = new CreateSubmissionRequest + intercept[SubmitRestProtocolException] { message.validate() } + message.clientSparkVersion = "1.2.3" + message.appResource = "honey-walnut-cherry.jar" + message.mainClass = "org.apache.spark.examples.SparkPie" + val conf = new SparkConf(false) + conf.set("spark.app.name", "SparkPie") + message.sparkProperties = conf.getAll.toMap + message.validate() + // optional fields + conf.set("spark.jars", "mayonnaise.jar,ketchup.jar") + conf.set("spark.files", "fireball.png") + conf.set("spark.driver.memory", "512m") + conf.set("spark.driver.cores", "180") + conf.set("spark.driver.extraJavaOptions", " -Dslices=5 -Dcolor=mostly_red") + conf.set("spark.driver.extraClassPath", "food-coloring.jar") + conf.set("spark.driver.extraLibraryPath", "pickle.jar") + conf.set("spark.driver.supervise", "false") + conf.set("spark.executor.memory", "256m") + conf.set("spark.cores.max", "10000") + message.sparkProperties = conf.getAll.toMap + message.appArgs = Array("two slices", "a hint of cinnamon") + message.environmentVariables = Map("PATH" -> "/dev/null") + message.validate() + // bad fields + var badConf = conf.clone().set("spark.driver.cores", "one hundred feet") + message.sparkProperties = badConf.getAll.toMap + intercept[SubmitRestProtocolException] { message.validate() } + badConf = conf.clone().set("spark.driver.supervise", "nope, never") + message.sparkProperties = badConf.getAll.toMap + intercept[SubmitRestProtocolException] { message.validate() } + badConf = conf.clone().set("spark.cores.max", "two men") + message.sparkProperties = badConf.getAll.toMap + intercept[SubmitRestProtocolException] { message.validate() } + message.sparkProperties = conf.getAll.toMap + // test JSON + val json = message.toJson + assertJsonEquals(json, submitDriverRequestJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[CreateSubmissionRequest]) + assert(newMessage.clientSparkVersion === "1.2.3") + assert(newMessage.appResource === "honey-walnut-cherry.jar") + assert(newMessage.mainClass === "org.apache.spark.examples.SparkPie") + assert(newMessage.sparkProperties("spark.app.name") === "SparkPie") + assert(newMessage.sparkProperties("spark.jars") === "mayonnaise.jar,ketchup.jar") + assert(newMessage.sparkProperties("spark.files") === "fireball.png") + assert(newMessage.sparkProperties("spark.driver.memory") === "512m") + assert(newMessage.sparkProperties("spark.driver.cores") === "180") + assert(newMessage.sparkProperties("spark.driver.extraJavaOptions") === " -Dslices=5 -Dcolor=mostly_red") + assert(newMessage.sparkProperties("spark.driver.extraClassPath") === "food-coloring.jar") + assert(newMessage.sparkProperties("spark.driver.extraLibraryPath") === "pickle.jar") + assert(newMessage.sparkProperties("spark.driver.supervise") === "false") + assert(newMessage.sparkProperties("spark.executor.memory") === "256m") + assert(newMessage.sparkProperties("spark.cores.max") === "10000") + assert(newMessage.appArgs === message.appArgs) + assert(newMessage.sparkProperties === message.sparkProperties) + assert(newMessage.environmentVariables === message.environmentVariables) + } + + test("CreateSubmissionResponse") { + val message = new CreateSubmissionResponse + intercept[SubmitRestProtocolException] { message.validate() } + message.serverSparkVersion = "1.2.3" + message.submissionId = "driver_123" + message.success = true + message.validate() + // test JSON + val json = message.toJson + assertJsonEquals(json, submitDriverResponseJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[CreateSubmissionResponse]) + assert(newMessage.serverSparkVersion === "1.2.3") + assert(newMessage.submissionId === "driver_123") + assert(newMessage.success) + } + + test("KillSubmissionResponse") { + val message = new KillSubmissionResponse + intercept[SubmitRestProtocolException] { message.validate() } + message.serverSparkVersion = "1.2.3" + message.submissionId = "driver_123" + message.success = true + message.validate() + // test JSON + val json = message.toJson + assertJsonEquals(json, killDriverResponseJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[KillSubmissionResponse]) + assert(newMessage.serverSparkVersion === "1.2.3") + assert(newMessage.submissionId === "driver_123") + assert(newMessage.success) + } + + test("SubmissionStatusResponse") { + val message = new SubmissionStatusResponse + intercept[SubmitRestProtocolException] { message.validate() } + message.serverSparkVersion = "1.2.3" + message.submissionId = "driver_123" + message.success = true + message.validate() + // optional fields + message.driverState = "RUNNING" + message.workerId = "worker_123" + message.workerHostPort = "1.2.3.4:7780" + // test JSON + val json = message.toJson + assertJsonEquals(json, driverStatusResponseJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[SubmissionStatusResponse]) + assert(newMessage.serverSparkVersion === "1.2.3") + assert(newMessage.submissionId === "driver_123") + assert(newMessage.driverState === "RUNNING") + assert(newMessage.success) + assert(newMessage.workerId === "worker_123") + assert(newMessage.workerHostPort === "1.2.3.4:7780") + } + + test("ErrorResponse") { + val message = new ErrorResponse + intercept[SubmitRestProtocolException] { message.validate() } + message.serverSparkVersion = "1.2.3" + message.message = "Field not found in submit request: X" + message.validate() + // test JSON + val json = message.toJson + assertJsonEquals(json, errorJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[ErrorResponse]) + assert(newMessage.serverSparkVersion === "1.2.3") + assert(newMessage.message === "Field not found in submit request: X") + } + + private val dummyRequestJson = + """ + |{ + | "action" : "DummyRequest", + | "active" : true, + | "age" : 25, + | "clientSparkVersion" : "1.2.3", + | "name" : "jung" + |} + """.stripMargin + + private val dummyResponseJson = + """ + |{ + | "action" : "DummyResponse", + | "serverSparkVersion" : "3.3.4", + | "success": true + |} + """.stripMargin + + private val submitDriverRequestJson = + """ + |{ + | "action" : "CreateSubmissionRequest", + | "appArgs" : [ "two slices", "a hint of cinnamon" ], + | "appResource" : "honey-walnut-cherry.jar", + | "clientSparkVersion" : "1.2.3", + | "environmentVariables" : { + | "PATH" : "/dev/null" + | }, + | "mainClass" : "org.apache.spark.examples.SparkPie", + | "sparkProperties" : { + | "spark.driver.extraLibraryPath" : "pickle.jar", + | "spark.jars" : "mayonnaise.jar,ketchup.jar", + | "spark.driver.supervise" : "false", + | "spark.app.name" : "SparkPie", + | "spark.cores.max" : "10000", + | "spark.driver.memory" : "512m", + | "spark.files" : "fireball.png", + | "spark.driver.cores" : "180", + | "spark.driver.extraJavaOptions" : " -Dslices=5 -Dcolor=mostly_red", + | "spark.executor.memory" : "256m", + | "spark.driver.extraClassPath" : "food-coloring.jar" + | } + |} + """.stripMargin + + private val submitDriverResponseJson = + """ + |{ + | "action" : "CreateSubmissionResponse", + | "serverSparkVersion" : "1.2.3", + | "submissionId" : "driver_123", + | "success" : true + |} + """.stripMargin + + private val killDriverResponseJson = + """ + |{ + | "action" : "KillSubmissionResponse", + | "serverSparkVersion" : "1.2.3", + | "submissionId" : "driver_123", + | "success" : true + |} + """.stripMargin + + private val driverStatusResponseJson = + """ + |{ + | "action" : "SubmissionStatusResponse", + | "driverState" : "RUNNING", + | "serverSparkVersion" : "1.2.3", + | "submissionId" : "driver_123", + | "success" : true, + | "workerHostPort" : "1.2.3.4:7780", + | "workerId" : "worker_123" + |} + """.stripMargin + + private val errorJson = + """ + |{ + | "action" : "ErrorResponse", + | "message" : "Field not found in submit request: X", + | "serverSparkVersion" : "1.2.3" + |} + """.stripMargin + + /** Assert that the contents in the two JSON strings are equal after ignoring whitespace. */ + private def assertJsonEquals(jsonString1: String, jsonString2: String): Unit = { + val trimmedJson1 = jsonString1.trim + val trimmedJson2 = jsonString2.trim + val json1 = compact(render(parse(trimmedJson1))) + val json2 = compact(render(parse(trimmedJson2))) + // Put this on a separate line to avoid printing comparison twice when test fails + val equals = json1 == json2 + assert(equals, "\"[%s]\" did not equal \"[%s]\"".format(trimmedJson1, trimmedJson2)) + } +} + +private class DummyResponse extends SubmitRestProtocolResponse +private class DummyRequest extends SubmitRestProtocolRequest { + var active: Boolean = null + var age: Integer = null + var name: String = null + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(name, "name") + assertFieldIsSet(age, "age") + assert(age > 5, "Not old enough!") + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala index b6f4411e0587a..aa6e4874cecde 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala @@ -27,6 +27,7 @@ import org.scalatest.FunSuite import org.apache.spark.SparkConf import org.apache.spark.deploy.{Command, DriverDescription} +import org.apache.spark.util.Clock class DriverRunnerTest extends FunSuite { private def createDriverRunner() = { @@ -129,7 +130,7 @@ class DriverRunnerTest extends FunSuite { .thenReturn(-1) // fail 3 .thenReturn(-1) // fail 4 .thenReturn(0) // success - when(clock.currentTimeMillis()) + when(clock.getTimeMillis()) .thenReturn(0).thenReturn(1000) // fail 1 (short) .thenReturn(1000).thenReturn(2000) // fail 2 (short) .thenReturn(2000).thenReturn(10000) // fail 3 (long) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index 6f233d7cf97aa..6fca6321e5a1b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -32,8 +32,8 @@ class ExecutorRunnerTest extends FunSuite { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) val appDesc = new ApplicationDescription("app name", Some(8), 500, Command("foo", Seq(appId), Map(), Seq(), Seq(), Seq()), "appUiUrl") - val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", - new File(sparkHome), new File("ooga"), "blah", new SparkConf, Seq("localDir"), + val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", 123, + "publicAddr", new File(sparkHome), new File("ooga"), "blah", new SparkConf, Seq("localDir"), ExecutorState.RUNNING) val builder = CommandUtils.buildProcessBuilder(appDesc.command, 512, sparkHome, er.substituteVariables) assert(builder.command().last === appId) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JavaJDBCTrampoline.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala similarity index 67% rename from sql/core/src/main/scala/org/apache/spark/sql/jdbc/JavaJDBCTrampoline.scala rename to core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala index 86bb67ec74256..326e203afe136 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JavaJDBCTrampoline.scala +++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala @@ -15,16 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.jdbc +package org.apache.spark.executor -import org.apache.spark.sql.DataFrame +import org.scalatest.FunSuite -private[jdbc] class JavaJDBCTrampoline { - def createJDBCTable(rdd: DataFrame, url: String, table: String, allowExisting: Boolean) { - rdd.createJDBCTable(url, table, allowExisting); - } - - def insertIntoJDBC(rdd: DataFrame, url: String, table: String, overwrite: Boolean) { - rdd.insertIntoJDBC(url, table, overwrite); +class TaskMetricsSuite extends FunSuite { + test("[SPARK-5701] updateShuffleReadMetrics: ShuffleReadMetrics not added when no shuffle deps") { + val taskMetrics = new TaskMetrics() + taskMetrics.updateShuffleReadMetrics() + assert(taskMetrics.shuffleReadMetrics.isEmpty) } } diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala index 98b0a16ce88ba..2e58c159a2ed8 100644 --- a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala @@ -28,7 +28,7 @@ import org.scalatest.FunSuite import org.apache.hadoop.io.Text -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.util.Utils import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodecFactory, GzipCodec} @@ -42,7 +42,15 @@ class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll { private var factory: CompressionCodecFactory = _ override def beforeAll() { - sc = new SparkContext("local", "test") + // Hadoop's FileSystem caching does not use the Configuration as part of its cache key, which + // can cause Filesystem.get(Configuration) to return a cached instance created with a different + // configuration than the one passed to get() (see HADOOP-8490 for more details). This caused + // hard-to-reproduce test failures, since any suites that were run after this one would inherit + // the new value of "fs.local.block.size" (see SPARK-5227 and SPARK-5679). To work around this, + // we disable FileSystem caching in this suite. + val conf = new SparkConf().set("spark.hadoop.fs.file.impl.disable.cache", "true") + + sc = new SparkContext("local", "test", conf) // Set the block size of local file system to test whether files are split right or not. sc.hadoopConfiguration.setLong("fs.local.block.size", 32) diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 81db66ae17464..78fa98a3b9065 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -21,44 +21,46 @@ import java.io.{File, FileWriter, PrintWriter} import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - +import org.apache.commons.lang.math.RandomUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{LongWritable, Text} -import org.apache.hadoop.mapred.{FileSplit => OldFileSplit, InputSplit => OldInputSplit, JobConf, - LineRecordReader => OldLineRecordReader, RecordReader => OldRecordReader, Reporter, - TextInputFormat => OldTextInputFormat} import org.apache.hadoop.mapred.lib.{CombineFileInputFormat => OldCombineFileInputFormat, - CombineFileSplit => OldCombineFileSplit, CombineFileRecordReader => OldCombineFileRecordReader} -import org.apache.hadoop.mapreduce.{InputSplit => NewInputSplit, RecordReader => NewRecordReader, - TaskAttemptContext} + CombineFileRecordReader => OldCombineFileRecordReader, CombineFileSplit => OldCombineFileSplit} +import org.apache.hadoop.mapred.{JobConf, Reporter, FileSplit => OldFileSplit, + InputSplit => OldInputSplit, LineRecordReader => OldLineRecordReader, + RecordReader => OldRecordReader, TextInputFormat => OldTextInputFormat} import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat => NewCombineFileInputFormat, CombineFileRecordReader => NewCombineFileRecordReader, CombineFileSplit => NewCombineFileSplit, FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} +import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} +import org.apache.hadoop.mapreduce.{TaskAttemptContext, InputSplit => NewInputSplit, + RecordReader => NewRecordReader} +import org.scalatest.{BeforeAndAfter, FunSuite} import org.apache.spark.SharedSparkContext import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.util.Utils -class InputOutputMetricsSuite extends FunSuite with SharedSparkContext { +class InputOutputMetricsSuite extends FunSuite with SharedSparkContext + with BeforeAndAfter { @transient var tmpDir: File = _ @transient var tmpFile: File = _ @transient var tmpFilePath: String = _ + @transient val numRecords: Int = 100000 + @transient val numBuckets: Int = 10 - override def beforeAll() { - super.beforeAll() - + before { tmpDir = Utils.createTempDir() val testTempDir = new File(tmpDir, "test") testTempDir.mkdir() tmpFile = new File(testTempDir, getClass.getSimpleName + ".txt") val pw = new PrintWriter(new FileWriter(tmpFile)) - for (x <- 1 to 1000000) { - pw.println("s") + for (x <- 1 to numRecords) { + pw.println(RandomUtils.nextInt(numBuckets)) } pw.close() @@ -66,8 +68,7 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext { tmpFilePath = "file://" + tmpFile.getAbsolutePath } - override def afterAll() { - super.afterAll() + after { Utils.deleteRecursively(tmpDir) } @@ -155,6 +156,101 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext { assert(bytesRead >= tmpFile.length()) } + test("input metrics on records read - simple") { + val records = runAndReturnRecordsRead { + sc.textFile(tmpFilePath, 4).count() + } + assert(records == numRecords) + } + + test("input metrics on records read - more stages") { + val records = runAndReturnRecordsRead { + sc.textFile(tmpFilePath, 4) + .map(key => (key.length, 1)) + .reduceByKey(_ + _) + .count() + } + assert(records == numRecords) + } + + test("input metrics on records - New Hadoop API") { + val records = runAndReturnRecordsRead { + sc.newAPIHadoopFile(tmpFilePath, classOf[NewTextInputFormat], classOf[LongWritable], + classOf[Text]).count() + } + assert(records == numRecords) + } + + test("input metrics on recordsd read with cache") { + // prime the cache manager + val rdd = sc.textFile(tmpFilePath, 4).cache() + rdd.collect() + + val records = runAndReturnRecordsRead { + rdd.count() + } + + assert(records == numRecords) + } + + test("shuffle records read metrics") { + val recordsRead = runAndReturnShuffleRecordsRead { + sc.textFile(tmpFilePath, 4) + .map(key => (key, 1)) + .groupByKey() + .collect() + } + assert(recordsRead == numRecords) + } + + test("shuffle records written metrics") { + val recordsWritten = runAndReturnShuffleRecordsWritten { + sc.textFile(tmpFilePath, 4) + .map(key => (key, 1)) + .groupByKey() + .collect() + } + assert(recordsWritten == numRecords) + } + + /** + * Tests the metrics from end to end. + * 1) reading a hadoop file + * 2) shuffle and writing to a hadoop file. + * 3) writing to hadoop file. + */ + test("input read/write and shuffle read/write metrics all line up") { + var inputRead = 0L + var outputWritten = 0L + var shuffleRead = 0L + var shuffleWritten = 0L + sc.addSparkListener(new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + val metrics = taskEnd.taskMetrics + metrics.inputMetrics.foreach(inputRead += _.recordsRead) + metrics.outputMetrics.foreach(outputWritten += _.recordsWritten) + metrics.shuffleReadMetrics.foreach(shuffleRead += _.recordsRead) + metrics.shuffleWriteMetrics.foreach(shuffleWritten += _.shuffleRecordsWritten) + } + }) + + val tmpFile = new File(tmpDir, getClass.getSimpleName) + + sc.textFile(tmpFilePath, 4) + .map(key => (key, 1)) + .reduceByKey(_+_) + .saveAsTextFile("file://" + tmpFile.getAbsolutePath) + + sc.listenerBus.waitUntilEmpty(500) + assert(inputRead == numRecords) + + // Only supported on newer Hadoop + if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback().isDefined) { + assert(outputWritten == numBuckets) + } + assert(shuffleRead == shuffleWritten) + } + test("input metrics with interleaved reads") { val numPartitions = 2 val cartVector = 0 to 9 @@ -193,18 +289,66 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext { assert(cartesianBytes == firstSize * numPartitions + (cartVector.length * secondSize)) } - private def runAndReturnBytesRead(job : => Unit): Long = { - val taskBytesRead = new ArrayBuffer[Long]() + private def runAndReturnBytesRead(job: => Unit): Long = { + runAndReturnMetrics(job, _.taskMetrics.inputMetrics.map(_.bytesRead)) + } + + private def runAndReturnRecordsRead(job: => Unit): Long = { + runAndReturnMetrics(job, _.taskMetrics.inputMetrics.map(_.recordsRead)) + } + + private def runAndReturnRecordsWritten(job: => Unit): Long = { + runAndReturnMetrics(job, _.taskMetrics.outputMetrics.map(_.recordsWritten)) + } + + private def runAndReturnShuffleRecordsRead(job: => Unit): Long = { + runAndReturnMetrics(job, _.taskMetrics.shuffleReadMetrics.map(_.recordsRead)) + } + + private def runAndReturnShuffleRecordsWritten(job: => Unit): Long = { + runAndReturnMetrics(job, _.taskMetrics.shuffleWriteMetrics.map(_.shuffleRecordsWritten)) + } + + private def runAndReturnMetrics(job: => Unit, + collector: (SparkListenerTaskEnd) => Option[Long]): Long = { + val taskMetrics = new ArrayBuffer[Long]() sc.addSparkListener(new SparkListener() { override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { - taskBytesRead += taskEnd.taskMetrics.inputMetrics.get.bytesRead + collector(taskEnd).foreach(taskMetrics += _) } }) job sc.listenerBus.waitUntilEmpty(500) - taskBytesRead.sum + taskMetrics.sum + } + + test("output metrics on records written") { + // Only supported on newer Hadoop + if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback().isDefined) { + val file = new File(tmpDir, getClass.getSimpleName) + val filePath = "file://" + file.getAbsolutePath + + val records = runAndReturnRecordsWritten { + sc.parallelize(1 to numRecords).saveAsTextFile(filePath) + } + assert(records == numRecords) + } + } + + test("output metrics on records written - new Hadoop API") { + // Only supported on newer Hadoop + if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback().isDefined) { + val file = new File(tmpDir, getClass.getSimpleName) + val filePath = "file://" + file.getAbsolutePath + + val records = runAndReturnRecordsWritten { + sc.parallelize(1 to numRecords).map(key => (key.toString, key.toString)) + .saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](filePath) + } + assert(records == numRecords) + } } test("output metrics when writing text file") { @@ -318,4 +462,4 @@ class NewCombineTextRecordReaderWrapper( override def getCurrentValue(): Text = delegate.getCurrentValue override def getProgress(): Float = delegate.getProgress override def close(): Unit = delegate.close() -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala index 1a9ce8c607dcd..37e528435aa5d 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala @@ -27,7 +27,7 @@ class MetricsConfigSuite extends FunSuite with BeforeAndAfter { } test("MetricsConfig with default properties") { - val conf = new MetricsConfig(Option("dummy-file")) + val conf = new MetricsConfig(None) conf.initialize() assert(conf.properties.size() === 4) diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala index de306533752c1..787b06e1e6701 100644 --- a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala @@ -232,6 +232,12 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { assert(histogramBuckets === expectedHistogramBuckets) } + test("WorksWithDoubleValuesAtMinMax") { + val rdd = sc.parallelize(Seq(1, 1, 1, 2, 3, 3)) + assert(Array(3, 0, 1, 2) === rdd.map(_.toDouble).histogram(4)._2) + assert(Array(3, 1, 2) === rdd.map(_.toDouble).histogram(3)._2) + } + test("WorksWithoutBucketsWithMoreRequestedThanElements") { // Verify the basic case of one bucket and all elements in that bucket works val rdd = sc.parallelize(Seq(1, 2)) @@ -245,7 +251,7 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { } test("WorksWithoutBucketsForLargerDatasets") { - // Verify the case of slighly larger datasets + // Verify the case of slightly larger datasets val rdd = sc.parallelize(6 to 99) val (histogramBuckets, histogramResults) = rdd.histogram(8) val expectedHistogramResults = @@ -256,17 +262,27 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { assert(histogramBuckets === expectedHistogramBuckets) } - test("WorksWithoutBucketsWithIrrationalBucketEdges") { - // Verify the case of buckets with irrational edges. See #SPARK-2862. + test("WorksWithoutBucketsWithNonIntegralBucketEdges") { + // Verify the case of buckets with nonintegral edges. See #SPARK-2862. val rdd = sc.parallelize(6 to 99) val (histogramBuckets, histogramResults) = rdd.histogram(9) + // Buckets are 6.0, 16.333333333333336, 26.666666666666668, 37.0, 47.333333333333336 ... val expectedHistogramResults = - Array(11, 10, 11, 10, 10, 11, 10, 10, 11) + Array(11, 10, 10, 11, 10, 10, 11, 10, 11) assert(histogramResults === expectedHistogramResults) assert(histogramBuckets(0) === 6.0) assert(histogramBuckets(9) === 99.0) } + test("WorksWithHugeRange") { + val rdd = sc.parallelize(Array(0, 1.0e24, 1.0e30)) + val histogramResults = rdd.histogram(1000000)._2 + assert(histogramResults(0) === 1) + assert(histogramResults(1) === 1) + assert(histogramResults.last === 1) + assert((2 to histogramResults.length - 2).forall(i => histogramResults(i) == 0)) + } + // Test the failure mode with an invalid RDD test("ThrowsExceptionOnInvalidRDDs") { // infinity diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index eb116213f69fc..febfe081b28f2 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -208,7 +208,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(taskSet.tasks.size >= results.size) for ((result, i) <- results.zipWithIndex) { if (i < taskSet.tasks.size) { - runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, null, null, null)) + runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, null, createFakeTaskInfo(), null)) } } } @@ -219,7 +219,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar for ((result, i) <- results.zipWithIndex) { if (i < taskSet.tasks.size) { runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, - Map[Long, Any]((accumId, 1)), null, null)) + Map[Long, Any]((accumId, 1)), createFakeTaskInfo(), null)) } } } @@ -476,7 +476,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null, Map[Long, Any](), - null, + createFakeTaskInfo(), null)) assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.contains(1)) @@ -487,7 +487,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1, "ignored"), null, Map[Long, Any](), - null, + createFakeTaskInfo(), null)) // The SparkListener should not receive redundant failure events. assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) @@ -507,14 +507,14 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(newEpoch > oldEpoch) val taskSet = taskSets(0) // should be ignored for being too old - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, null, null)) + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) // should work because it's a non-failed host - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), null, null, null)) + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), null, createFakeTaskInfo(), null)) // should be ignored for being too old - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, null, null)) + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) // should work because it's a new epoch taskSet.tasks(1).epoch = newEpoch - runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), null, null, null)) + runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) complete(taskSets(1), Seq((Success, 42), (Success, 43))) @@ -765,6 +765,16 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(scheduler.runningStages.isEmpty) assert(scheduler.shuffleToMapStage.isEmpty) assert(scheduler.waitingStages.isEmpty) + assert(scheduler.outputCommitCoordinator.isEmpty) } + + // Nothing in this test should break if the task info's fields are null, but + // OutputCommitCoordinator requires the task info itself to not be null. + private def createFakeTaskInfo(): TaskInfo = { + val info = new TaskInfo(0, 0, 0, 0L, "", "", TaskLocality.ANY, false) + info.finishTime = 1 // to prevent spurious errors in JobProgressListener + info + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 437d8693c0b1f..7e401664a6337 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.io.{File, FileOutputStream, InputStream, IOException} +import java.net.URI import scala.collection.mutable import scala.io.Source @@ -26,7 +27,7 @@ import org.apache.hadoop.fs.Path import org.json4s.jackson.JsonMethods._ import org.scalatest.{BeforeAndAfter, FunSuite} -import org.apache.spark.{Logging, SparkConf, SparkContext} +import org.apache.spark.{Logging, SparkConf, SparkContext, SPARK_VERSION} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io._ import org.apache.spark.util.{JsonProtocol, Utils} @@ -59,7 +60,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter with Loggin test("Verify log file exist") { // Verify logging directory exists val conf = getLoggingConf(testDirPath) - val eventLogger = new EventLoggingListener("test", testDirPath.toUri().toString(), conf) + val eventLogger = new EventLoggingListener("test", testDirPath.toUri(), conf) eventLogger.start() val logPath = new Path(eventLogger.logPath + EventLoggingListener.IN_PROGRESS) @@ -78,7 +79,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter with Loggin test("Basic event logging with compression") { CompressionCodec.ALL_COMPRESSION_CODECS.foreach { codec => - testEventLogging(compressionCodec = Some(codec)) + testEventLogging(compressionCodec = Some(CompressionCodec.getShortName(codec))) } } @@ -88,25 +89,38 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter with Loggin test("End-to-end event logging with compression") { CompressionCodec.ALL_COMPRESSION_CODECS.foreach { codec => - testApplicationEventLogging(compressionCodec = Some(codec)) + testApplicationEventLogging(compressionCodec = Some(CompressionCodec.getShortName(codec))) } } test("Log overwriting") { - val log = new FileOutputStream(new File(testDir, "test")) - log.close() - try { - testEventLogging() - assert(false) - } catch { - case e: IOException => - // Expected, since we haven't enabled log overwrite. - } - + val logUri = EventLoggingListener.getLogPath(testDir.toURI, "test") + val logPath = new URI(logUri).getPath + // Create file before writing the event log + new FileOutputStream(new File(logPath)).close() + // Expected IOException, since we haven't enabled log overwrite. + intercept[IOException] { testEventLogging() } // Try again, but enable overwriting. testEventLogging(extraConf = Map("spark.eventLog.overwrite" -> "true")) } + test("Event log name") { + // without compression + assert(s"file:/base-dir/app1" === EventLoggingListener.getLogPath( + Utils.resolveURI("/base-dir"), "app1")) + // with compression + assert(s"file:/base-dir/app1.lzf" === + EventLoggingListener.getLogPath(Utils.resolveURI("/base-dir"), "app1", Some("lzf"))) + // illegal characters in app ID + assert(s"file:/base-dir/a-fine-mind_dollar_bills__1" === + EventLoggingListener.getLogPath(Utils.resolveURI("/base-dir"), + "a fine:mind$dollar{bills}.1")) + // illegal characters in app ID with compression + assert(s"file:/base-dir/a-fine-mind_dollar_bills__1.lz4" === + EventLoggingListener.getLogPath(Utils.resolveURI("/base-dir"), + "a fine:mind$dollar{bills}.1", Some("lz4"))) + } + /* ----------------- * * Actual test logic * * ----------------- */ @@ -125,7 +139,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter with Loggin val conf = getLoggingConf(testDirPath, compressionCodec) extraConf.foreach { case (k, v) => conf.set(k, v) } val logName = compressionCodec.map("test-" + _).getOrElse("test") - val eventLogger = new EventLoggingListener(logName, testDirPath.toUri().toString(), conf) + val eventLogger = new EventLoggingListener(logName, testDirPath.toUri(), conf) val listenerBus = new LiveListenerBus val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None, 125L, "Mickey") @@ -140,15 +154,17 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter with Loggin eventLogger.stop() // Verify file contains exactly the two events logged - val (logData, version) = EventLoggingListener.openEventLog(new Path(eventLogger.logPath), - fileSystem) + val logData = EventLoggingListener.openEventLog(new Path(eventLogger.logPath), fileSystem) try { val lines = readLines(logData) - assert(lines.size === 2) - assert(lines(0).contains("SparkListenerApplicationStart")) - assert(lines(1).contains("SparkListenerApplicationEnd")) - assert(JsonProtocol.sparkEventFromJson(parse(lines(0))) === applicationStart) - assert(JsonProtocol.sparkEventFromJson(parse(lines(1))) === applicationEnd) + val logStart = SparkListenerLogStart(SPARK_VERSION) + assert(lines.size === 3) + assert(lines(0).contains("SparkListenerLogStart")) + assert(lines(1).contains("SparkListenerApplicationStart")) + assert(lines(2).contains("SparkListenerApplicationEnd")) + assert(JsonProtocol.sparkEventFromJson(parse(lines(0))) === logStart) + assert(JsonProtocol.sparkEventFromJson(parse(lines(1))) === applicationStart) + assert(JsonProtocol.sparkEventFromJson(parse(lines(2))) === applicationEnd) } finally { logData.close() } @@ -159,12 +175,17 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter with Loggin * This runs a simple Spark job and asserts that the expected events are logged when expected. */ private def testApplicationEventLogging(compressionCodec: Option[String] = None) { + // Set defaultFS to something that would cause an exception, to make sure we don't run + // into SPARK-6688. val conf = getLoggingConf(testDirPath, compressionCodec) + .set("spark.hadoop.fs.defaultFS", "unsupported://example.com") val sc = new SparkContext("local-cluster[2,2,512]", "test", conf) assert(sc.eventLogger.isDefined) val eventLogger = sc.eventLogger.get - val expectedLogDir = testDir.toURI().toString() - assert(eventLogger.logPath.startsWith(expectedLogDir + "/")) + val eventLogPath = eventLogger.logPath + val expectedLogDir = testDir.toURI() + assert(eventLogPath === EventLoggingListener.getLogPath( + expectedLogDir, sc.applicationId, compressionCodec.map(CompressionCodec.getShortName))) // Begin listening for events that trigger asserts val eventExistenceListener = new EventExistenceListener(eventLogger) @@ -178,8 +199,8 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter with Loggin eventExistenceListener.assertAllCallbacksInvoked() // Make sure expected events exist in the log file. - val (logData, version) = EventLoggingListener.openEventLog(new Path(eventLogger.logPath), - fileSystem) + val logData = EventLoggingListener.openEventLog(new Path(eventLogger.logPath), fileSystem) + val logStart = SparkListenerLogStart(SPARK_VERSION) val lines = readLines(logData) val eventSet = mutable.Set( SparkListenerApplicationStart, @@ -204,6 +225,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter with Loggin } } } + assert(JsonProtocol.sparkEventFromJson(parse(lines(0))) === logStart) assert(eventSet.isEmpty, "The following events are missing: " + eventSet.toSeq) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala new file mode 100644 index 0000000000000..732d466848558 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.io.File +import java.util.concurrent.TimeoutException + +import org.mockito.Matchers +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.{BeforeAndAfter, FunSuite} + +import org.apache.hadoop.mapred.{TaskAttemptID, JobConf, TaskAttemptContext, OutputCommitter} + +import org.apache.spark._ +import org.apache.spark.rdd.{RDD, FakeOutputCommitter} +import org.apache.spark.util.Utils + +import scala.concurrent.Await +import scala.concurrent.duration._ +import scala.language.postfixOps + +/** + * Unit tests for the output commit coordination functionality. + * + * The unit test makes both the original task and the speculated task + * attempt to commit, where committing is emulated by creating a + * directory. If both tasks create directories then the end result is + * a failure. + * + * Note that there are some aspects of this test that are less than ideal. + * In particular, the test mocks the speculation-dequeuing logic to always + * dequeue a task and consider it as speculated. Immediately after initially + * submitting the tasks and calling reviveOffers(), reviveOffers() is invoked + * again to pick up the speculated task. This may be hacking the original + * behavior in too much of an unrealistic fashion. + * + * Also, the validation is done by checking the number of files in a directory. + * Ideally, an accumulator would be used for this, where we could increment + * the accumulator in the output committer's commitTask() call. If the call to + * commitTask() was called twice erroneously then the test would ideally fail because + * the accumulator would be incremented twice. + * + * The problem with this test implementation is that when both a speculated task and + * its original counterpart complete, only one of the accumulator's increments is + * captured. This results in a paradox where if the OutputCommitCoordinator logic + * was not in SparkHadoopWriter, the tests would still pass because only one of the + * increments would be captured even though the commit in both tasks was executed + * erroneously. + */ +class OutputCommitCoordinatorSuite extends FunSuite with BeforeAndAfter { + + var outputCommitCoordinator: OutputCommitCoordinator = null + var tempDir: File = null + var sc: SparkContext = null + + before { + tempDir = Utils.createTempDir() + val conf = new SparkConf() + .setMaster("local[4]") + .setAppName(classOf[OutputCommitCoordinatorSuite].getSimpleName) + .set("spark.speculation", "true") + sc = new SparkContext(conf) { + override private[spark] def createSparkEnv( + conf: SparkConf, + isLocal: Boolean, + listenerBus: LiveListenerBus): SparkEnv = { + outputCommitCoordinator = spy(new OutputCommitCoordinator(conf)) + // Use Mockito.spy() to maintain the default infrastructure everywhere else. + // This mocking allows us to control the coordinator responses in test cases. + SparkEnv.createDriverEnv(conf, isLocal, listenerBus, Some(outputCommitCoordinator)) + } + } + // Use Mockito.spy() to maintain the default infrastructure everywhere else + val mockTaskScheduler = spy(sc.taskScheduler.asInstanceOf[TaskSchedulerImpl]) + + doAnswer(new Answer[Unit]() { + override def answer(invoke: InvocationOnMock): Unit = { + // Submit the tasks, then force the task scheduler to dequeue the + // speculated task + invoke.callRealMethod() + mockTaskScheduler.backend.reviveOffers() + } + }).when(mockTaskScheduler).submitTasks(Matchers.any()) + + doAnswer(new Answer[TaskSetManager]() { + override def answer(invoke: InvocationOnMock): TaskSetManager = { + val taskSet = invoke.getArguments()(0).asInstanceOf[TaskSet] + new TaskSetManager(mockTaskScheduler, taskSet, 4) { + var hasDequeuedSpeculatedTask = false + override def dequeueSpeculativeTask( + execId: String, + host: String, + locality: TaskLocality.Value): Option[(Int, TaskLocality.Value)] = { + if (!hasDequeuedSpeculatedTask) { + hasDequeuedSpeculatedTask = true + Some(0, TaskLocality.PROCESS_LOCAL) + } else { + None + } + } + } + } + }).when(mockTaskScheduler).createTaskSetManager(Matchers.any(), Matchers.any()) + + sc.taskScheduler = mockTaskScheduler + val dagSchedulerWithMockTaskScheduler = new DAGScheduler(sc, mockTaskScheduler) + sc.taskScheduler.setDAGScheduler(dagSchedulerWithMockTaskScheduler) + sc.dagScheduler = dagSchedulerWithMockTaskScheduler + } + + after { + sc.stop() + tempDir.delete() + outputCommitCoordinator = null + } + + test("Only one of two duplicate commit tasks should commit") { + val rdd = sc.parallelize(Seq(1), 1) + sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).commitSuccessfully _, + 0 until rdd.partitions.size, allowLocal = false) + assert(tempDir.list().size === 1) + } + + test("If commit fails, if task is retried it should not be locked, and will succeed.") { + val rdd = sc.parallelize(Seq(1), 1) + sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).failFirstCommitAttempt _, + 0 until rdd.partitions.size, allowLocal = false) + assert(tempDir.list().size === 1) + } + + test("Job should not complete if all commits are denied") { + // Create a mock OutputCommitCoordinator that denies all attempts to commit + doReturn(false).when(outputCommitCoordinator).handleAskPermissionToCommit( + Matchers.any(), Matchers.any(), Matchers.any()) + val rdd: RDD[Int] = sc.parallelize(Seq(1), 1) + def resultHandler(x: Int, y: Unit): Unit = {} + val futureAction: SimpleFutureAction[Unit] = sc.submitJob[Int, Unit, Unit](rdd, + OutputCommitFunctions(tempDir.getAbsolutePath).commitSuccessfully, + 0 until rdd.partitions.size, resultHandler, 0) + // It's an error if the job completes successfully even though no committer was authorized, + // so throw an exception if the job was allowed to complete. + intercept[TimeoutException] { + Await.result(futureAction, 5 seconds) + } + assert(tempDir.list().size === 0) + } + + test("Only authorized committer failures can clear the authorized committer lock (SPARK-6614)") { + val stage: Int = 1 + val partition: Long = 2 + val authorizedCommitter: Long = 3 + val nonAuthorizedCommitter: Long = 100 + outputCommitCoordinator.stageStart(stage) + assert(outputCommitCoordinator.canCommit(stage, partition, attempt = authorizedCommitter)) + assert(!outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter)) + // The non-authorized committer fails + outputCommitCoordinator.taskCompleted( + stage, partition, attempt = nonAuthorizedCommitter, reason = TaskKilled) + // New tasks should still not be able to commit because the authorized committer has not failed + assert( + !outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 1)) + // The authorized committer now fails, clearing the lock + outputCommitCoordinator.taskCompleted( + stage, partition, attempt = authorizedCommitter, reason = TaskKilled) + // A new task should now be allowed to become the authorized committer + assert( + outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 2)) + // There can only be one authorized committer + assert( + !outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 3)) + } +} + +/** + * Class with methods that can be passed to runJob to test commits with a mock committer. + */ +private case class OutputCommitFunctions(tempDirPath: String) { + + // Mock output committer that simulates a successful commit (after commit is authorized) + private def successfulOutputCommitter = new FakeOutputCommitter { + override def commitTask(context: TaskAttemptContext): Unit = { + Utils.createDirectory(tempDirPath) + } + } + + // Mock output committer that simulates a failed commit (after commit is authorized) + private def failingOutputCommitter = new FakeOutputCommitter { + override def commitTask(taskAttemptContext: TaskAttemptContext) { + throw new RuntimeException + } + } + + def commitSuccessfully(iter: Iterator[Int]): Unit = { + val ctx = TaskContext.get() + runCommitWithProvidedCommitter(ctx, iter, successfulOutputCommitter) + } + + def failFirstCommitAttempt(iter: Iterator[Int]): Unit = { + val ctx = TaskContext.get() + runCommitWithProvidedCommitter(ctx, iter, + if (ctx.attemptNumber == 0) failingOutputCommitter else successfulOutputCommitter) + } + + private def runCommitWithProvidedCommitter( + ctx: TaskContext, + iter: Iterator[Int], + outputCommitter: OutputCommitter): Unit = { + def jobConf = new JobConf { + override def getOutputCommitter(): OutputCommitter = outputCommitter + } + val sparkHadoopWriter = new SparkHadoopWriter(jobConf) { + override def newTaskAttemptContext( + conf: JobConf, + attemptId: TaskAttemptID): TaskAttemptContext = { + mock(classOf[TaskAttemptContext]) + } + } + sparkHadoopWriter.setup(ctx.stageId, ctx.partitionId, ctx.attemptNumber) + sparkHadoopWriter.commit() + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 7e360cc6082ec..6de6d2fec622a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.io.{File, PrintWriter} +import java.net.URI import org.json4s.jackson.JsonMethods._ import org.scalatest.{BeforeAndAfter, FunSuite} @@ -61,7 +62,7 @@ class ReplayListenerSuite extends FunSuite with BeforeAndAfter { try { val replayer = new ReplayListenerBus() replayer.addListener(eventMonster) - replayer.replay(logData, SPARK_VERSION) + replayer.replay(logData, logFilePath.toString) } finally { logData.close() } @@ -115,12 +116,12 @@ class ReplayListenerSuite extends FunSuite with BeforeAndAfter { assert(!eventLog.isDir) // Replay events - val (logData, version) = EventLoggingListener.openEventLog(eventLog.getPath(), fileSystem) + val logData = EventLoggingListener.openEventLog(eventLog.getPath(), fileSystem) val eventMonster = new EventMonster(conf) try { val replayer = new ReplayListenerBus() replayer.addListener(eventMonster) - replayer.replay(logData, version) + replayer.replay(logData, eventLog.getPath().toString) } finally { logData.close() } @@ -145,16 +146,9 @@ class ReplayListenerSuite extends FunSuite with BeforeAndAfter { * log the events. */ private class EventMonster(conf: SparkConf) - extends EventLoggingListener("test", "testdir", conf) { + extends EventLoggingListener("test", new URI("testdir"), conf) { override def start() { } } - - private def getCompressionCodec(codecName: String) = { - val conf = new SparkConf - conf.set("spark.io.compression.codec", codecName) - CompressionCodec.createCodec(conf) - } - } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 0fb1bdd30d975..3a41ee8d4ae0c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -20,26 +20,22 @@ package org.apache.spark.scheduler import java.util.concurrent.Semaphore import scala.collection.mutable +import scala.collection.JavaConversions._ -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} -import org.scalatest.Matchers +import org.scalatest.{FunSuite, Matchers} -import org.apache.spark.{LocalSparkContext, SparkContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.util.ResetSystemProperties +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext} -class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers with BeforeAndAfter - with BeforeAndAfterAll with ResetSystemProperties { +class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers + with ResetSystemProperties { /** Length of time to wait while draining listener events. */ val WAIT_TIMEOUT_MILLIS = 10000 val jobCompletionTime = 1421191296660L - before { - sc = new SparkContext("local", "SparkListenerSuite") - } - test("basic creation and shutdown of LiveListenerBus") { val counter = new BasicJobCounter val bus = new LiveListenerBus @@ -127,6 +123,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers } test("basic creation of StageInfo") { + sc = new SparkContext("local", "SparkListenerSuite") val listener = new SaveStageAndTaskInfo sc.addSparkListener(listener) val rdd1 = sc.parallelize(1 to 100, 4) @@ -148,6 +145,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers } test("basic creation of StageInfo with shuffle") { + sc = new SparkContext("local", "SparkListenerSuite") val listener = new SaveStageAndTaskInfo sc.addSparkListener(listener) val rdd1 = sc.parallelize(1 to 100, 4) @@ -185,6 +183,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers } test("StageInfo with fewer tasks than partitions") { + sc = new SparkContext("local", "SparkListenerSuite") val listener = new SaveStageAndTaskInfo sc.addSparkListener(listener) val rdd1 = sc.parallelize(1 to 100, 4) @@ -201,6 +200,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers } test("local metrics") { + sc = new SparkContext("local", "SparkListenerSuite") val listener = new SaveStageAndTaskInfo sc.addSparkListener(listener) sc.addSparkListener(new StatsReportListener) @@ -267,6 +267,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers } test("onTaskGettingResult() called when result fetched remotely") { + sc = new SparkContext("local", "SparkListenerSuite") val listener = new SaveTaskEvents sc.addSparkListener(listener) @@ -287,6 +288,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers } test("onTaskGettingResult() not called when result sent directly") { + sc = new SparkContext("local", "SparkListenerSuite") val listener = new SaveTaskEvents sc.addSparkListener(listener) @@ -302,6 +304,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers } test("onTaskEnd() should be called for all started tasks, even after job has been killed") { + sc = new SparkContext("local", "SparkListenerSuite") val WAIT_TIMEOUT_MILLIS = 10000 val listener = new SaveTaskEvents sc.addSparkListener(listener) @@ -356,6 +359,17 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers assert(jobCounter2.count === 5) } + test("registering listeners via spark.extraListeners") { + val conf = new SparkConf().setMaster("local").setAppName("test") + .set("spark.extraListeners", classOf[ListenerThatAcceptsSparkConf].getName + "," + + classOf[BasicJobCounter].getName) + sc = new SparkContext(conf) + sc.listenerBus.listeners.collect { case x: BasicJobCounter => x}.size should be (1) + sc.listenerBus.listeners.collect { + case x: ListenerThatAcceptsSparkConf => x + }.size should be (1) + } + /** * Assert that the given list of numbers has an average that is greater than zero. */ @@ -363,14 +377,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers assert(m.sum / m.size.toDouble > 0.0, msg) } - /** - * A simple listener that counts the number of jobs observed. - */ - private class BasicJobCounter extends SparkListener { - var count = 0 - override def onJobEnd(job: SparkListenerJobEnd) = count += 1 - } - /** * A simple listener that saves all task infos and task metrics. */ @@ -423,3 +429,19 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers } } + +// These classes can't be declared inside of the SparkListenerSuite class because we don't want +// their constructors to contain references to SparkListenerSuite: + +/** + * A simple listener that counts the number of jobs observed. + */ +private class BasicJobCounter extends SparkListener { + var count = 0 + override def onJobEnd(job: SparkListenerJobEnd) = count += 1 +} + +private class ListenerThatAcceptsSparkConf(conf: SparkConf) extends SparkListener { + var count = 0 + override def onJobEnd(job: SparkListenerJobEnd) = count += 1 +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 84b9b788237bf..12330d8f63c40 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -27,7 +27,7 @@ import org.scalatest.FunSuite import org.apache.spark._ import org.apache.spark.executor.TaskMetrics -import org.apache.spark.util.FakeClock +import org.apache.spark.util.ManualClock class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) extends DAGScheduler(sc) { @@ -164,7 +164,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = FakeTask.createTaskSet(1) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // Offer a host with NO_PREF as the constraint, @@ -213,7 +213,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execC", "host2")) val taskSet = FakeTask.createTaskSet(1, Seq(TaskLocation("host1", "execB"))) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // An executor that is not NODE_LOCAL should be rejected. @@ -234,7 +234,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { Seq(TaskLocation("host1"), TaskLocation("host2", "exec2")), Seq() // Last task has no locality prefs ) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // First offer host1, exec1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0) @@ -263,7 +263,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { Seq(TaskLocation("host2", "exec3")), Seq() // Last task has no locality prefs ) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // First offer host1, exec1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL).get.index === 0) @@ -283,7 +283,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { Seq(TaskLocation("host3")), Seq(TaskLocation("host2")) ) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // First offer host1: first task should be chosen @@ -314,13 +314,14 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("delay scheduling with failed hosts") { sc = new SparkContext("local", "test") - val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"), + ("exec3", "host3")) val taskSet = FakeTask.createTaskSet(3, Seq(TaskLocation("host1")), Seq(TaskLocation("host2")), Seq(TaskLocation("host3")) ) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // First offer host1: first task should be chosen @@ -352,7 +353,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = FakeTask.createTaskSet(1) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0) @@ -369,7 +370,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = FakeTask.createTaskSet(1) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // Fail the task MAX_TASK_FAILURES times, and check that the task set is aborted @@ -401,7 +402,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { ("exec1.1", "host1"), ("exec2", "host2")) // affinity to exec1 on host1 - which we will fail. val taskSet = FakeTask.createTaskSet(1, Seq(TaskLocation("host1", "exec1"))) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, 4, clock) { @@ -485,7 +486,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { Seq(TaskLocation("host1", "execB")), Seq(TaskLocation("host2", "execC")), Seq()) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // Only ANY is valid assert(manager.myLocalityLevels.sameElements(Array(NO_PREF, ANY))) @@ -521,7 +522,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { val taskSet = FakeTask.createTaskSet(2, Seq(TaskLocation("host1", "execA")), Seq(TaskLocation("host1", "execA"))) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY))) @@ -610,7 +611,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { Seq(TaskLocation("host2"), TaskLocation("host1")), Seq(), Seq(TaskLocation("host3", "execC"))) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) assert(manager.resourceOffer("execA", "host1", PROCESS_LOCAL).get.index === 0) @@ -636,7 +637,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { Seq(TaskLocation("host2")), Seq(), Seq(TaskLocation("host3"))) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // node-local tasks are scheduled without delay @@ -649,6 +650,47 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.resourceOffer("execA", "host3", NO_PREF).get.index === 2) } + test("SPARK-4939: node-local tasks should be scheduled right after process-local tasks finished") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2")) + val taskSet = FakeTask.createTaskSet(4, + Seq(TaskLocation("host1")), + Seq(TaskLocation("host2")), + Seq(ExecutorCacheTaskLocation("host1", "execA")), + Seq(ExecutorCacheTaskLocation("host2", "execB"))) + val clock = new ManualClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + + // process-local tasks are scheduled first + assert(manager.resourceOffer("execA", "host1", NODE_LOCAL).get.index === 2) + assert(manager.resourceOffer("execB", "host2", NODE_LOCAL).get.index === 3) + // node-local tasks are scheduled without delay + assert(manager.resourceOffer("execA", "host1", NODE_LOCAL).get.index === 0) + assert(manager.resourceOffer("execB", "host2", NODE_LOCAL).get.index === 1) + assert(manager.resourceOffer("execA", "host1", NODE_LOCAL) == None) + assert(manager.resourceOffer("execB", "host2", NODE_LOCAL) == None) + } + + test("SPARK-4939: no-pref tasks should be scheduled after process-local tasks finished") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2")) + val taskSet = FakeTask.createTaskSet(3, + Seq(), + Seq(ExecutorCacheTaskLocation("host1", "execA")), + Seq(ExecutorCacheTaskLocation("host2", "execB"))) + val clock = new ManualClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + + // process-local tasks are scheduled first + assert(manager.resourceOffer("execA", "host1", PROCESS_LOCAL).get.index === 1) + assert(manager.resourceOffer("execB", "host2", PROCESS_LOCAL).get.index === 2) + // no-pref tasks are scheduled without delay + assert(manager.resourceOffer("execA", "host1", PROCESS_LOCAL) == None) + assert(manager.resourceOffer("execA", "host1", NODE_LOCAL) == None) + assert(manager.resourceOffer("execA", "host1", NO_PREF).get.index === 0) + assert(manager.resourceOffer("execA", "host1", ANY) == None) + } + test("Ensure TaskSetManager is usable after addition of levels") { // Regression test for SPARK-2931 sc = new SparkContext("local", "test") @@ -656,7 +698,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { val taskSet = FakeTask.createTaskSet(2, Seq(TaskLocation("host1", "execA")), Seq(TaskLocation("host2", "execB.1"))) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // Only ANY is valid assert(manager.myLocalityLevels.sameElements(Array(ANY))) @@ -690,7 +732,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { Seq(HostTaskLocation("host1")), Seq(HostTaskLocation("host2")), Seq(HDFSCacheTaskLocation("host3"))) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY))) sched.removeExecutor("execA") diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala index f2ff98eb72daf..afbaa9ade811f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala @@ -17,45 +17,48 @@ package org.apache.spark.scheduler.mesos -import org.apache.spark.executor.MesosExecutorBackend -import org.scalatest.FunSuite -import org.apache.spark.{SparkConf, SparkContext, LocalSparkContext} -import org.apache.spark.scheduler.{SparkListenerExecutorAdded, LiveListenerBus, - TaskDescription, WorkerOffer, TaskSchedulerImpl} -import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.scheduler.cluster.mesos.{MemoryUtils, MesosSchedulerBackend} -import org.apache.mesos.SchedulerDriver -import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, _} -import org.apache.mesos.Protos.Value.Scalar -import org.easymock.{Capture, EasyMock} import java.nio.ByteBuffer -import java.util.Collections import java.util -import org.scalatest.mock.EasyMockSugar +import java.util.Collections import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with EasyMockSugar { +import org.apache.mesos.SchedulerDriver +import org.apache.mesos.Protos._ +import org.apache.mesos.Protos.Value.Scalar +import org.mockito.Mockito._ +import org.mockito.Matchers._ +import org.mockito.{ArgumentCaptor, Matchers} +import org.scalatest.FunSuite +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext} +import org.apache.spark.executor.MesosExecutorBackend +import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerExecutorAdded, + TaskDescription, TaskSchedulerImpl, WorkerOffer} +import org.apache.spark.scheduler.cluster.ExecutorInfo +import org.apache.spark.scheduler.cluster.mesos.{MesosSchedulerBackend, MemoryUtils} + +class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with MockitoSugar { test("check spark-class location correctly") { val conf = new SparkConf conf.set("spark.mesos.executor.home" , "/mesos-home") - val listenerBus = EasyMock.createMock(classOf[LiveListenerBus]) - listenerBus.post(SparkListenerExecutorAdded(EasyMock.anyLong, "s1", new ExecutorInfo("host1", 2))) - EasyMock.replay(listenerBus) - - val sc = EasyMock.createMock(classOf[SparkContext]) - EasyMock.expect(sc.getSparkHome()).andReturn(Option("/spark-home")).anyTimes() - EasyMock.expect(sc.conf).andReturn(conf).anyTimes() - EasyMock.expect(sc.executorEnvs).andReturn(new mutable.HashMap).anyTimes() - EasyMock.expect(sc.executorMemory).andReturn(100).anyTimes() - EasyMock.expect(sc.listenerBus).andReturn(listenerBus) - EasyMock.replay(sc) - val taskScheduler = EasyMock.createMock(classOf[TaskSchedulerImpl]) - EasyMock.expect(taskScheduler.CPUS_PER_TASK).andReturn(2).anyTimes() - EasyMock.replay(taskScheduler) + val listenerBus = mock[LiveListenerBus] + listenerBus.post( + SparkListenerExecutorAdded(anyLong, "s1", new ExecutorInfo("host1", 2, Map.empty))) + + val sc = mock[SparkContext] + when(sc.getSparkHome()).thenReturn(Option("/spark-home")) + + when(sc.conf).thenReturn(conf) + when(sc.executorEnvs).thenReturn(new mutable.HashMap[String, String]) + when(sc.executorMemory).thenReturn(100) + when(sc.listenerBus).thenReturn(listenerBus) + val taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.CPUS_PER_TASK).thenReturn(2) val mesosSchedulerBackend = new MesosSchedulerBackend(taskScheduler, sc, "master") @@ -84,20 +87,19 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Ea .setSlaveId(SlaveID.newBuilder().setValue(s"s${id.toString}")).setHostname(s"host${id.toString}").build() } - val driver = EasyMock.createMock(classOf[SchedulerDriver]) - val taskScheduler = EasyMock.createMock(classOf[TaskSchedulerImpl]) + val driver = mock[SchedulerDriver] + val taskScheduler = mock[TaskSchedulerImpl] - val listenerBus = EasyMock.createMock(classOf[LiveListenerBus]) - listenerBus.post(SparkListenerExecutorAdded(EasyMock.anyLong, "s1", new ExecutorInfo("host1", 2))) - EasyMock.replay(listenerBus) + val listenerBus = mock[LiveListenerBus] + listenerBus.post( + SparkListenerExecutorAdded(anyLong, "s1", new ExecutorInfo("host1", 2, Map.empty))) - val sc = EasyMock.createMock(classOf[SparkContext]) - EasyMock.expect(sc.executorMemory).andReturn(100).anyTimes() - EasyMock.expect(sc.getSparkHome()).andReturn(Option("/path")).anyTimes() - EasyMock.expect(sc.executorEnvs).andReturn(new mutable.HashMap).anyTimes() - EasyMock.expect(sc.conf).andReturn(new SparkConf).anyTimes() - EasyMock.expect(sc.listenerBus).andReturn(listenerBus) - EasyMock.replay(sc) + val sc = mock[SparkContext] + when(sc.executorMemory).thenReturn(100) + when(sc.getSparkHome()).thenReturn(Option("/path")) + when(sc.executorEnvs).thenReturn(new mutable.HashMap[String, String]) + when(sc.conf).thenReturn(new SparkConf) + when(sc.listenerBus).thenReturn(listenerBus) val minMem = MemoryUtils.calculateTotalMemory(sc).toInt val minCpu = 4 @@ -121,25 +123,29 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Ea 2 )) val taskDesc = new TaskDescription(1L, 0, "s1", "n1", 0, ByteBuffer.wrap(new Array[Byte](0))) - EasyMock.expect(taskScheduler.resourceOffers(EasyMock.eq(expectedWorkerOffers))).andReturn(Seq(Seq(taskDesc))) - EasyMock.expect(taskScheduler.CPUS_PER_TASK).andReturn(2).anyTimes() - EasyMock.replay(taskScheduler) + when(taskScheduler.resourceOffers(expectedWorkerOffers)).thenReturn(Seq(Seq(taskDesc))) + when(taskScheduler.CPUS_PER_TASK).thenReturn(2) - val capture = new Capture[util.Collection[TaskInfo]] - EasyMock.expect( + val capture = ArgumentCaptor.forClass(classOf[util.Collection[TaskInfo]]) + when( driver.launchTasks( - EasyMock.eq(Collections.singleton(mesosOffers.get(0).getId)), - EasyMock.capture(capture), - EasyMock.anyObject(classOf[Filters]) + Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), + capture.capture(), + any(classOf[Filters]) ) - ).andReturn(Status.valueOf(1)).once - EasyMock.expect(driver.declineOffer(mesosOffers.get(1).getId)).andReturn(Status.valueOf(1)).times(1) - EasyMock.expect(driver.declineOffer(mesosOffers.get(2).getId)).andReturn(Status.valueOf(1)).times(1) - EasyMock.replay(driver) + ).thenReturn(Status.valueOf(1)) + when(driver.declineOffer(mesosOffers.get(1).getId)).thenReturn(Status.valueOf(1)) + when(driver.declineOffer(mesosOffers.get(2).getId)).thenReturn(Status.valueOf(1)) backend.resourceOffers(driver, mesosOffers) - EasyMock.verify(driver) + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), + capture.capture(), + any(classOf[Filters]) + ) + verify(driver, times(1)).declineOffer(mesosOffers.get(1).getId) + verify(driver, times(1)).declineOffer(mesosOffers.get(2).getId) assert(capture.getValue.size() == 1) val taskInfo = capture.getValue.iterator().next() assert(taskInfo.getName.equals("n1")) @@ -151,15 +157,13 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Ea // Unwanted resources offered on an existing node. Make sure they are declined val mesosOffers2 = new java.util.ArrayList[Offer] mesosOffers2.add(createOffer(1, minMem, minCpu)) - EasyMock.reset(taskScheduler) - EasyMock.reset(driver) - EasyMock.expect(taskScheduler.resourceOffers(EasyMock.anyObject(classOf[Seq[WorkerOffer]])).andReturn(Seq(Seq()))) - EasyMock.expect(taskScheduler.CPUS_PER_TASK).andReturn(2).anyTimes() - EasyMock.replay(taskScheduler) - EasyMock.expect(driver.declineOffer(mesosOffers2.get(0).getId)).andReturn(Status.valueOf(1)).times(1) - EasyMock.replay(driver) + reset(taskScheduler) + reset(driver) + when(taskScheduler.resourceOffers(any(classOf[Seq[WorkerOffer]]))).thenReturn(Seq(Seq())) + when(taskScheduler.CPUS_PER_TASK).thenReturn(2) + when(driver.declineOffer(mesosOffers2.get(0).getId)).thenReturn(Status.valueOf(1)) backend.resourceOffers(driver, mesosOffers2) - EasyMock.verify(driver) + verify(driver, times(1)).declineOffer(mesosOffers2.get(0).getId) } } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala index 855f1b6276089..054a4c64897a9 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala @@ -29,9 +29,9 @@ class KryoSerializerDistributedSuite extends FunSuite { test("kryo objects are serialised consistently in different processes") { val conf = new SparkConf(false) - conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.kryo.registrator", classOf[AppJarRegistrator].getName) - conf.set("spark.task.maxFailures", "1") + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .set("spark.kryo.registrator", classOf[AppJarRegistrator].getName) + .set("spark.task.maxFailures", "1") val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName)) conf.setJars(List(jar.getPath)) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index a70f67af2e62e..6198df84fab3d 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -23,9 +23,10 @@ import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo import org.scalatest.FunSuite -import org.apache.spark.{SparkConf, SharedSparkContext} +import org.apache.spark.{SharedSparkContext, SparkConf} +import org.apache.spark.scheduler.HighlyCompressedMapStatus import org.apache.spark.serializer.KryoTest._ - +import org.apache.spark.storage.BlockManagerId class KryoSerializerSuite extends FunSuite with SharedSparkContext { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") @@ -242,6 +243,38 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { ser.newInstance().deserialize[ClassLoaderTestingObject](bytes) } } + + test("registration of HighlyCompressedMapStatus") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + + // these cases require knowing the internals of RoaringBitmap a little. Blocks span 2^16 + // values, and they use a bitmap (dense) if they have more than 4096 values, and an + // array (sparse) if they use less. So we just create two cases, one sparse and one dense. + // and we use a roaring bitmap for the empty blocks, so we trigger the dense case w/ mostly + // empty blocks + + val ser = new KryoSerializer(conf).newInstance() + val denseBlockSizes = new Array[Long](5000) + val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L) + Seq(denseBlockSizes, sparseBlockSizes).foreach { blockSizes => + ser.serialize(HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes)) + } + } + + test("serialization buffer overflow reporting") { + import org.apache.spark.SparkException + val kryoBufferMaxProperty = "spark.kryoserializer.buffer.max.mb" + + val largeObject = (1 to 1000000).toArray + + val conf = new SparkConf(false) + conf.set(kryoBufferMaxProperty, "1") + + val ser = new KryoSerializer(conf).newInstance() + val thrown = intercept[SparkException](ser.serialize(largeObject)) + assert(thrown.getMessage.contains(kryoBufferMaxProperty)) + } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala index bbc7e1357b90d..c21c92b63ad13 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala @@ -31,6 +31,8 @@ class BlockObjectWriterSuite extends FunSuite { new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics) writer.write(Long.box(20)) + // Record metrics update on every write + assert(writeMetrics.shuffleRecordsWritten === 1) // Metrics don't update on every write assert(writeMetrics.shuffleBytesWritten == 0) // After 32 writes, metrics should update @@ -39,6 +41,7 @@ class BlockObjectWriterSuite extends FunSuite { writer.write(Long.box(i)) } assert(writeMetrics.shuffleBytesWritten > 0) + assert(writeMetrics.shuffleRecordsWritten === 33) writer.commitAndClose() assert(file.length() == writeMetrics.shuffleBytesWritten) } @@ -51,6 +54,8 @@ class BlockObjectWriterSuite extends FunSuite { new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics) writer.write(Long.box(20)) + // Record metrics update on every write + assert(writeMetrics.shuffleRecordsWritten === 1) // Metrics don't update on every write assert(writeMetrics.shuffleBytesWritten == 0) // After 32 writes, metrics should update @@ -59,7 +64,23 @@ class BlockObjectWriterSuite extends FunSuite { writer.write(Long.box(i)) } assert(writeMetrics.shuffleBytesWritten > 0) + assert(writeMetrics.shuffleRecordsWritten === 33) writer.revertPartialWritesAndClose() assert(writeMetrics.shuffleBytesWritten == 0) + assert(writeMetrics.shuffleRecordsWritten == 0) + } + + test("Reopening a closed block writer") { + val file = new File("somefile") + file.deleteOnExit() + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics) + + writer.open() + writer.close() + intercept[IllegalStateException] { + writer.open() + } } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index e85a436cdba17..0d155982a8c54 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -17,27 +17,40 @@ package org.apache.spark.ui +import javax.servlet.http.HttpServletRequest + import scala.collection.JavaConversions._ +import scala.xml.Node -import org.openqa.selenium.{By, WebDriver} import org.openqa.selenium.htmlunit.HtmlUnitDriver +import org.openqa.selenium.{By, WebDriver} import org.scalatest._ import org.scalatest.concurrent.Eventually._ import org.scalatest.selenium.WebBrowser import org.scalatest.time.SpanSugar._ -import org.apache.spark._ import org.apache.spark.LocalSparkContext._ +import org.apache.spark._ import org.apache.spark.api.java.StorageLevels import org.apache.spark.shuffle.FetchFailedException + /** - * Selenium tests for the Spark Web UI. These tests are not run by default - * because they're slow. + * Selenium tests for the Spark Web UI. */ -@DoNotDiscover -class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { - implicit val webDriver: WebDriver = new HtmlUnitDriver +class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with BeforeAndAfterAll { + + implicit var webDriver: WebDriver = _ + + override def beforeAll(): Unit = { + webDriver = new HtmlUnitDriver + } + + override def afterAll(): Unit = { + if (webDriver != null) { + webDriver.quit() + } + } /** * Create a test SparkContext with the SparkUI enabled. @@ -48,6 +61,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { .setMaster("local") .setAppName("test") .set("spark.ui.enabled", "true") + .set("spark.ui.port", "0") val sc = new SparkContext(conf) assert(sc.ui.isDefined) sc @@ -93,7 +107,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { } eventually(timeout(5 seconds), interval(50 milliseconds)) { go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/stages") - find(id("active")).get.text should be("Active Stages (0)") + find(id("active")) should be(None) // Since we hide empty tables find(id("failed")).get.text should be("Failed Stages (1)") } @@ -105,7 +119,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { } eventually(timeout(5 seconds), interval(50 milliseconds)) { go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/stages") - find(id("active")).get.text should be("Active Stages (0)") + find(id("active")) should be(None) // Since we hide empty tables // The failure occurs before the stage becomes active, hence we should still show only one // failed stage, not two: find(id("failed")).get.text should be("Failed Stages (1)") @@ -167,13 +181,14 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { test("job progress bars should handle stage / task failures") { withSpark(newSparkContext()) { sc => - val data = sc.parallelize(Seq(1, 2, 3)).map(identity).groupBy(identity) + val data = sc.parallelize(Seq(1, 2, 3), 1).map(identity).groupBy(identity) val shuffleHandle = data.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle // Simulate fetch failures: val mappedData = data.map { x => val taskContext = TaskContext.get - if (taskContext.attemptNumber == 0) { // Cause this stage to fail on its first attempt. + if (taskContext.taskAttemptId() == 1) { + // Cause the post-shuffle stage to fail on its first attempt with a single task failure val env = SparkEnv.get val bmAddress = env.blockManager.blockManagerId val shuffleId = shuffleHandle.shuffleId @@ -299,4 +314,46 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { } } } + + test("attaching and detaching a new tab") { + withSpark(newSparkContext()) { sc => + val sparkUI = sc.ui.get + + val newTab = new WebUITab(sparkUI, "foo") { + attachPage(new WebUIPage("") { + def render(request: HttpServletRequest): Seq[Node] = { + "html magic" + } + }) + } + sparkUI.attachTab(newTab) + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/")) + find(cssSelector("""ul li a[href*="jobs"]""")) should not be(None) + find(cssSelector("""ul li a[href*="stages"]""")) should not be(None) + find(cssSelector("""ul li a[href*="storage"]""")) should not be(None) + find(cssSelector("""ul li a[href*="environment"]""")) should not be(None) + find(cssSelector("""ul li a[href*="foo"]""")) should not be(None) + } + eventually(timeout(10 seconds), interval(50 milliseconds)) { + // check whether new page exists + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/foo") + find(cssSelector("b")).get.text should include ("html magic") + } + sparkUI.detachTab(newTab) + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/")) + find(cssSelector("""ul li a[href*="jobs"]""")) should not be(None) + find(cssSelector("""ul li a[href*="stages"]""")) should not be(None) + find(cssSelector("""ul li a[href*="storage"]""")) should not be(None) + find(cssSelector("""ul li a[href*="environment"]""")) should not be(None) + find(cssSelector("""ul li a[href*="foo"]""")) should be(None) + } + eventually(timeout(10 seconds), interval(50 milliseconds)) { + // check new page not exist + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/foo") + find(cssSelector("b")) should be(None) + } + } + } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 92a21f82f3c21..77a038dc1720d 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.ui import java.net.ServerSocket -import javax.servlet.http.HttpServletRequest import scala.io.Source import scala.util.{Failure, Success, Try} @@ -28,9 +27,8 @@ import org.scalatest.FunSuite import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkContext, SparkConf} import org.apache.spark.LocalSparkContext._ -import scala.xml.Node +import org.apache.spark.{SparkConf, SparkContext} class UISuite extends FunSuite { @@ -72,40 +70,6 @@ class UISuite extends FunSuite { } } - ignore("attaching a new tab") { - withSpark(newSparkContext()) { sc => - val sparkUI = sc.ui.get - - val newTab = new WebUITab(sparkUI, "foo") { - attachPage(new WebUIPage("") { - def render(request: HttpServletRequest): Seq[Node] = { - "html magic" - } - }) - } - sparkUI.attachTab(newTab) - eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL(sparkUI.appUIAddress).mkString - assert(!html.contains("random data that should not be present")) - - // check whether new page exists - assert(html.toLowerCase.contains("foo")) - - // check whether other pages still exist - assert(html.toLowerCase.contains("stages")) - assert(html.toLowerCase.contains("storage")) - assert(html.toLowerCase.contains("environment")) - assert(html.toLowerCase.contains("executors")) - } - - eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL(sparkUI.appUIAddress.stripSuffix("/") + "/foo").mkString - // check whether new page exists - assert(html.contains("magic")) - } - } - } - test("jetty selects different port under contention") { val server = new ServerSocket(0) val startPort = server.getLocalPort diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 68074ae32a672..c0c28cb60e21d 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ui.jobs +import java.util.Properties + import org.scalatest.FunSuite import org.scalatest.Matchers @@ -44,11 +46,19 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc SparkListenerStageCompleted(stageInfo) } - private def createJobStartEvent(jobId: Int, stageIds: Seq[Int]) = { + private def createJobStartEvent( + jobId: Int, + stageIds: Seq[Int], + jobGroup: Option[String] = None): SparkListenerJobStart = { val stageInfos = stageIds.map { stageId => new StageInfo(stageId, 0, stageId.toString, 0, null, "") } - SparkListenerJobStart(jobId, jobSubmissionTime, stageInfos) + val properties: Option[Properties] = jobGroup.map { groupId => + val props = new Properties() + props.setProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId) + props + } + SparkListenerJobStart(jobId, jobSubmissionTime, stageInfos, properties.orNull) } private def createJobEndEvent(jobId: Int, failed: Boolean = false) = { @@ -88,6 +98,45 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc listener.completedStages.map(_.stageId).toSet should be (Set(50, 49, 48, 47, 46)) } + test("test clearing of stageIdToActiveJobs") { + val conf = new SparkConf() + conf.set("spark.ui.retainedStages", 5.toString) + val listener = new JobProgressListener(conf) + val jobId = 0 + val stageIds = 1 to 50 + // Start a job with 50 stages + listener.onJobStart(createJobStartEvent(jobId, stageIds)) + for (stageId <- stageIds) { + listener.onStageSubmitted(createStageStartEvent(stageId)) + } + listener.stageIdToActiveJobIds.size should be > 0 + + // Complete the stages and job + for (stageId <- stageIds) { + listener.onStageCompleted(createStageEndEvent(stageId, failed = false)) + } + listener.onJobEnd(createJobEndEvent(jobId, false)) + assertActiveJobsStateIsEmpty(listener) + listener.stageIdToActiveJobIds.size should be (0) + } + + test("test clearing of jobGroupToJobIds") { + val conf = new SparkConf() + conf.set("spark.ui.retainedJobs", 5.toString) + val listener = new JobProgressListener(conf) + + // Run 50 jobs, each with one stage + for (jobId <- 0 to 50) { + listener.onJobStart(createJobStartEvent(jobId, Seq(0), jobGroup = Some(jobId.toString))) + listener.onStageSubmitted(createStageStartEvent(0)) + listener.onStageCompleted(createStageEndEvent(0, failed = false)) + listener.onJobEnd(createJobEndEvent(jobId, false)) + } + assertActiveJobsStateIsEmpty(listener) + // This collection won't become empty, but it should be bounded by spark.ui.retainedJobs + listener.jobGroupToJobIds.size should be (5) + } + test("test LRU eviction of jobs") { val conf = new SparkConf() conf.set("spark.ui.retainedStages", 5.toString) @@ -227,6 +276,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics)) taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics) shuffleReadMetrics.incRemoteBytesRead(base + 1) + shuffleReadMetrics.incLocalBytesRead(base + 9) shuffleReadMetrics.incRemoteBlocksFetched(base + 2) shuffleWriteMetrics.incShuffleBytesWritten(base + 3) taskMetrics.setExecutorRunTime(base + 4) @@ -234,7 +284,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc taskMetrics.incMemoryBytesSpilled(base + 6) val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) taskMetrics.setInputMetrics(Some(inputMetrics)) - inputMetrics.addBytesRead(base + 7) + inputMetrics.incBytesRead(base + 7) val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop) taskMetrics.outputMetrics = Some(outputMetrics) outputMetrics.setBytesWritten(base + 8) @@ -260,8 +310,8 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc var stage0Data = listener.stageIdToData.get((0, 0)).get var stage1Data = listener.stageIdToData.get((1, 0)).get - assert(stage0Data.shuffleReadBytes == 102) - assert(stage1Data.shuffleReadBytes == 201) + assert(stage0Data.shuffleReadTotalBytes == 220) + assert(stage1Data.shuffleReadTotalBytes == 410) assert(stage0Data.shuffleWriteBytes == 106) assert(stage1Data.shuffleWriteBytes == 203) assert(stage0Data.executorRunTime == 108) @@ -290,8 +340,11 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc stage0Data = listener.stageIdToData.get((0, 0)).get stage1Data = listener.stageIdToData.get((1, 0)).get - assert(stage0Data.shuffleReadBytes == 402) - assert(stage1Data.shuffleReadBytes == 602) + // Task 1235 contributed (100+1)+(100+9) = 210 shuffle bytes, and task 1234 contributed + // (300+1)+(300+9) = 610 total shuffle bytes, so the total for the stage is 820. + assert(stage0Data.shuffleReadTotalBytes == 820) + // Task 1236 contributed 410 shuffle bytes, and task 1237 contributed 810 shuffle bytes. + assert(stage1Data.shuffleReadTotalBytes == 1220) assert(stage0Data.shuffleWriteBytes == 406) assert(stage1Data.shuffleWriteBytes == 606) assert(stage0Data.executorRunTime == 408) diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index 39e5d367d676c..6250d50fb7036 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.util import java.util.concurrent.TimeoutException import scala.concurrent.Await +import scala.util.{Failure, Try} import akka.actor._ @@ -370,8 +371,12 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val selection = slaveSystem.actorSelection( AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) val timeout = AkkaUtils.lookupTimeout(conf) - intercept[TimeoutException] { - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + val result = Try(Await.result(selection.resolveOne(timeout * 2), timeout)) + + result match { + case Failure(ex: ActorNotFound) => + case Failure(ex: TimeoutException) => + case r => fail(s"$r is neither Failure(ActorNotFound) nor Failure(TimeoutException)") } actorSystem.shutdown() diff --git a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala index 1026cb2aa7cae..47b535206c949 100644 --- a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala @@ -203,4 +203,76 @@ class EventLoopSuite extends FunSuite with Timeouts { assert(!eventLoop.isActive) } } + + test("EventLoop: stop() in onStart should call onStop") { + @volatile var onStopCalled: Boolean = false + val eventLoop = new EventLoop[Int]("test") { + + override def onStart(): Unit = { + stop() + } + + override def onReceive(event: Int): Unit = { + } + + override def onError(e: Throwable): Unit = { + } + + override def onStop(): Unit = { + onStopCalled = true + } + } + eventLoop.start() + eventually(timeout(5 seconds), interval(5 millis)) { + assert(!eventLoop.isActive) + } + assert(onStopCalled) + } + + test("EventLoop: stop() in onReceive should call onStop") { + @volatile var onStopCalled: Boolean = false + val eventLoop = new EventLoop[Int]("test") { + + override def onReceive(event: Int): Unit = { + stop() + } + + override def onError(e: Throwable): Unit = { + } + + override def onStop(): Unit = { + onStopCalled = true + } + } + eventLoop.start() + eventLoop.post(1) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(!eventLoop.isActive) + } + assert(onStopCalled) + } + + test("EventLoop: stop() in onError should call onStop") { + @volatile var onStopCalled: Boolean = false + val eventLoop = new EventLoop[Int]("test") { + + override def onReceive(event: Int): Unit = { + throw new RuntimeException("Oops") + } + + override def onError(e: Throwable): Unit = { + stop() + } + + override def onStop(): Unit = { + onStopCalled = true + } + } + eventLoop.start() + eventLoop.post(1) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(!eventLoop.isActive) + } + assert(onStopCalled) + } } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 6577ebaa2e9a8..a2be724254d7c 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -76,8 +76,9 @@ class JsonProtocolSuite extends FunSuite { val unpersistRdd = SparkListenerUnpersistRDD(12345) val applicationStart = SparkListenerApplicationStart("The winner of all", None, 42L, "Garfield") val applicationEnd = SparkListenerApplicationEnd(42L) + val logUrlMap = Map("stderr" -> "mystderr", "stdout" -> "mystdout").toMap val executorAdded = SparkListenerExecutorAdded(executorAddedTime, "exec1", - new ExecutorInfo("Hostee.awesome.com", 11)) + new ExecutorInfo("Hostee.awesome.com", 11, logUrlMap)) val executorRemoved = SparkListenerExecutorRemoved(executorRemovedTime, "exec2", "test reason") testEvent(stageSubmitted, stageSubmittedJsonString) @@ -100,13 +101,14 @@ class JsonProtocolSuite extends FunSuite { } test("Dependent Classes") { + val logUrlMap = Map("stderr" -> "mystderr", "stdout" -> "mystdout").toMap testRDDInfo(makeRddInfo(2, 3, 4, 5L, 6L)) testStageInfo(makeStageInfo(10, 20, 30, 40L, 50L)) testTaskInfo(makeTaskInfo(999L, 888, 55, 777L, false)) testTaskMetrics(makeTaskMetrics( 33333L, 44444L, 55555L, 66666L, 7, 8, hasHadoopInput = false, hasOutput = false)) testBlockManagerId(BlockManagerId("Hong", "Kong", 500)) - testExecutorInfo(new ExecutorInfo("host", 43)) + testExecutorInfo(new ExecutorInfo("host", 43, logUrlMap)) // StorageLevel testStorageLevel(StorageLevel.NONE) @@ -187,6 +189,34 @@ class JsonProtocolSuite extends FunSuite { assert(newMetrics.inputMetrics.isEmpty) } + test("Input/Output records backwards compatibility") { + // records read were added after 1.2 + val metrics = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, + hasHadoopInput = true, hasOutput = true, hasRecords = false) + assert(metrics.inputMetrics.nonEmpty) + assert(metrics.outputMetrics.nonEmpty) + val newJson = JsonProtocol.taskMetricsToJson(metrics) + val oldJson = newJson.removeField { case (field, _) => field == "Records Read" } + .removeField { case (field, _) => field == "Records Written" } + val newMetrics = JsonProtocol.taskMetricsFromJson(oldJson) + assert(newMetrics.inputMetrics.get.recordsRead == 0) + assert(newMetrics.outputMetrics.get.recordsWritten == 0) + } + + test("Shuffle Read/Write records backwards compatibility") { + // records read were added after 1.2 + val metrics = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, + hasHadoopInput = false, hasOutput = false, hasRecords = false) + assert(metrics.shuffleReadMetrics.nonEmpty) + assert(metrics.shuffleWriteMetrics.nonEmpty) + val newJson = JsonProtocol.taskMetricsToJson(metrics) + val oldJson = newJson.removeField { case (field, _) => field == "Total Records Read" } + .removeField { case (field, _) => field == "Shuffle Records Written" } + val newMetrics = JsonProtocol.taskMetricsFromJson(oldJson) + assert(newMetrics.shuffleReadMetrics.get.recordsRead == 0) + assert(newMetrics.shuffleWriteMetrics.get.shuffleRecordsWritten == 0) + } + test("OutputMetrics backward compatibility") { // OutputMetrics were added after 1.1 val metrics = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, hasHadoopInput = false, hasOutput = true) @@ -230,6 +260,18 @@ class JsonProtocolSuite extends FunSuite { assert(expectedFetchFailed === JsonProtocol.taskEndReasonFromJson(oldEvent)) } + test("ShuffleReadMetrics: Local bytes read and time taken backwards compatibility") { + // Metrics about local shuffle bytes read and local read time were added in 1.3.1. + val metrics = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, + hasHadoopInput = false, hasOutput = false, hasRecords = false) + assert(metrics.shuffleReadMetrics.nonEmpty) + val newJson = JsonProtocol.taskMetricsToJson(metrics) + val oldJson = newJson.removeField { case (field, _) => field == "Local Bytes Read" } + .removeField { case (field, _) => field == "Local Read Time" } + val newMetrics = JsonProtocol.taskMetricsFromJson(oldJson) + assert(newMetrics.shuffleReadMetrics.get.localBytesRead == 0) + } + test("SparkListenerApplicationStart backwards compatibility") { // SparkListenerApplicationStart in Spark 1.0.0 do not have an "appId" property. val applicationStart = SparkListenerApplicationStart("test", None, 1L, "user") @@ -642,7 +684,8 @@ class JsonProtocolSuite extends FunSuite { e: Int, f: Int, hasHadoopInput: Boolean, - hasOutput: Boolean) = { + hasOutput: Boolean, + hasRecords: Boolean = true) = { val t = new TaskMetrics t.setHostname("localhost") t.setExecutorDeserializeTime(a) @@ -654,7 +697,8 @@ class JsonProtocolSuite extends FunSuite { if (hasHadoopInput) { val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) - inputMetrics.addBytesRead(d + e + f) + inputMetrics.incBytesRead(d + e + f) + inputMetrics.incRecordsRead(if (hasRecords) (d + e + f) / 100 else -1) t.setInputMetrics(Some(inputMetrics)) } else { val sr = new ShuffleReadMetrics @@ -662,16 +706,20 @@ class JsonProtocolSuite extends FunSuite { sr.incLocalBlocksFetched(e) sr.incFetchWaitTime(a + d) sr.incRemoteBlocksFetched(f) + sr.incRecordsRead(if (hasRecords) (b + d) / 100 else -1) + sr.incLocalBytesRead(a + f) t.setShuffleReadMetrics(Some(sr)) } if (hasOutput) { val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop) outputMetrics.setBytesWritten(a + b + c) + outputMetrics.setRecordsWritten(if (hasRecords) (a + b + c)/100 else -1) t.outputMetrics = Some(outputMetrics) } else { val sw = new ShuffleWriteMetrics sw.incShuffleBytesWritten(a + b + c) sw.incShuffleWriteTime(b + c + d) + sw.setShuffleRecordsWritten(if (hasRecords) (a + b + c) / 100 else -1) t.shuffleWriteMetrics = Some(sw) } // Make at most 6 blocks @@ -905,11 +953,14 @@ class JsonProtocolSuite extends FunSuite { | "Remote Blocks Fetched": 800, | "Local Blocks Fetched": 700, | "Fetch Wait Time": 900, - | "Remote Bytes Read": 1000 + | "Remote Bytes Read": 1000, + | "Local Bytes Read": 1100, + | "Total Records Read" : 10 | }, | "Shuffle Write Metrics": { | "Shuffle Bytes Written": 1200, - | "Shuffle Write Time": 1500 + | "Shuffle Write Time": 1500, + | "Shuffle Records Written": 12 | }, | "Updated Blocks": [ | { @@ -986,11 +1037,13 @@ class JsonProtocolSuite extends FunSuite { | "Disk Bytes Spilled": 0, | "Shuffle Write Metrics": { | "Shuffle Bytes Written": 1200, - | "Shuffle Write Time": 1500 + | "Shuffle Write Time": 1500, + | "Shuffle Records Written": 12 | }, | "Input Metrics": { | "Data Read Method": "Hadoop", - | "Bytes Read": 2100 + | "Bytes Read": 2100, + | "Records Read": 21 | }, | "Updated Blocks": [ | { @@ -1067,11 +1120,13 @@ class JsonProtocolSuite extends FunSuite { | "Disk Bytes Spilled": 0, | "Input Metrics": { | "Data Read Method": "Hadoop", - | "Bytes Read": 2100 + | "Bytes Read": 2100, + | "Records Read": 21 | }, | "Output Metrics": { | "Data Write Method": "Hadoop", - | "Bytes Written": 1200 + | "Bytes Written": 1200, + | "Records Written": 12 | }, | "Updated Blocks": [ | { @@ -1463,7 +1518,11 @@ class JsonProtocolSuite extends FunSuite { | "Executor ID": "exec1", | "Executor Info": { | "Host": "Hostee.awesome.com", - | "Total Cores": 11 + | "Total Cores": 11, + | "Log Urls" : { + | "stderr" : "mystderr", + | "stdout" : "mystdout" + | } | } |} """ diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorURLClassLoaderSuite.scala b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala similarity index 75% rename from core/src/test/scala/org/apache/spark/executor/ExecutorURLClassLoaderSuite.scala rename to core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala index e2050e95a1b88..31e3b7e7bb71b 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorURLClassLoaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.executor +package org.apache.spark.util import java.net.URLClassLoader @@ -24,32 +24,40 @@ import org.scalatest.FunSuite import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, TestUtils} import org.apache.spark.util.Utils -class ExecutorURLClassLoaderSuite extends FunSuite { +class MutableURLClassLoaderSuite extends FunSuite { - val childClassNames = List("FakeClass1", "FakeClass2") - val parentClassNames = List("FakeClass1", "FakeClass2", "FakeClass3") - val urls = List(TestUtils.createJarWithClasses(childClassNames, "1")).toArray - val urls2 = List(TestUtils.createJarWithClasses(parentClassNames, "2")).toArray + val urls2 = List(TestUtils.createJarWithClasses( + classNames = Seq("FakeClass1", "FakeClass2", "FakeClass3"), + toStringValue = "2")).toArray + val urls = List(TestUtils.createJarWithClasses( + classNames = Seq("FakeClass1"), + classNamesWithBase = Seq(("FakeClass2", "FakeClass3")), // FakeClass3 is in parent + toStringValue = "1", + classpathUrls = urls2)).toArray test("child first") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ChildExecutorURLClassLoader(urls, parentLoader) + val classLoader = new ChildFirstURLClassLoader(urls, parentLoader) val fakeClass = classLoader.loadClass("FakeClass2").newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "1") + val fakeClass2 = classLoader.loadClass("FakeClass2").newInstance() + assert(fakeClass.getClass === fakeClass2.getClass) } test("parent first") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorURLClassLoader(urls, parentLoader) + val classLoader = new MutableURLClassLoader(urls, parentLoader) val fakeClass = classLoader.loadClass("FakeClass1").newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") + val fakeClass2 = classLoader.loadClass("FakeClass1").newInstance() + assert(fakeClass.getClass === fakeClass2.getClass) } test("child first can fall back") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ChildExecutorURLClassLoader(urls, parentLoader) + val classLoader = new ChildFirstURLClassLoader(urls, parentLoader) val fakeClass = classLoader.loadClass("FakeClass3").newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") @@ -57,7 +65,7 @@ class ExecutorURLClassLoaderSuite extends FunSuite { test("child first can fail") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ChildExecutorURLClassLoader(urls, parentLoader) + val classLoader = new ChildFirstURLClassLoader(urls, parentLoader) intercept[java.lang.ClassNotFoundException] { classLoader.loadClass("FakeClassDoesNotExist").newInstance() } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 4544382094f96..fd77753c0d362 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -29,6 +29,9 @@ import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.scalatest.FunSuite +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + import org.apache.spark.SparkConf class UtilsSuite extends FunSuite with ResetSystemProperties { @@ -205,18 +208,18 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { child1.setLastModified(System.currentTimeMillis() - (1000 * 30)) // although child1 is old, child2 is still new so return true - assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5)) + assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5)) child2.setLastModified(System.currentTimeMillis - (1000 * 30)) - assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5)) + assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5)) parent.setLastModified(System.currentTimeMillis - (1000 * 30)) // although parent and its immediate children are new, child3 is still old // we expect a full recursive search for new files. - assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5)) + assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5)) child3.setLastModified(System.currentTimeMillis - (1000 * 30)) - assert(!Utils.doesDirectoryContainAnyNewFiles(parent, 5)) + assert(!Utils.doesDirectoryContainAnyNewFiles(parent, 5)) } test("resolveURI") { @@ -336,21 +339,21 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { assert(!tempDir1.exists()) val tempDir2 = Utils.createTempDir() - val tempFile1 = new File(tempDir2, "foo.txt") - Files.touch(tempFile1) - assert(tempFile1.exists()) - Utils.deleteRecursively(tempFile1) - assert(!tempFile1.exists()) + val sourceFile1 = new File(tempDir2, "foo.txt") + Files.touch(sourceFile1) + assert(sourceFile1.exists()) + Utils.deleteRecursively(sourceFile1) + assert(!sourceFile1.exists()) val tempDir3 = new File(tempDir2, "subdir") assert(tempDir3.mkdir()) - val tempFile2 = new File(tempDir3, "bar.txt") - Files.touch(tempFile2) - assert(tempFile2.exists()) + val sourceFile2 = new File(tempDir3, "bar.txt") + Files.touch(sourceFile2) + assert(sourceFile2.exists()) Utils.deleteRecursively(tempDir2) assert(!tempDir2.exists()) assert(!tempDir3.exists()) - assert(!tempFile2.exists()) + assert(!sourceFile2.exists()) } test("loading properties from file") { @@ -381,4 +384,41 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { require(cnt === 2, "prepare should be called twice") require(time < 500, "preparation time should not count") } + + test("fetch hcfs dir") { + val sourceDir = Utils.createTempDir() + val innerSourceDir = Utils.createTempDir(root=sourceDir.getPath) + val sourceFile = File.createTempFile("someprefix", "somesuffix", innerSourceDir) + val targetDir = new File(Utils.createTempDir(), "target-dir") + Files.write("some text", sourceFile, UTF_8) + + val path = new Path("file://" + sourceDir.getAbsolutePath) + val conf = new Configuration() + val fs = Utils.getHadoopFileSystem(path.toString, conf) + + assert(!targetDir.isDirectory()) + Utils.fetchHcfsFile(path, targetDir, fs, new SparkConf(), conf, false) + assert(targetDir.isDirectory()) + + // Copy again to make sure it doesn't error if the dir already exists. + Utils.fetchHcfsFile(path, targetDir, fs, new SparkConf(), conf, false) + + val destDir = new File(targetDir, sourceDir.getName()) + assert(destDir.isDirectory()) + + val destInnerDir = new File(destDir, innerSourceDir.getName) + assert(destInnerDir.isDirectory()) + + val destInnerFile = new File(destInnerDir, sourceFile.getName) + assert(destInnerFile.isFile()) + + val filePath = new Path("file://" + sourceFile.getAbsolutePath) + val testFileDir = new File("test-filename") + val testFileName = "testFName" + val testFilefs = Utils.getHadoopFileSystem(filePath.toString, conf) + Utils.fetchHcfsFile(filePath, testFileDir, testFilefs, new SparkConf(), + conf, false, Some(testFileName)) + val newFileName = new File(testFileDir, testFileName) + assert(newFileName.isFile()) + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala index 0cb1ed7397655..e0d6cc16bde05 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala @@ -65,6 +65,13 @@ class SorterSuite extends FunSuite { } } + // http://www.envisage-project.eu/timsort-specification-and-verification/ + test("SPARK-5984 TimSort bug") { + val data = TestTimSort.getTimSortBugTestSet(67108864) + new Sorter(new IntArraySortDataFormat).sort(data, 0, data.length, Ordering.Int) + (0 to data.length - 2).foreach(i => assert(data(i) <= data(i + 1))) + } + /** Runs an experiment several times. */ def runExperiment(name: String, skip: Boolean = false)(f: => Unit, prepare: () => Unit): Unit = { if (skip) { diff --git a/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala b/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala index 4918e2d92beb4..daa795a043495 100644 --- a/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala +++ b/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala @@ -44,13 +44,21 @@ class ImplicitSuite { } def testRddToSequenceFileRDDFunctions(): Unit = { - // TODO eliminating `import intToIntWritable` needs refactoring SequenceFileRDDFunctions. - // That will be a breaking change. - import org.apache.spark.SparkContext.intToIntWritable val rdd: org.apache.spark.rdd.RDD[(Int, Int)] = mockRDD rdd.saveAsSequenceFile("/a/test/path") } + def testRddToSequenceFileRDDFunctionsWithWritable(): Unit = { + val rdd: org.apache.spark.rdd.RDD[(org.apache.hadoop.io.IntWritable, org.apache.hadoop.io.Text)] + = mockRDD + rdd.saveAsSequenceFile("/a/test/path") + } + + def testRddToSequenceFileRDDFunctionsWithBytesArray(): Unit = { + val rdd: org.apache.spark.rdd.RDD[(Int, Array[Byte])] = mockRDD + rdd.saveAsSequenceFile("/a/test/path") + } + def testRddToOrderedRDDFunctions(): Unit = { val rdd: org.apache.spark.rdd.RDD[(Int, Int)] = mockRDD rdd.sortByKey() diff --git a/data/mllib/als/sample_movielens_movies.txt b/data/mllib/als/sample_movielens_movies.txt new file mode 100644 index 0000000000000..934a0253849e1 --- /dev/null +++ b/data/mllib/als/sample_movielens_movies.txt @@ -0,0 +1,100 @@ +0::Movie 0::Romance|Comedy +1::Movie 1::Action|Anime +2::Movie 2::Romance|Thriller +3::Movie 3::Action|Romance +4::Movie 4::Anime|Comedy +5::Movie 5::Action|Action +6::Movie 6::Action|Comedy +7::Movie 7::Anime|Comedy +8::Movie 8::Comedy|Action +9::Movie 9::Anime|Thriller +10::Movie 10::Action|Anime +11::Movie 11::Action|Anime +12::Movie 12::Anime|Comedy +13::Movie 13::Thriller|Action +14::Movie 14::Anime|Comedy +15::Movie 15::Comedy|Thriller +16::Movie 16::Anime|Romance +17::Movie 17::Thriller|Action +18::Movie 18::Action|Comedy +19::Movie 19::Anime|Romance +20::Movie 20::Action|Anime +21::Movie 21::Romance|Thriller +22::Movie 22::Romance|Romance +23::Movie 23::Comedy|Comedy +24::Movie 24::Anime|Action +25::Movie 25::Comedy|Comedy +26::Movie 26::Anime|Romance +27::Movie 27::Anime|Anime +28::Movie 28::Thriller|Anime +29::Movie 29::Anime|Romance +30::Movie 30::Thriller|Romance +31::Movie 31::Thriller|Romance +32::Movie 32::Comedy|Anime +33::Movie 33::Comedy|Comedy +34::Movie 34::Anime|Anime +35::Movie 35::Action|Thriller +36::Movie 36::Anime|Romance +37::Movie 37::Romance|Anime +38::Movie 38::Thriller|Romance +39::Movie 39::Romance|Comedy +40::Movie 40::Action|Anime +41::Movie 41::Comedy|Thriller +42::Movie 42::Comedy|Action +43::Movie 43::Thriller|Anime +44::Movie 44::Anime|Action +45::Movie 45::Comedy|Romance +46::Movie 46::Comedy|Action +47::Movie 47::Romance|Comedy +48::Movie 48::Action|Comedy +49::Movie 49::Romance|Romance +50::Movie 50::Comedy|Romance +51::Movie 51::Action|Action +52::Movie 52::Thriller|Action +53::Movie 53::Action|Action +54::Movie 54::Romance|Thriller +55::Movie 55::Anime|Romance +56::Movie 56::Comedy|Action +57::Movie 57::Action|Anime +58::Movie 58::Thriller|Romance +59::Movie 59::Thriller|Comedy +60::Movie 60::Anime|Comedy +61::Movie 61::Comedy|Action +62::Movie 62::Comedy|Romance +63::Movie 63::Romance|Thriller +64::Movie 64::Romance|Action +65::Movie 65::Anime|Romance +66::Movie 66::Comedy|Action +67::Movie 67::Thriller|Anime +68::Movie 68::Thriller|Romance +69::Movie 69::Action|Comedy +70::Movie 70::Thriller|Thriller +71::Movie 71::Action|Comedy +72::Movie 72::Thriller|Romance +73::Movie 73::Comedy|Action +74::Movie 74::Action|Action +75::Movie 75::Action|Action +76::Movie 76::Comedy|Comedy +77::Movie 77::Comedy|Comedy +78::Movie 78::Comedy|Comedy +79::Movie 79::Thriller|Thriller +80::Movie 80::Comedy|Anime +81::Movie 81::Comedy|Anime +82::Movie 82::Romance|Anime +83::Movie 83::Comedy|Thriller +84::Movie 84::Anime|Action +85::Movie 85::Thriller|Anime +86::Movie 86::Romance|Anime +87::Movie 87::Thriller|Thriller +88::Movie 88::Romance|Thriller +89::Movie 89::Action|Anime +90::Movie 90::Anime|Romance +91::Movie 91::Anime|Thriller +92::Movie 92::Action|Comedy +93::Movie 93::Romance|Thriller +94::Movie 94::Thriller|Comedy +95::Movie 95::Action|Action +96::Movie 96::Thriller|Romance +97::Movie 97::Thriller|Thriller +98::Movie 98::Thriller|Comedy +99::Movie 99::Thriller|Romance diff --git a/data/mllib/als/sample_movielens_ratings.txt b/data/mllib/als/sample_movielens_ratings.txt new file mode 100644 index 0000000000000..0889142950797 --- /dev/null +++ b/data/mllib/als/sample_movielens_ratings.txt @@ -0,0 +1,1501 @@ +0::2::3::1424380312 +0::3::1::1424380312 +0::5::2::1424380312 +0::9::4::1424380312 +0::11::1::1424380312 +0::12::2::1424380312 +0::15::1::1424380312 +0::17::1::1424380312 +0::19::1::1424380312 +0::21::1::1424380312 +0::23::1::1424380312 +0::26::3::1424380312 +0::27::1::1424380312 +0::28::1::1424380312 +0::29::1::1424380312 +0::30::1::1424380312 +0::31::1::1424380312 +0::34::1::1424380312 +0::37::1::1424380312 +0::41::2::1424380312 +0::44::1::1424380312 +0::45::2::1424380312 +0::46::1::1424380312 +0::47::1::1424380312 +0::48::1::1424380312 +0::50::1::1424380312 +0::51::1::1424380312 +0::54::1::1424380312 +0::55::1::1424380312 +0::59::2::1424380312 +0::61::2::1424380312 +0::64::1::1424380312 +0::67::1::1424380312 +0::68::1::1424380312 +0::69::1::1424380312 +0::71::1::1424380312 +0::72::1::1424380312 +0::77::2::1424380312 +0::79::1::1424380312 +0::83::1::1424380312 +0::87::1::1424380312 +0::89::2::1424380312 +0::91::3::1424380312 +0::92::4::1424380312 +0::94::1::1424380312 +0::95::2::1424380312 +0::96::1::1424380312 +0::98::1::1424380312 +0::99::1::1424380312 +1::2::2::1424380312 +1::3::1::1424380312 +1::4::2::1424380312 +1::6::1::1424380312 +1::9::3::1424380312 +1::12::1::1424380312 +1::13::1::1424380312 +1::14::1::1424380312 +1::16::1::1424380312 +1::19::1::1424380312 +1::21::3::1424380312 +1::27::1::1424380312 +1::28::3::1424380312 +1::33::1::1424380312 +1::36::2::1424380312 +1::37::1::1424380312 +1::40::1::1424380312 +1::41::2::1424380312 +1::43::1::1424380312 +1::44::1::1424380312 +1::47::1::1424380312 +1::50::1::1424380312 +1::54::1::1424380312 +1::56::2::1424380312 +1::57::1::1424380312 +1::58::1::1424380312 +1::60::1::1424380312 +1::62::4::1424380312 +1::63::1::1424380312 +1::67::1::1424380312 +1::68::4::1424380312 +1::70::2::1424380312 +1::72::1::1424380312 +1::73::1::1424380312 +1::74::2::1424380312 +1::76::1::1424380312 +1::77::3::1424380312 +1::78::1::1424380312 +1::81::1::1424380312 +1::82::1::1424380312 +1::85::3::1424380312 +1::86::2::1424380312 +1::88::2::1424380312 +1::91::1::1424380312 +1::92::2::1424380312 +1::93::1::1424380312 +1::94::2::1424380312 +1::96::1::1424380312 +1::97::1::1424380312 +2::4::3::1424380312 +2::6::1::1424380312 +2::8::5::1424380312 +2::9::1::1424380312 +2::10::1::1424380312 +2::12::3::1424380312 +2::13::1::1424380312 +2::15::2::1424380312 +2::18::2::1424380312 +2::19::4::1424380312 +2::22::1::1424380312 +2::26::1::1424380312 +2::28::1::1424380312 +2::34::4::1424380312 +2::35::1::1424380312 +2::37::5::1424380312 +2::38::1::1424380312 +2::39::5::1424380312 +2::40::4::1424380312 +2::47::1::1424380312 +2::50::1::1424380312 +2::52::2::1424380312 +2::54::1::1424380312 +2::55::1::1424380312 +2::57::2::1424380312 +2::58::2::1424380312 +2::59::1::1424380312 +2::61::1::1424380312 +2::62::1::1424380312 +2::64::1::1424380312 +2::65::1::1424380312 +2::66::3::1424380312 +2::68::1::1424380312 +2::71::3::1424380312 +2::76::1::1424380312 +2::77::1::1424380312 +2::78::1::1424380312 +2::80::1::1424380312 +2::83::5::1424380312 +2::85::1::1424380312 +2::87::2::1424380312 +2::88::1::1424380312 +2::89::4::1424380312 +2::90::1::1424380312 +2::92::4::1424380312 +2::93::5::1424380312 +3::0::1::1424380312 +3::1::1::1424380312 +3::2::1::1424380312 +3::7::3::1424380312 +3::8::3::1424380312 +3::9::1::1424380312 +3::14::1::1424380312 +3::15::1::1424380312 +3::16::1::1424380312 +3::18::4::1424380312 +3::19::1::1424380312 +3::24::3::1424380312 +3::26::1::1424380312 +3::29::3::1424380312 +3::33::1::1424380312 +3::34::3::1424380312 +3::35::1::1424380312 +3::36::3::1424380312 +3::37::1::1424380312 +3::38::2::1424380312 +3::43::1::1424380312 +3::44::1::1424380312 +3::46::1::1424380312 +3::47::1::1424380312 +3::51::5::1424380312 +3::52::3::1424380312 +3::56::1::1424380312 +3::58::1::1424380312 +3::60::3::1424380312 +3::62::1::1424380312 +3::65::2::1424380312 +3::66::1::1424380312 +3::67::1::1424380312 +3::68::2::1424380312 +3::70::1::1424380312 +3::72::2::1424380312 +3::76::3::1424380312 +3::79::3::1424380312 +3::80::4::1424380312 +3::81::1::1424380312 +3::83::1::1424380312 +3::84::1::1424380312 +3::86::1::1424380312 +3::87::2::1424380312 +3::88::4::1424380312 +3::89::1::1424380312 +3::91::1::1424380312 +3::94::3::1424380312 +4::1::1::1424380312 +4::6::1::1424380312 +4::8::1::1424380312 +4::9::1::1424380312 +4::10::1::1424380312 +4::11::1::1424380312 +4::12::1::1424380312 +4::13::1::1424380312 +4::14::2::1424380312 +4::15::1::1424380312 +4::17::1::1424380312 +4::20::1::1424380312 +4::22::1::1424380312 +4::23::1::1424380312 +4::24::1::1424380312 +4::29::4::1424380312 +4::30::1::1424380312 +4::31::1::1424380312 +4::34::1::1424380312 +4::35::1::1424380312 +4::36::1::1424380312 +4::39::2::1424380312 +4::40::3::1424380312 +4::41::4::1424380312 +4::43::2::1424380312 +4::44::1::1424380312 +4::45::1::1424380312 +4::46::1::1424380312 +4::47::1::1424380312 +4::49::2::1424380312 +4::50::1::1424380312 +4::51::1::1424380312 +4::52::4::1424380312 +4::54::1::1424380312 +4::55::1::1424380312 +4::60::3::1424380312 +4::61::1::1424380312 +4::62::4::1424380312 +4::63::3::1424380312 +4::65::1::1424380312 +4::67::2::1424380312 +4::69::1::1424380312 +4::70::4::1424380312 +4::71::1::1424380312 +4::73::1::1424380312 +4::78::1::1424380312 +4::84::1::1424380312 +4::85::1::1424380312 +4::87::3::1424380312 +4::88::3::1424380312 +4::89::2::1424380312 +4::96::1::1424380312 +4::97::1::1424380312 +4::98::1::1424380312 +4::99::1::1424380312 +5::0::1::1424380312 +5::1::1::1424380312 +5::4::1::1424380312 +5::5::1::1424380312 +5::8::1::1424380312 +5::9::3::1424380312 +5::10::2::1424380312 +5::13::3::1424380312 +5::15::1::1424380312 +5::19::1::1424380312 +5::20::3::1424380312 +5::21::2::1424380312 +5::23::3::1424380312 +5::27::1::1424380312 +5::28::1::1424380312 +5::29::1::1424380312 +5::31::1::1424380312 +5::36::3::1424380312 +5::38::2::1424380312 +5::39::1::1424380312 +5::42::1::1424380312 +5::48::3::1424380312 +5::49::4::1424380312 +5::50::3::1424380312 +5::51::1::1424380312 +5::52::1::1424380312 +5::54::1::1424380312 +5::55::5::1424380312 +5::56::3::1424380312 +5::58::1::1424380312 +5::60::1::1424380312 +5::61::1::1424380312 +5::64::3::1424380312 +5::65::2::1424380312 +5::68::4::1424380312 +5::70::1::1424380312 +5::71::1::1424380312 +5::72::1::1424380312 +5::74::1::1424380312 +5::79::1::1424380312 +5::81::2::1424380312 +5::84::1::1424380312 +5::85::1::1424380312 +5::86::1::1424380312 +5::88::1::1424380312 +5::90::4::1424380312 +5::91::2::1424380312 +5::95::2::1424380312 +5::99::1::1424380312 +6::0::1::1424380312 +6::1::1::1424380312 +6::2::3::1424380312 +6::5::1::1424380312 +6::6::1::1424380312 +6::9::1::1424380312 +6::10::1::1424380312 +6::15::2::1424380312 +6::16::2::1424380312 +6::17::1::1424380312 +6::18::1::1424380312 +6::20::1::1424380312 +6::21::1::1424380312 +6::22::1::1424380312 +6::24::1::1424380312 +6::25::5::1424380312 +6::26::1::1424380312 +6::28::1::1424380312 +6::30::1::1424380312 +6::33::1::1424380312 +6::38::1::1424380312 +6::39::1::1424380312 +6::43::4::1424380312 +6::44::1::1424380312 +6::45::1::1424380312 +6::48::1::1424380312 +6::49::1::1424380312 +6::50::1::1424380312 +6::53::1::1424380312 +6::54::1::1424380312 +6::55::1::1424380312 +6::56::1::1424380312 +6::58::4::1424380312 +6::59::1::1424380312 +6::60::1::1424380312 +6::61::3::1424380312 +6::63::3::1424380312 +6::66::1::1424380312 +6::67::3::1424380312 +6::68::1::1424380312 +6::69::1::1424380312 +6::71::2::1424380312 +6::73::1::1424380312 +6::75::1::1424380312 +6::77::1::1424380312 +6::79::1::1424380312 +6::81::1::1424380312 +6::84::1::1424380312 +6::85::3::1424380312 +6::86::1::1424380312 +6::87::1::1424380312 +6::88::1::1424380312 +6::89::1::1424380312 +6::91::2::1424380312 +6::94::1::1424380312 +6::95::2::1424380312 +6::96::1::1424380312 +7::1::1::1424380312 +7::2::2::1424380312 +7::3::1::1424380312 +7::4::1::1424380312 +7::7::1::1424380312 +7::10::1::1424380312 +7::11::2::1424380312 +7::14::2::1424380312 +7::15::1::1424380312 +7::16::1::1424380312 +7::18::1::1424380312 +7::21::1::1424380312 +7::22::1::1424380312 +7::23::1::1424380312 +7::25::5::1424380312 +7::26::1::1424380312 +7::29::4::1424380312 +7::30::1::1424380312 +7::31::3::1424380312 +7::32::1::1424380312 +7::33::1::1424380312 +7::35::1::1424380312 +7::37::2::1424380312 +7::39::3::1424380312 +7::40::2::1424380312 +7::42::2::1424380312 +7::44::1::1424380312 +7::45::2::1424380312 +7::47::4::1424380312 +7::48::1::1424380312 +7::49::1::1424380312 +7::53::1::1424380312 +7::54::1::1424380312 +7::55::1::1424380312 +7::56::1::1424380312 +7::59::1::1424380312 +7::61::2::1424380312 +7::62::3::1424380312 +7::63::2::1424380312 +7::66::1::1424380312 +7::67::3::1424380312 +7::74::1::1424380312 +7::75::1::1424380312 +7::76::3::1424380312 +7::77::1::1424380312 +7::81::1::1424380312 +7::82::1::1424380312 +7::84::2::1424380312 +7::85::4::1424380312 +7::86::1::1424380312 +7::92::2::1424380312 +7::96::1::1424380312 +7::97::1::1424380312 +7::98::1::1424380312 +8::0::1::1424380312 +8::2::4::1424380312 +8::3::2::1424380312 +8::4::2::1424380312 +8::5::1::1424380312 +8::7::1::1424380312 +8::9::1::1424380312 +8::11::1::1424380312 +8::15::1::1424380312 +8::18::1::1424380312 +8::19::1::1424380312 +8::21::1::1424380312 +8::29::5::1424380312 +8::31::3::1424380312 +8::33::1::1424380312 +8::35::1::1424380312 +8::36::1::1424380312 +8::40::2::1424380312 +8::44::1::1424380312 +8::45::1::1424380312 +8::50::1::1424380312 +8::51::1::1424380312 +8::52::5::1424380312 +8::53::5::1424380312 +8::54::1::1424380312 +8::55::1::1424380312 +8::56::1::1424380312 +8::58::4::1424380312 +8::60::3::1424380312 +8::62::4::1424380312 +8::64::1::1424380312 +8::67::3::1424380312 +8::69::1::1424380312 +8::71::1::1424380312 +8::72::3::1424380312 +8::77::3::1424380312 +8::78::1::1424380312 +8::79::1::1424380312 +8::83::1::1424380312 +8::85::5::1424380312 +8::86::1::1424380312 +8::88::1::1424380312 +8::90::1::1424380312 +8::92::2::1424380312 +8::95::4::1424380312 +8::96::3::1424380312 +8::97::1::1424380312 +8::98::1::1424380312 +8::99::1::1424380312 +9::2::3::1424380312 +9::3::1::1424380312 +9::4::1::1424380312 +9::5::1::1424380312 +9::6::1::1424380312 +9::7::5::1424380312 +9::9::1::1424380312 +9::12::1::1424380312 +9::14::3::1424380312 +9::15::1::1424380312 +9::19::1::1424380312 +9::21::1::1424380312 +9::22::1::1424380312 +9::24::1::1424380312 +9::25::1::1424380312 +9::26::1::1424380312 +9::30::3::1424380312 +9::32::4::1424380312 +9::35::2::1424380312 +9::36::2::1424380312 +9::37::2::1424380312 +9::38::1::1424380312 +9::39::1::1424380312 +9::43::3::1424380312 +9::49::5::1424380312 +9::50::3::1424380312 +9::53::1::1424380312 +9::54::1::1424380312 +9::58::1::1424380312 +9::59::1::1424380312 +9::60::1::1424380312 +9::61::1::1424380312 +9::63::3::1424380312 +9::64::3::1424380312 +9::68::1::1424380312 +9::69::1::1424380312 +9::70::3::1424380312 +9::71::1::1424380312 +9::73::2::1424380312 +9::75::1::1424380312 +9::77::2::1424380312 +9::81::2::1424380312 +9::82::1::1424380312 +9::83::1::1424380312 +9::84::1::1424380312 +9::86::1::1424380312 +9::87::4::1424380312 +9::88::1::1424380312 +9::90::3::1424380312 +9::94::2::1424380312 +9::95::3::1424380312 +9::97::2::1424380312 +9::98::1::1424380312 +10::0::3::1424380312 +10::2::4::1424380312 +10::4::3::1424380312 +10::7::1::1424380312 +10::8::1::1424380312 +10::10::1::1424380312 +10::13::2::1424380312 +10::14::1::1424380312 +10::16::2::1424380312 +10::17::1::1424380312 +10::18::1::1424380312 +10::21::1::1424380312 +10::22::1::1424380312 +10::24::1::1424380312 +10::25::3::1424380312 +10::28::1::1424380312 +10::35::1::1424380312 +10::36::1::1424380312 +10::37::1::1424380312 +10::38::1::1424380312 +10::39::1::1424380312 +10::40::4::1424380312 +10::41::2::1424380312 +10::42::3::1424380312 +10::43::1::1424380312 +10::49::3::1424380312 +10::50::1::1424380312 +10::51::1::1424380312 +10::52::1::1424380312 +10::55::2::1424380312 +10::56::1::1424380312 +10::58::1::1424380312 +10::63::1::1424380312 +10::66::1::1424380312 +10::67::2::1424380312 +10::68::1::1424380312 +10::75::1::1424380312 +10::77::1::1424380312 +10::79::1::1424380312 +10::86::1::1424380312 +10::89::3::1424380312 +10::90::1::1424380312 +10::97::1::1424380312 +10::98::1::1424380312 +11::0::1::1424380312 +11::6::2::1424380312 +11::9::1::1424380312 +11::10::1::1424380312 +11::11::1::1424380312 +11::12::1::1424380312 +11::13::4::1424380312 +11::16::1::1424380312 +11::18::5::1424380312 +11::19::4::1424380312 +11::20::1::1424380312 +11::21::1::1424380312 +11::22::1::1424380312 +11::23::5::1424380312 +11::25::1::1424380312 +11::27::5::1424380312 +11::30::5::1424380312 +11::32::5::1424380312 +11::35::3::1424380312 +11::36::2::1424380312 +11::37::2::1424380312 +11::38::4::1424380312 +11::39::1::1424380312 +11::40::1::1424380312 +11::41::1::1424380312 +11::43::2::1424380312 +11::45::1::1424380312 +11::47::1::1424380312 +11::48::5::1424380312 +11::50::4::1424380312 +11::51::3::1424380312 +11::59::1::1424380312 +11::61::1::1424380312 +11::62::1::1424380312 +11::64::1::1424380312 +11::66::4::1424380312 +11::67::1::1424380312 +11::69::5::1424380312 +11::70::1::1424380312 +11::71::3::1424380312 +11::72::3::1424380312 +11::75::3::1424380312 +11::76::1::1424380312 +11::77::1::1424380312 +11::78::1::1424380312 +11::79::5::1424380312 +11::80::3::1424380312 +11::81::4::1424380312 +11::82::1::1424380312 +11::86::1::1424380312 +11::88::1::1424380312 +11::89::1::1424380312 +11::90::4::1424380312 +11::94::2::1424380312 +11::97::3::1424380312 +11::99::1::1424380312 +12::2::1::1424380312 +12::4::1::1424380312 +12::6::1::1424380312 +12::7::3::1424380312 +12::8::1::1424380312 +12::14::1::1424380312 +12::15::2::1424380312 +12::16::4::1424380312 +12::17::5::1424380312 +12::18::2::1424380312 +12::21::1::1424380312 +12::22::2::1424380312 +12::23::3::1424380312 +12::24::1::1424380312 +12::25::1::1424380312 +12::27::5::1424380312 +12::30::2::1424380312 +12::31::4::1424380312 +12::35::5::1424380312 +12::38::1::1424380312 +12::41::1::1424380312 +12::44::2::1424380312 +12::45::1::1424380312 +12::50::4::1424380312 +12::51::1::1424380312 +12::52::1::1424380312 +12::53::1::1424380312 +12::54::1::1424380312 +12::56::2::1424380312 +12::57::1::1424380312 +12::60::1::1424380312 +12::63::1::1424380312 +12::64::5::1424380312 +12::66::3::1424380312 +12::67::1::1424380312 +12::70::1::1424380312 +12::72::1::1424380312 +12::74::1::1424380312 +12::75::1::1424380312 +12::77::1::1424380312 +12::78::1::1424380312 +12::79::3::1424380312 +12::82::2::1424380312 +12::83::1::1424380312 +12::84::1::1424380312 +12::85::1::1424380312 +12::86::1::1424380312 +12::87::1::1424380312 +12::88::1::1424380312 +12::91::3::1424380312 +12::92::1::1424380312 +12::94::4::1424380312 +12::95::2::1424380312 +12::96::1::1424380312 +12::98::2::1424380312 +13::0::1::1424380312 +13::3::1::1424380312 +13::4::2::1424380312 +13::5::1::1424380312 +13::6::1::1424380312 +13::12::1::1424380312 +13::14::2::1424380312 +13::15::1::1424380312 +13::17::1::1424380312 +13::18::3::1424380312 +13::20::1::1424380312 +13::21::1::1424380312 +13::22::1::1424380312 +13::26::1::1424380312 +13::27::1::1424380312 +13::29::3::1424380312 +13::31::1::1424380312 +13::33::1::1424380312 +13::40::2::1424380312 +13::43::2::1424380312 +13::44::1::1424380312 +13::45::1::1424380312 +13::49::1::1424380312 +13::51::1::1424380312 +13::52::2::1424380312 +13::53::3::1424380312 +13::54::1::1424380312 +13::62::1::1424380312 +13::63::2::1424380312 +13::64::1::1424380312 +13::68::1::1424380312 +13::71::1::1424380312 +13::72::3::1424380312 +13::73::1::1424380312 +13::74::3::1424380312 +13::77::2::1424380312 +13::78::1::1424380312 +13::79::2::1424380312 +13::83::3::1424380312 +13::85::1::1424380312 +13::86::1::1424380312 +13::87::2::1424380312 +13::88::2::1424380312 +13::90::1::1424380312 +13::93::4::1424380312 +13::94::1::1424380312 +13::98::1::1424380312 +13::99::1::1424380312 +14::1::1::1424380312 +14::3::3::1424380312 +14::4::1::1424380312 +14::5::1::1424380312 +14::6::1::1424380312 +14::7::1::1424380312 +14::9::1::1424380312 +14::10::1::1424380312 +14::11::1::1424380312 +14::12::1::1424380312 +14::13::1::1424380312 +14::14::3::1424380312 +14::15::1::1424380312 +14::16::1::1424380312 +14::17::1::1424380312 +14::20::1::1424380312 +14::21::1::1424380312 +14::24::1::1424380312 +14::25::2::1424380312 +14::27::1::1424380312 +14::28::1::1424380312 +14::29::5::1424380312 +14::31::3::1424380312 +14::34::1::1424380312 +14::36::1::1424380312 +14::37::2::1424380312 +14::39::2::1424380312 +14::40::1::1424380312 +14::44::1::1424380312 +14::45::1::1424380312 +14::47::3::1424380312 +14::48::1::1424380312 +14::49::1::1424380312 +14::51::1::1424380312 +14::52::5::1424380312 +14::53::3::1424380312 +14::54::1::1424380312 +14::55::1::1424380312 +14::56::1::1424380312 +14::62::4::1424380312 +14::63::5::1424380312 +14::67::3::1424380312 +14::68::1::1424380312 +14::69::3::1424380312 +14::71::1::1424380312 +14::72::4::1424380312 +14::73::1::1424380312 +14::76::5::1424380312 +14::79::1::1424380312 +14::82::1::1424380312 +14::83::1::1424380312 +14::88::1::1424380312 +14::93::3::1424380312 +14::94::1::1424380312 +14::95::2::1424380312 +14::96::4::1424380312 +14::98::1::1424380312 +15::0::1::1424380312 +15::1::4::1424380312 +15::2::1::1424380312 +15::5::2::1424380312 +15::6::1::1424380312 +15::7::1::1424380312 +15::13::1::1424380312 +15::14::1::1424380312 +15::15::1::1424380312 +15::17::2::1424380312 +15::19::2::1424380312 +15::22::2::1424380312 +15::23::2::1424380312 +15::25::1::1424380312 +15::26::3::1424380312 +15::27::1::1424380312 +15::28::2::1424380312 +15::29::1::1424380312 +15::32::1::1424380312 +15::33::2::1424380312 +15::34::1::1424380312 +15::35::2::1424380312 +15::36::1::1424380312 +15::37::1::1424380312 +15::39::1::1424380312 +15::42::1::1424380312 +15::46::5::1424380312 +15::48::2::1424380312 +15::50::2::1424380312 +15::51::1::1424380312 +15::52::1::1424380312 +15::58::1::1424380312 +15::62::1::1424380312 +15::64::3::1424380312 +15::65::2::1424380312 +15::72::1::1424380312 +15::73::1::1424380312 +15::74::1::1424380312 +15::79::1::1424380312 +15::80::1::1424380312 +15::81::1::1424380312 +15::82::2::1424380312 +15::85::1::1424380312 +15::87::1::1424380312 +15::91::2::1424380312 +15::96::1::1424380312 +15::97::1::1424380312 +15::98::3::1424380312 +16::2::1::1424380312 +16::5::3::1424380312 +16::6::2::1424380312 +16::7::1::1424380312 +16::9::1::1424380312 +16::12::1::1424380312 +16::14::1::1424380312 +16::15::1::1424380312 +16::19::1::1424380312 +16::21::2::1424380312 +16::29::4::1424380312 +16::30::2::1424380312 +16::32::1::1424380312 +16::34::1::1424380312 +16::36::1::1424380312 +16::38::1::1424380312 +16::46::1::1424380312 +16::47::3::1424380312 +16::48::1::1424380312 +16::49::1::1424380312 +16::50::1::1424380312 +16::51::5::1424380312 +16::54::5::1424380312 +16::55::1::1424380312 +16::56::2::1424380312 +16::57::1::1424380312 +16::60::1::1424380312 +16::63::2::1424380312 +16::65::1::1424380312 +16::67::1::1424380312 +16::72::1::1424380312 +16::74::1::1424380312 +16::80::1::1424380312 +16::81::1::1424380312 +16::82::1::1424380312 +16::85::5::1424380312 +16::86::1::1424380312 +16::90::5::1424380312 +16::91::1::1424380312 +16::93::1::1424380312 +16::94::3::1424380312 +16::95::2::1424380312 +16::96::3::1424380312 +16::98::3::1424380312 +16::99::1::1424380312 +17::2::1::1424380312 +17::3::1::1424380312 +17::6::1::1424380312 +17::10::4::1424380312 +17::11::1::1424380312 +17::13::2::1424380312 +17::17::5::1424380312 +17::19::1::1424380312 +17::20::5::1424380312 +17::22::4::1424380312 +17::28::1::1424380312 +17::29::1::1424380312 +17::33::1::1424380312 +17::34::1::1424380312 +17::35::2::1424380312 +17::37::1::1424380312 +17::38::1::1424380312 +17::45::1::1424380312 +17::46::5::1424380312 +17::47::1::1424380312 +17::49::3::1424380312 +17::51::1::1424380312 +17::55::5::1424380312 +17::56::3::1424380312 +17::57::1::1424380312 +17::58::1::1424380312 +17::59::1::1424380312 +17::60::1::1424380312 +17::63::1::1424380312 +17::66::1::1424380312 +17::68::4::1424380312 +17::69::1::1424380312 +17::70::1::1424380312 +17::72::1::1424380312 +17::73::3::1424380312 +17::78::1::1424380312 +17::79::1::1424380312 +17::82::2::1424380312 +17::84::1::1424380312 +17::90::5::1424380312 +17::91::3::1424380312 +17::92::1::1424380312 +17::93::1::1424380312 +17::94::4::1424380312 +17::95::2::1424380312 +17::97::1::1424380312 +18::1::1::1424380312 +18::4::3::1424380312 +18::5::2::1424380312 +18::6::1::1424380312 +18::7::1::1424380312 +18::10::1::1424380312 +18::11::4::1424380312 +18::12::2::1424380312 +18::13::1::1424380312 +18::15::1::1424380312 +18::18::1::1424380312 +18::20::1::1424380312 +18::21::2::1424380312 +18::22::1::1424380312 +18::23::2::1424380312 +18::25::1::1424380312 +18::26::1::1424380312 +18::27::1::1424380312 +18::28::5::1424380312 +18::29::1::1424380312 +18::31::1::1424380312 +18::32::1::1424380312 +18::36::1::1424380312 +18::38::5::1424380312 +18::39::5::1424380312 +18::40::1::1424380312 +18::42::1::1424380312 +18::43::1::1424380312 +18::44::4::1424380312 +18::46::1::1424380312 +18::47::1::1424380312 +18::48::1::1424380312 +18::51::2::1424380312 +18::55::1::1424380312 +18::56::1::1424380312 +18::57::1::1424380312 +18::62::1::1424380312 +18::63::1::1424380312 +18::66::3::1424380312 +18::67::1::1424380312 +18::70::1::1424380312 +18::75::1::1424380312 +18::76::3::1424380312 +18::77::1::1424380312 +18::80::3::1424380312 +18::81::3::1424380312 +18::82::1::1424380312 +18::83::5::1424380312 +18::84::1::1424380312 +18::97::1::1424380312 +18::98::1::1424380312 +18::99::2::1424380312 +19::0::1::1424380312 +19::1::1::1424380312 +19::2::1::1424380312 +19::4::1::1424380312 +19::6::2::1424380312 +19::11::1::1424380312 +19::12::1::1424380312 +19::14::1::1424380312 +19::23::1::1424380312 +19::26::1::1424380312 +19::31::1::1424380312 +19::32::4::1424380312 +19::33::1::1424380312 +19::34::1::1424380312 +19::37::1::1424380312 +19::38::1::1424380312 +19::41::1::1424380312 +19::43::1::1424380312 +19::45::1::1424380312 +19::48::1::1424380312 +19::49::1::1424380312 +19::50::2::1424380312 +19::53::2::1424380312 +19::54::3::1424380312 +19::55::1::1424380312 +19::56::2::1424380312 +19::58::1::1424380312 +19::61::1::1424380312 +19::62::1::1424380312 +19::63::1::1424380312 +19::64::1::1424380312 +19::65::1::1424380312 +19::69::2::1424380312 +19::72::1::1424380312 +19::74::3::1424380312 +19::76::1::1424380312 +19::78::1::1424380312 +19::79::1::1424380312 +19::81::1::1424380312 +19::82::1::1424380312 +19::84::1::1424380312 +19::86::1::1424380312 +19::87::2::1424380312 +19::90::4::1424380312 +19::93::1::1424380312 +19::94::4::1424380312 +19::95::2::1424380312 +19::96::1::1424380312 +19::98::4::1424380312 +20::0::1::1424380312 +20::1::1::1424380312 +20::2::2::1424380312 +20::4::2::1424380312 +20::6::1::1424380312 +20::8::1::1424380312 +20::12::1::1424380312 +20::21::2::1424380312 +20::22::5::1424380312 +20::24::2::1424380312 +20::25::1::1424380312 +20::26::1::1424380312 +20::29::2::1424380312 +20::30::2::1424380312 +20::32::2::1424380312 +20::39::1::1424380312 +20::40::1::1424380312 +20::41::2::1424380312 +20::45::2::1424380312 +20::48::1::1424380312 +20::50::1::1424380312 +20::51::3::1424380312 +20::53::3::1424380312 +20::55::1::1424380312 +20::57::2::1424380312 +20::60::1::1424380312 +20::61::1::1424380312 +20::64::1::1424380312 +20::66::1::1424380312 +20::70::2::1424380312 +20::72::1::1424380312 +20::73::2::1424380312 +20::75::4::1424380312 +20::76::1::1424380312 +20::77::4::1424380312 +20::78::1::1424380312 +20::79::1::1424380312 +20::84::2::1424380312 +20::85::2::1424380312 +20::88::3::1424380312 +20::89::1::1424380312 +20::90::3::1424380312 +20::91::1::1424380312 +20::92::2::1424380312 +20::93::1::1424380312 +20::94::4::1424380312 +20::97::1::1424380312 +21::0::1::1424380312 +21::2::4::1424380312 +21::3::1::1424380312 +21::7::2::1424380312 +21::11::1::1424380312 +21::12::1::1424380312 +21::13::1::1424380312 +21::14::3::1424380312 +21::17::1::1424380312 +21::19::1::1424380312 +21::20::1::1424380312 +21::21::1::1424380312 +21::22::1::1424380312 +21::23::1::1424380312 +21::24::1::1424380312 +21::27::1::1424380312 +21::29::5::1424380312 +21::30::2::1424380312 +21::38::1::1424380312 +21::40::2::1424380312 +21::43::3::1424380312 +21::44::1::1424380312 +21::45::1::1424380312 +21::46::1::1424380312 +21::48::1::1424380312 +21::51::1::1424380312 +21::53::5::1424380312 +21::54::1::1424380312 +21::55::1::1424380312 +21::56::1::1424380312 +21::58::3::1424380312 +21::59::3::1424380312 +21::64::1::1424380312 +21::66::1::1424380312 +21::68::1::1424380312 +21::71::1::1424380312 +21::73::1::1424380312 +21::74::4::1424380312 +21::80::1::1424380312 +21::81::1::1424380312 +21::83::1::1424380312 +21::84::1::1424380312 +21::85::3::1424380312 +21::87::4::1424380312 +21::89::2::1424380312 +21::92::2::1424380312 +21::96::3::1424380312 +21::99::1::1424380312 +22::0::1::1424380312 +22::3::2::1424380312 +22::5::2::1424380312 +22::6::2::1424380312 +22::9::1::1424380312 +22::10::1::1424380312 +22::11::1::1424380312 +22::13::1::1424380312 +22::14::1::1424380312 +22::16::1::1424380312 +22::18::3::1424380312 +22::19::1::1424380312 +22::22::5::1424380312 +22::25::1::1424380312 +22::26::1::1424380312 +22::29::3::1424380312 +22::30::5::1424380312 +22::32::4::1424380312 +22::33::1::1424380312 +22::35::1::1424380312 +22::36::3::1424380312 +22::37::1::1424380312 +22::40::1::1424380312 +22::41::3::1424380312 +22::44::1::1424380312 +22::45::2::1424380312 +22::48::1::1424380312 +22::51::5::1424380312 +22::55::1::1424380312 +22::56::2::1424380312 +22::60::3::1424380312 +22::61::1::1424380312 +22::62::4::1424380312 +22::63::1::1424380312 +22::65::1::1424380312 +22::66::1::1424380312 +22::68::4::1424380312 +22::69::4::1424380312 +22::70::3::1424380312 +22::71::1::1424380312 +22::74::5::1424380312 +22::75::5::1424380312 +22::78::1::1424380312 +22::80::3::1424380312 +22::81::1::1424380312 +22::82::1::1424380312 +22::84::1::1424380312 +22::86::1::1424380312 +22::87::3::1424380312 +22::88::5::1424380312 +22::90::2::1424380312 +22::92::3::1424380312 +22::95::2::1424380312 +22::96::2::1424380312 +22::98::4::1424380312 +22::99::1::1424380312 +23::0::1::1424380312 +23::2::1::1424380312 +23::4::1::1424380312 +23::6::2::1424380312 +23::10::4::1424380312 +23::12::1::1424380312 +23::13::4::1424380312 +23::14::1::1424380312 +23::15::1::1424380312 +23::18::4::1424380312 +23::22::2::1424380312 +23::23::4::1424380312 +23::24::1::1424380312 +23::25::1::1424380312 +23::26::1::1424380312 +23::27::5::1424380312 +23::28::1::1424380312 +23::29::1::1424380312 +23::30::4::1424380312 +23::32::5::1424380312 +23::33::2::1424380312 +23::36::3::1424380312 +23::37::1::1424380312 +23::38::1::1424380312 +23::39::1::1424380312 +23::43::1::1424380312 +23::48::5::1424380312 +23::49::5::1424380312 +23::50::4::1424380312 +23::53::1::1424380312 +23::55::5::1424380312 +23::57::1::1424380312 +23::59::1::1424380312 +23::60::1::1424380312 +23::61::1::1424380312 +23::64::4::1424380312 +23::65::5::1424380312 +23::66::2::1424380312 +23::67::1::1424380312 +23::68::3::1424380312 +23::69::1::1424380312 +23::72::1::1424380312 +23::73::3::1424380312 +23::77::1::1424380312 +23::82::2::1424380312 +23::83::1::1424380312 +23::84::1::1424380312 +23::85::1::1424380312 +23::87::3::1424380312 +23::88::1::1424380312 +23::95::2::1424380312 +23::97::1::1424380312 +24::4::1::1424380312 +24::6::3::1424380312 +24::7::1::1424380312 +24::10::2::1424380312 +24::12::1::1424380312 +24::15::1::1424380312 +24::19::1::1424380312 +24::24::1::1424380312 +24::27::3::1424380312 +24::30::5::1424380312 +24::31::1::1424380312 +24::32::3::1424380312 +24::33::1::1424380312 +24::37::1::1424380312 +24::39::1::1424380312 +24::40::1::1424380312 +24::42::1::1424380312 +24::43::3::1424380312 +24::45::2::1424380312 +24::46::1::1424380312 +24::47::1::1424380312 +24::48::1::1424380312 +24::49::1::1424380312 +24::50::1::1424380312 +24::52::5::1424380312 +24::57::1::1424380312 +24::59::4::1424380312 +24::63::4::1424380312 +24::65::1::1424380312 +24::66::1::1424380312 +24::67::1::1424380312 +24::68::3::1424380312 +24::69::5::1424380312 +24::71::1::1424380312 +24::72::4::1424380312 +24::77::4::1424380312 +24::78::1::1424380312 +24::80::1::1424380312 +24::82::1::1424380312 +24::84::1::1424380312 +24::86::1::1424380312 +24::87::1::1424380312 +24::88::2::1424380312 +24::89::1::1424380312 +24::90::5::1424380312 +24::91::1::1424380312 +24::92::1::1424380312 +24::94::2::1424380312 +24::95::1::1424380312 +24::96::5::1424380312 +24::98::1::1424380312 +24::99::1::1424380312 +25::1::3::1424380312 +25::2::1::1424380312 +25::7::1::1424380312 +25::9::1::1424380312 +25::12::3::1424380312 +25::16::3::1424380312 +25::17::1::1424380312 +25::18::1::1424380312 +25::20::1::1424380312 +25::22::1::1424380312 +25::23::1::1424380312 +25::26::2::1424380312 +25::29::1::1424380312 +25::30::1::1424380312 +25::31::2::1424380312 +25::33::4::1424380312 +25::34::3::1424380312 +25::35::2::1424380312 +25::36::1::1424380312 +25::37::1::1424380312 +25::40::1::1424380312 +25::41::1::1424380312 +25::43::1::1424380312 +25::47::4::1424380312 +25::50::1::1424380312 +25::51::1::1424380312 +25::53::1::1424380312 +25::56::1::1424380312 +25::58::2::1424380312 +25::64::2::1424380312 +25::67::2::1424380312 +25::68::1::1424380312 +25::70::1::1424380312 +25::71::4::1424380312 +25::73::1::1424380312 +25::74::1::1424380312 +25::76::1::1424380312 +25::79::1::1424380312 +25::82::1::1424380312 +25::84::2::1424380312 +25::85::1::1424380312 +25::91::3::1424380312 +25::92::1::1424380312 +25::94::1::1424380312 +25::95::1::1424380312 +25::97::2::1424380312 +26::0::1::1424380312 +26::1::1::1424380312 +26::2::1::1424380312 +26::3::1::1424380312 +26::4::4::1424380312 +26::5::2::1424380312 +26::6::3::1424380312 +26::7::5::1424380312 +26::13::3::1424380312 +26::14::1::1424380312 +26::16::1::1424380312 +26::18::3::1424380312 +26::20::1::1424380312 +26::21::3::1424380312 +26::22::5::1424380312 +26::23::5::1424380312 +26::24::5::1424380312 +26::27::1::1424380312 +26::31::1::1424380312 +26::35::1::1424380312 +26::36::4::1424380312 +26::40::1::1424380312 +26::44::1::1424380312 +26::45::2::1424380312 +26::47::1::1424380312 +26::48::1::1424380312 +26::49::3::1424380312 +26::50::2::1424380312 +26::52::1::1424380312 +26::54::4::1424380312 +26::55::1::1424380312 +26::57::3::1424380312 +26::58::1::1424380312 +26::61::1::1424380312 +26::62::2::1424380312 +26::66::1::1424380312 +26::68::4::1424380312 +26::71::1::1424380312 +26::73::4::1424380312 +26::76::1::1424380312 +26::81::3::1424380312 +26::85::1::1424380312 +26::86::3::1424380312 +26::88::5::1424380312 +26::91::1::1424380312 +26::94::5::1424380312 +26::95::1::1424380312 +26::96::1::1424380312 +26::97::1::1424380312 +27::0::1::1424380312 +27::9::1::1424380312 +27::10::1::1424380312 +27::18::4::1424380312 +27::19::3::1424380312 +27::20::1::1424380312 +27::22::2::1424380312 +27::24::2::1424380312 +27::25::1::1424380312 +27::27::3::1424380312 +27::28::1::1424380312 +27::29::1::1424380312 +27::31::1::1424380312 +27::33::3::1424380312 +27::40::1::1424380312 +27::42::1::1424380312 +27::43::1::1424380312 +27::44::3::1424380312 +27::45::1::1424380312 +27::51::3::1424380312 +27::52::1::1424380312 +27::55::3::1424380312 +27::57::1::1424380312 +27::59::1::1424380312 +27::60::1::1424380312 +27::61::1::1424380312 +27::64::1::1424380312 +27::66::3::1424380312 +27::68::1::1424380312 +27::70::1::1424380312 +27::71::2::1424380312 +27::72::1::1424380312 +27::75::3::1424380312 +27::78::1::1424380312 +27::80::3::1424380312 +27::82::1::1424380312 +27::83::3::1424380312 +27::86::1::1424380312 +27::87::2::1424380312 +27::90::1::1424380312 +27::91::1::1424380312 +27::92::1::1424380312 +27::93::1::1424380312 +27::94::2::1424380312 +27::95::1::1424380312 +27::98::1::1424380312 +28::0::3::1424380312 +28::1::1::1424380312 +28::2::4::1424380312 +28::3::1::1424380312 +28::6::1::1424380312 +28::7::1::1424380312 +28::12::5::1424380312 +28::13::2::1424380312 +28::14::1::1424380312 +28::15::1::1424380312 +28::17::1::1424380312 +28::19::3::1424380312 +28::20::1::1424380312 +28::23::3::1424380312 +28::24::3::1424380312 +28::27::1::1424380312 +28::29::1::1424380312 +28::33::1::1424380312 +28::34::1::1424380312 +28::36::1::1424380312 +28::38::2::1424380312 +28::39::2::1424380312 +28::44::1::1424380312 +28::45::1::1424380312 +28::49::4::1424380312 +28::50::1::1424380312 +28::52::1::1424380312 +28::54::1::1424380312 +28::56::1::1424380312 +28::57::3::1424380312 +28::58::1::1424380312 +28::59::1::1424380312 +28::60::1::1424380312 +28::62::3::1424380312 +28::63::1::1424380312 +28::65::1::1424380312 +28::75::1::1424380312 +28::78::1::1424380312 +28::81::5::1424380312 +28::82::4::1424380312 +28::83::1::1424380312 +28::85::1::1424380312 +28::88::2::1424380312 +28::89::4::1424380312 +28::90::1::1424380312 +28::92::5::1424380312 +28::94::1::1424380312 +28::95::2::1424380312 +28::98::1::1424380312 +28::99::1::1424380312 +29::3::1::1424380312 +29::4::1::1424380312 +29::5::1::1424380312 +29::7::2::1424380312 +29::9::1::1424380312 +29::10::3::1424380312 +29::11::1::1424380312 +29::13::3::1424380312 +29::14::1::1424380312 +29::15::1::1424380312 +29::17::3::1424380312 +29::19::3::1424380312 +29::22::3::1424380312 +29::23::4::1424380312 +29::25::1::1424380312 +29::29::1::1424380312 +29::31::1::1424380312 +29::32::4::1424380312 +29::33::2::1424380312 +29::36::2::1424380312 +29::38::3::1424380312 +29::39::1::1424380312 +29::42::1::1424380312 +29::46::5::1424380312 +29::49::3::1424380312 +29::51::2::1424380312 +29::59::1::1424380312 +29::61::1::1424380312 +29::62::1::1424380312 +29::67::1::1424380312 +29::68::3::1424380312 +29::69::1::1424380312 +29::70::1::1424380312 +29::74::1::1424380312 +29::75::1::1424380312 +29::79::2::1424380312 +29::80::1::1424380312 +29::81::2::1424380312 +29::83::1::1424380312 +29::85::1::1424380312 +29::86::1::1424380312 +29::90::4::1424380312 +29::93::1::1424380312 +29::94::4::1424380312 +29::97::1::1424380312 +29::99::1::1424380312 diff --git a/data/mllib/gmm_data.txt b/data/mllib/gmm_data.txt new file mode 100644 index 0000000000000..934ee4a83a2df --- /dev/null +++ b/data/mllib/gmm_data.txt @@ -0,0 +1,2000 @@ + 2.59470454e+00 2.12298217e+00 + 1.15807024e+00 -1.46498723e-01 + 2.46206638e+00 6.19556894e-01 + -5.54845070e-01 -7.24700066e-01 + -3.23111426e+00 -1.42579084e+00 + 3.02978115e+00 7.87121753e-01 + 1.97365907e+00 1.15914704e+00 + -6.44852101e+00 -3.18154314e+00 + 1.30963349e+00 1.62866434e-01 + 4.26482541e+00 2.15547996e+00 + 3.79927257e+00 1.50572445e+00 + 4.17452609e-01 -6.74032760e-01 + 4.21117627e-01 4.45590255e-01 + -2.80425571e+00 -7.77150554e-01 + 2.55928797e+00 7.03954218e-01 + 1.32554059e+00 -9.46663152e-01 + -3.39691439e+00 -1.49005743e+00 + -2.26542270e-01 3.60052515e-02 + 1.04994198e+00 5.29825685e-01 + -1.51566882e+00 -1.86264432e-01 + -3.27928172e-01 -7.60859110e-01 + -3.18054866e-01 3.97719805e-01 + 1.65579418e-01 -3.47232033e-01 + 6.47162333e-01 4.96059961e-02 + -2.80776647e-01 4.79418757e-01 + 7.45069752e-01 1.20790281e-01 + 2.13604102e-01 1.59542555e-01 + -3.08860224e+00 -1.43259870e+00 + 8.97066497e-01 1.10206801e+00 + -2.23918874e-01 -1.07267267e+00 + 2.51525708e+00 2.84761973e-01 + 9.98052532e-01 1.08333783e+00 + 1.76705588e+00 8.18866778e-01 + 5.31555163e-02 -1.90111151e-01 + -2.17405059e+00 7.21854582e-02 + -2.13772505e+00 -3.62010387e-01 + 2.95974057e+00 1.31602381e+00 + 2.74053561e+00 1.61781757e+00 + 6.68135448e-01 2.86586009e-01 + 2.82323739e+00 1.74437257e+00 + 8.11540288e-01 5.50744478e-01 + 4.10050897e-01 5.10668402e-03 + 9.58626136e-01 -3.49633680e-01 + 4.66599798e+00 1.49964894e+00 + 4.94507794e-01 2.58928077e-01 + -2.36029742e+00 -1.61042909e+00 + -4.99306804e-01 -8.04984769e-01 + 1.07448510e+00 9.39605828e-01 + -1.80448949e+00 -1.05983264e+00 + -3.22353821e-01 1.73612093e-01 + 1.85418702e+00 1.15640643e+00 + 6.93794163e-01 6.59993560e-01 + 1.99399102e+00 1.44547123e+00 + 3.38866124e+00 1.23379290e+00 + -4.24067720e+00 -1.22264282e+00 + 6.03230201e-02 2.95232729e-01 + -3.59341813e+00 -7.17453726e-01 + 4.87447372e-01 -2.00733911e-01 + 1.20149195e+00 4.07880197e-01 + -2.13331464e+00 -4.58518077e-01 + -3.84091083e+00 -1.71553950e+00 + -5.37279250e-01 2.64822629e-02 + -2.10155227e+00 -1.32558103e+00 + -1.71318897e+00 -7.12098563e-01 + -1.46280695e+00 -1.84868337e-01 + -3.59785325e+00 -1.54832434e+00 + -5.77528081e-01 -5.78580857e-01 + 3.14734283e-01 5.80184639e-01 + -2.71164714e+00 -1.19379432e+00 + 1.09634489e+00 7.20143887e-01 + -3.05527722e+00 -1.47774064e+00 + 6.71753586e-01 7.61350020e-01 + 3.98294144e+00 1.54166484e+00 + -3.37220384e+00 -2.21332064e+00 + 1.81222914e+00 7.41212752e-01 + 2.71458282e-01 1.36329078e-01 + -3.97815359e-01 1.16766886e-01 + -1.70192814e+00 -9.75851571e-01 + -3.46803804e+00 -1.09965988e+00 + -1.69649627e+00 -5.76045801e-01 + -1.02485636e-01 -8.81841246e-01 + -3.24194667e-02 2.55429276e-01 + -2.75343168e+00 -1.51366320e+00 + -2.78676702e+00 -5.22360489e-01 + 1.70483164e+00 1.19769805e+00 + 4.92022579e-01 3.24944706e-01 + 2.48768464e+00 1.00055363e+00 + 4.48786400e-01 7.63902870e-01 + 2.93862696e+00 1.73809968e+00 + -3.55019305e+00 -1.97875558e+00 + 1.74270784e+00 6.90229224e-01 + 5.13391994e-01 4.58374016e-01 + 1.78379499e+00 9.08026381e-01 + 1.75814147e+00 7.41449784e-01 + -2.30687792e-01 3.91009729e-01 + 3.92271353e+00 1.44006290e+00 + 2.93361679e-01 -4.99886375e-03 + 2.47902690e-01 -7.49542503e-01 + -3.97675355e-01 1.36824887e-01 + 3.56535953e+00 1.15181329e+00 + 3.22425301e+00 1.28702383e+00 + -2.94192478e-01 -2.42382557e-01 + 8.02068864e-01 -1.51671475e-01 + 8.54133530e-01 -4.89514885e-02 + -1.64316316e-01 -5.34642346e-01 + -6.08485405e-01 -2.10332352e-01 + -2.18940059e+00 -1.07024952e+00 + -1.71586960e+00 -2.83333492e-02 + 1.70200448e-01 -3.28031178e-01 + -1.97210346e+00 -5.39948532e-01 + 2.19500160e+00 1.05697170e+00 + -1.76239935e+00 -1.09377438e+00 + 1.68314744e+00 6.86491164e-01 + -2.99852288e+00 -1.46619067e+00 + -2.23769560e+00 -9.15008355e-01 + 9.46887516e-01 5.58410503e-01 + 5.02153123e-01 1.63851235e-01 + -9.70297062e-01 3.14625374e-01 + -1.29405593e+00 -8.20994131e-01 + 2.72516079e+00 7.85839947e-01 + 1.45788024e+00 3.37487353e-01 + -4.36292749e-01 -5.42150480e-01 + 2.21304711e+00 1.25254042e+00 + -1.20810271e-01 4.79632898e-01 + -3.30884511e+00 -1.50607586e+00 + -6.55882455e+00 -1.94231256e+00 + -3.17033630e+00 -9.94678930e-01 + 1.42043617e+00 7.28808957e-01 + -1.57546099e+00 -1.10320497e+00 + -3.22748754e+00 -1.64174579e+00 + 2.96776017e-03 -3.16191512e-02 + -2.25986054e+00 -6.13123197e-01 + 2.49434243e+00 7.73069183e-01 + 9.08494049e-01 -1.53926853e-01 + -2.80559090e+00 -1.37474221e+00 + 4.75224286e-01 2.53153674e-01 + 4.37644006e+00 8.49116998e-01 + 2.27282959e+00 6.16568202e-01 + 1.16006880e+00 1.65832798e-01 + -1.67163193e+00 -1.22555386e+00 + -1.38231118e+00 -7.29575504e-01 + -3.49922750e+00 -2.26446675e+00 + -3.73780110e-01 -1.90657869e-01 + 1.68627679e+00 1.05662987e+00 + -3.28891792e+00 -1.11080334e+00 + -2.59815798e+00 -1.51410198e+00 + -2.61203309e+00 -6.00143552e-01 + 6.58964943e-01 4.47216094e-01 + -2.26711381e+00 -7.26512923e-01 + -5.31429009e-02 -1.97925341e-02 + 3.19749807e+00 9.20425476e-01 + -1.37595787e+00 -6.58062732e-01 + 8.09900278e-01 -3.84286160e-01 + -5.07741280e+00 -1.97683808e+00 + -2.99764250e+00 -1.50753777e+00 + -9.87671815e-01 -4.63255889e-01 + 1.65390765e+00 6.73806615e-02 + 5.51252659e+00 2.69842267e+00 + -2.23724309e+00 -4.77624004e-01 + 4.99726228e+00 1.74690949e+00 + 1.75859162e-01 -1.49350995e-01 + 4.13382789e+00 1.31735161e+00 + 2.69058117e+00 4.87656923e-01 + 1.07180318e+00 1.01426954e+00 + 3.37216869e+00 1.05955377e+00 + -2.95006781e+00 -1.57048303e+00 + -2.46401648e+00 -8.37056374e-01 + 1.19012962e-01 7.54702770e-01 + 3.34142539e+00 4.81938295e-01 + 2.92643913e+00 1.04301050e+00 + 2.89697751e+00 1.37551442e+00 + -1.03094242e+00 2.20903962e-01 + -5.13914589e+00 -2.23355387e+00 + -8.81680780e-01 1.83590000e-01 + 2.82334775e+00 1.26650464e+00 + -2.81042540e-01 -3.26370240e-01 + 2.97995487e+00 8.34569452e-01 + -1.39857135e+00 -1.15798385e+00 + 4.27186506e+00 9.04253702e-01 + 6.98684517e-01 7.91167305e-01 + 3.52233095e+00 1.29976473e+00 + 2.21448029e+00 2.73213379e-01 + -3.13505683e-01 -1.20593774e-01 + 3.70571571e+00 1.06220876e+00 + 9.83881041e-01 5.67713803e-01 + -2.17897705e+00 2.52925205e-01 + 1.38734039e+00 4.61287066e-01 + -1.41181602e+00 -1.67248955e-02 + -1.69974639e+00 -7.17812071e-01 + -2.01005793e-01 -7.49662056e-01 + 1.69016336e+00 3.24687979e-01 + -2.03250179e+00 -2.76108460e-01 + 3.68776848e-01 4.12536941e-01 + 7.66238259e-01 -1.84750637e-01 + -2.73989147e-01 -1.72817250e-01 + -2.18623745e+00 -2.10906798e-01 + -1.39795625e-01 3.26066094e-02 + -2.73826912e-01 -6.67586097e-02 + -1.57880654e+00 -4.99395900e-01 + 4.55950908e+00 2.29410489e+00 + -7.36479631e-01 -1.57861857e-01 + 1.92082888e+00 1.05843391e+00 + 4.29192810e+00 1.38127810e+00 + 1.61852879e+00 1.95871986e-01 + -1.95027403e+00 -5.22448168e-01 + -1.67446281e+00 -9.41497162e-01 + 6.07097859e-01 3.44178029e-01 + -3.44004683e+00 -1.49258461e+00 + 2.72114752e+00 6.00728991e-01 + 8.80685522e-01 -2.53243336e-01 + 1.39254928e+00 3.42988512e-01 + 1.14194836e-01 -8.57945694e-02 + -1.49387332e+00 -7.60860481e-01 + -1.98053285e+00 -4.86039865e-01 + 3.56008568e+00 1.08438692e+00 + 2.27833961e-01 1.09441881e+00 + -1.16716710e+00 -6.54778242e-01 + 2.02156613e+00 5.42075758e-01 + 1.08429178e+00 -7.67420693e-01 + 6.63058455e-01 4.61680991e-01 + -1.06201537e+00 1.38862846e-01 + 3.08701875e+00 8.32580273e-01 + -4.96558108e-01 -2.47031257e-01 + 7.95109987e-01 7.59314147e-02 + -3.39903524e-01 8.71565566e-03 + 8.68351357e-01 4.78358641e-01 + 1.48750819e+00 7.63257420e-01 + -4.51224101e-01 -4.44056898e-01 + -3.02734750e-01 -2.98487961e-01 + 5.46846609e-01 7.02377629e-01 + 1.65129778e+00 3.74008231e-01 + -7.43336512e-01 3.95723531e-01 + -5.88446605e-01 -6.47520211e-01 + 3.58613167e+00 1.95024937e+00 + 3.11718883e+00 8.37984715e-01 + 1.80919244e+00 9.62644986e-01 + 5.43856371e-02 -5.86297543e-01 + -1.95186766e+00 -1.02624212e-01 + 8.95628057e-01 5.91812281e-01 + 4.97691627e-02 5.31137156e-01 + -1.07633113e+00 -2.47392788e-01 + -1.17257986e+00 -8.68528265e-01 + -8.19227665e-02 5.80579434e-03 + -2.86409787e-01 1.95812924e-01 + 1.10582671e+00 7.42853240e-01 + 4.06429774e+00 1.06557476e+00 + -3.42521792e+00 -7.74327139e-01 + 1.28468671e+00 6.20431661e-01 + 6.01201008e-01 -1.16799728e-01 + -1.85058727e-01 -3.76235293e-01 + 5.44083324e+00 2.98490868e+00 + 2.69273070e+00 7.83901153e-01 + 1.88938036e-01 -4.83222152e-01 + 1.05667256e+00 -2.57003165e-01 + 2.99711662e-01 -4.33131912e-01 + 7.73689216e-02 -1.78738364e-01 + 9.58326279e-01 6.38325706e-01 + -3.97727049e-01 2.27314759e-01 + 3.36098175e+00 1.12165237e+00 + 1.77804871e+00 6.46961933e-01 + -2.86945546e+00 -1.00395518e+00 + 3.03494815e+00 7.51814612e-01 + -1.43658194e+00 -3.55432244e-01 + -3.08455105e+00 -1.51535106e+00 + -1.55841975e+00 3.93454820e-02 + 7.96073412e-01 -3.11036969e-01 + -9.84125401e-01 -1.02064649e+00 + -7.75688143e+00 -3.65219926e+00 + 1.53816429e+00 7.65926670e-01 + -4.92712738e-01 2.32244240e-02 + -1.93166919e+00 -1.07701304e+00 + 2.03029875e-02 -7.54055699e-01 + 2.52177489e+00 1.01544979e+00 + 3.65109048e-01 -9.48328494e-01 + -1.28849143e-01 2.51947174e-01 + -1.02428075e+00 -9.37767116e-01 + -3.04179748e+00 -9.97926994e-01 + -2.51986980e+00 -1.69117413e+00 + -1.24900838e+00 -4.16179917e-01 + 2.77943992e+00 1.22842327e+00 + -4.37434557e+00 -1.70182693e+00 + -1.60019319e+00 -4.18345639e-01 + -1.67613646e+00 -9.44087262e-01 + -9.00843245e-01 8.26378089e-02 + 3.29770621e-01 -9.07870444e-01 + -2.84650535e+00 -9.00155396e-01 + 1.57111705e+00 7.07432268e-01 + 1.24948552e+00 1.04812849e-01 + 1.81440558e+00 9.53545082e-01 + -1.74915794e+00 -1.04606288e+00 + 1.20593269e+00 -1.12607147e-02 + 1.36004919e-01 -1.09828044e+00 + 2.57480693e-01 3.34941541e-01 + 7.78775385e-01 -5.32494732e-01 + -1.79155126e+00 -6.29994129e-01 + -1.75706839e+00 -8.35100126e-01 + 4.29512012e-01 7.81426910e-02 + 3.08349370e-01 -1.27359861e-01 + 1.05560329e+00 4.55150640e-01 + 1.95662574e+00 1.17593217e+00 + 8.77376632e-01 6.57866662e-01 + 7.71311255e-01 9.15134334e-02 + -6.36978275e+00 -2.55874241e+00 + -2.98335339e+00 -1.59567024e+00 + -3.67104587e-01 1.85315291e-01 + 1.95347407e+00 -7.15503113e-02 + 8.45556363e-01 6.51256415e-02 + 9.42868521e-01 3.56647624e-01 + 2.99321875e+00 1.07505254e+00 + -2.91030538e-01 -3.77637183e-01 + 1.62870918e+00 3.37563671e-01 + 2.05773173e-01 3.43337416e-01 + -8.40879199e-01 -1.35600767e-01 + 1.38101624e+00 5.99253495e-01 + -6.93715607e+00 -2.63580662e+00 + -1.04423404e+00 -8.32865050e-01 + 1.33448476e+00 1.04863475e+00 + 6.01675207e-01 1.98585194e-01 + 2.31233993e+00 7.98628331e-01 + 1.85201313e-01 -1.76070247e+00 + 1.92006354e+00 8.45737582e-01 + 1.06320415e+00 2.93426068e-01 + -1.20360141e+00 -1.00301288e+00 + 1.95926629e+00 6.26643532e-01 + 6.04483978e-02 5.72643059e-01 + -1.04568563e+00 -5.91021496e-01 + 2.62300678e+00 9.50997831e-01 + -4.04610275e-01 3.73150879e-01 + 2.26371902e+00 8.73627529e-01 + 2.12545313e+00 7.90640352e-01 + 7.72181917e-03 1.65718952e-02 + 1.00422340e-01 -2.05562936e-01 + -1.22989802e+00 -1.01841681e-01 + 3.09064082e+00 1.04288010e+00 + 5.18274167e+00 1.34749259e+00 + -8.32075153e-01 -1.97592029e-01 + 3.84126764e-02 5.58171345e-01 + 4.99560727e-01 -4.26154438e-02 + 4.79071151e+00 2.19728942e+00 + -2.78437968e+00 -1.17812590e+00 + -2.22804226e+00 -4.31174255e-01 + 8.50762292e-01 -1.06445261e-01 + 1.10812830e+00 -2.59118812e-01 + -2.91450155e-01 6.42802679e-01 + -1.38631532e-01 -5.88585623e-01 + -5.04120983e-01 -2.17094915e-01 + 3.41410820e+00 1.67897767e+00 + -2.23697326e+00 -6.62735244e-01 + -3.55961064e-01 -1.27647226e-01 + -3.55568274e+00 -2.49011369e+00 + -8.77586408e-01 -9.38268065e-03 + 1.52382384e-01 -5.62155760e-01 + 1.55885574e-01 1.07617069e-01 + -8.37129973e-01 -5.22259081e-01 + -2.92741750e+00 -1.35049428e+00 + -3.54670781e-01 5.69205952e-02 + 2.21030255e+00 1.34689986e+00 + 1.60787722e+00 5.75984706e-01 + 1.32294221e+00 5.31577509e-01 + 7.05672928e-01 3.34241244e-01 + 1.41406179e+00 1.15783408e+00 + -6.92172228e-01 -2.84817896e-01 + 3.28358655e-01 -2.66910083e-01 + 1.68013644e-01 -4.28016549e-02 + 2.07365974e+00 7.76496211e-01 + -3.92974907e-01 2.46796730e-01 + -5.76078636e-01 3.25676963e-01 + -1.82547204e-01 -5.06410543e-01 + 3.04754906e+00 1.16174496e+00 + -3.01090632e+00 -1.09195183e+00 + -1.44659696e+00 -6.87838682e-01 + 2.11395861e+00 9.10495785e-01 + 1.40962871e+00 1.13568678e+00 + -1.66653234e-01 -2.10012503e-01 + 3.17456029e+00 9.74502922e-01 + 2.15944820e+00 8.62807189e-01 + -3.45418719e+00 -1.33647548e+00 + -3.41357732e+00 -8.47048920e-01 + -3.06702448e-01 -6.64280634e-01 + -2.86930714e-01 -1.35268264e-01 + -3.15835557e+00 -5.43439253e-01 + 2.49541440e-01 -4.71733570e-01 + 2.71933912e+00 4.13308399e-01 + -2.43787038e+00 -1.08050547e+00 + -4.90234490e-01 -6.64069865e-01 + 8.99524451e-02 5.76180541e-01 + 5.00500404e+00 2.12125521e+00 + -1.73107940e-01 -2.28506575e-02 + 5.44938858e-01 -1.29523352e-01 + 5.13526842e+00 1.68785993e+00 + 1.70228304e+00 1.02601138e+00 + 3.58957507e+00 1.54396196e+00 + 1.85615738e+00 4.92916197e-01 + 2.55772147e+00 7.88438908e-01 + -1.57008279e+00 -4.17377300e-01 + -1.42548604e+00 -3.63684860e-01 + -8.52026118e-01 2.72052686e-01 + -5.10563077e+00 -2.35665994e+00 + -2.95517031e+00 -1.84945297e+00 + -2.91947959e+00 -1.66016784e+00 + -4.21462387e+00 -1.41131535e+00 + 6.59901121e-01 4.87156314e-01 + -9.75352532e-01 -4.50231285e-01 + -5.94084444e-01 -1.16922670e+00 + 7.50554615e-01 -9.83692552e-01 + 1.07054926e+00 2.77143030e-01 + -3.88079578e-01 -4.17737309e-02 + -9.59373733e-01 -8.85454886e-01 + -7.53560665e-02 -5.16223870e-02 + 9.84108158e-01 -5.89290700e-02 + 1.87272961e-01 -4.34238391e-01 + 6.86509981e-01 -3.15116460e-01 + -1.07762538e+00 6.58984161e-02 + 6.09266592e-01 6.91808473e-02 + -8.30529954e-01 -7.00454791e-01 + -9.13179464e-01 -6.31712891e-01 + 7.68744851e-01 1.09840676e+00 + -1.07606690e+00 -8.78390282e-01 + -1.71038184e+00 -5.73606033e-01 + 8.75982765e-01 3.66343143e-01 + -7.04919009e-01 -8.49182590e-01 + -1.00274668e+00 -7.99573611e-01 + -1.05562848e+00 -5.84060076e-01 + 4.03490015e+00 1.28679206e+00 + -3.53484804e+00 -1.71381255e+00 + 2.31527363e-01 1.04179397e-01 + -3.58592392e-02 3.74895739e-01 + 3.92253428e+00 1.81852726e+00 + -7.27384249e-01 -6.45605128e-01 + 4.65678097e+00 2.41379899e+00 + 1.16750534e+00 7.60718205e-01 + 1.15677059e+00 7.96225550e-01 + -1.42920261e+00 -4.66946295e-01 + 3.71148192e+00 1.88060191e+00 + 2.44052407e+00 3.84472199e-01 + -1.64535035e+00 -8.94530036e-01 + -3.69608753e+00 -1.36402754e+00 + 2.24419208e+00 9.69744889e-01 + 2.54822427e+00 1.22613039e+00 + 3.77484909e-01 -5.98521878e-01 + -3.61521175e+00 -1.11123912e+00 + 3.28113127e+00 1.52551775e+00 + -3.51030902e+00 -1.53913980e+00 + -2.44874505e+00 -6.30246005e-01 + -3.42516153e-01 -5.07352665e-01 + 1.09110502e+00 6.36821628e-01 + -2.49434967e+00 -8.02827146e-01 + 1.41763139e+00 -3.46591820e-01 + 1.61108619e+00 5.93871102e-01 + 3.97371717e+00 1.35552499e+00 + -1.33437177e+00 -2.83908670e-01 + -1.41606483e+00 -1.76402601e-01 + 2.23945322e-01 -1.77157065e-01 + 2.60271569e+00 2.40778251e-01 + -2.82213895e-02 1.98255474e-01 + 4.20727940e+00 1.31490863e+00 + 3.36944889e+00 1.57566635e+00 + 3.53049396e+00 1.73579350e+00 + -1.29170202e+00 -1.64196290e+00 + 9.27295604e-01 9.98808036e-01 + 1.75321843e-01 -2.83267817e-01 + -2.19069578e+00 -1.12814358e+00 + 1.66606031e+00 7.68006933e-01 + -7.13826035e-01 5.20881684e-02 + -3.43821888e+00 -2.36137021e+00 + -5.93210310e-01 1.21843813e-01 + -4.09800822e+00 -1.39893953e+00 + 2.74110954e+00 1.52728606e+00 + 1.72652512e+00 -1.25435113e-01 + 1.97722357e+00 6.40667481e-01 + 4.18635780e-01 3.57018509e-01 + -1.78303569e+00 -2.11864764e-01 + -3.52809366e+00 -2.58794450e-01 + -4.72407090e+00 -1.63870734e+00 + 1.73917807e+00 8.73251829e-01 + 4.37979356e-01 8.49210569e-01 + 3.93791881e+00 1.76269490e+00 + 2.79065411e+00 1.04019042e+00 + -8.47426142e-01 -3.40136892e-01 + -4.24389181e+00 -1.80253120e+00 + -1.86675870e+00 -7.64558265e-01 + 9.46212675e-01 -7.77681445e-02 + -2.82448462e+00 -1.33592449e+00 + -2.57938567e+00 -1.56554690e+00 + -2.71615767e+00 -6.27667233e-01 + -1.55999166e+00 -5.81013466e-01 + -4.24696864e-01 -7.44673250e-01 + 1.67592970e+00 7.68164292e-01 + 8.48455216e-01 -6.05681126e-01 + 6.12575454e+00 1.65607584e+00 + 1.38207327e+00 2.39261863e-01 + 3.13364450e+00 1.17154698e+00 + 1.71694858e+00 1.26744905e+00 + -1.61746367e+00 -8.80098073e-01 + -8.52196756e-01 -9.27299728e-01 + -1.51562462e-01 -8.36552490e-02 + -7.04792753e-01 -1.24726713e-02 + -3.35265757e+00 -1.82176312e+00 + 3.32173170e-01 -1.33405580e-01 + 4.95841013e-01 4.58292712e-01 + 1.57713955e+00 7.79272991e-01 + 2.09743109e+00 9.23542557e-01 + 3.90450311e-03 -8.42873164e-01 + 2.59519038e+00 7.56479591e-01 + -5.77643976e-01 -2.36401904e-01 + -5.22310654e-01 1.34187830e-01 + -2.22096086e+00 -7.75507719e-01 + 1.35907831e+00 7.80197510e-01 + 3.80355868e+00 1.16983476e+00 + 3.82746596e+00 1.31417718e+00 + 3.30451183e+00 1.55398159e+00 + -3.42917814e-01 -8.62281222e-02 + -2.59093020e+00 -9.29883526e-01 + 1.40928562e+00 1.08398346e+00 + 1.54400137e-01 3.35881092e-01 + 1.59171586e+00 1.18855802e+00 + -5.25164002e-01 -1.03104220e-01 + 2.20067959e+00 1.37074713e+00 + 6.97860830e-01 6.27718548e-01 + -4.59743507e-01 1.36061163e-01 + -1.04691963e-01 -2.16271727e-01 + -1.08905573e+00 -5.95510769e-01 + -1.00826983e+00 -5.38509162e-02 + -3.16402719e+00 -1.33414216e+00 + 1.47870874e-01 1.75234619e-01 + -2.57078234e-01 7.03316889e-02 + 1.81073945e+00 4.26901462e-01 + 2.65476530e+00 6.74217273e-01 + 1.27539811e+00 6.22914081e-01 + -3.76750499e-01 -1.20629449e+00 + 1.00177595e+00 -1.40660091e-01 + -2.98919265e+00 -1.65145013e+00 + -2.21557682e+00 -8.11123452e-01 + -3.22635378e+00 -1.65639056e+00 + -2.72868553e+00 -1.02812087e+00 + 1.26042797e+00 8.49005248e-01 + -9.38318534e-01 -9.87588651e-01 + 3.38013194e-01 -1.00237461e-01 + 1.91175691e+00 8.48716369e-01 + 4.30244344e-01 6.05539915e-02 + 2.21783435e+00 3.03268204e-01 + 1.78019576e+00 1.27377108e+00 + 1.59733274e+00 4.40674687e-02 + 3.97428484e+00 2.20881566e+00 + -2.41108677e+00 -6.01410418e-01 + -2.50796499e+00 -5.71169866e-01 + -3.71957427e+00 -1.38195726e+00 + -1.57992670e+00 1.32068593e-01 + -1.35278851e+00 -6.39349270e-01 + 1.23075932e+00 2.40445409e-01 + 1.35606530e+00 4.33180078e-01 + 9.60968518e-02 2.26734255e-01 + 6.22975063e-01 5.03431915e-02 + -1.47624851e+00 -3.60568238e-01 + -2.49337808e+00 -1.15083052e+00 + 2.15717792e+00 1.03071559e+00 + -3.07814376e-02 1.38700314e-02 + 4.52049499e-02 -4.86409775e-01 + 2.58231061e+00 1.14327809e-01 + 1.10999138e+00 -5.18568405e-01 + -2.19426443e-01 -5.37505538e-01 + -4.44740298e-01 6.78099955e-01 + 4.03379080e+00 1.49825720e+00 + -5.13182408e-01 -4.90201950e-01 + -6.90139716e-01 1.63875126e-01 + -8.17281461e-01 2.32155064e-01 + -2.92357619e-01 -8.02573544e-01 + -1.80769841e+00 -7.58907326e-01 + 2.16981590e+00 1.06728873e+00 + 1.98995203e-01 -6.84176682e-02 + -2.39546753e+00 -2.92873789e-01 + -4.24251021e+00 -1.46255564e+00 + -5.01411291e-01 -5.95712813e-03 + 2.68085809e+00 1.42883780e+00 + -4.13289873e+00 -1.62729388e+00 + 1.87957843e+00 3.63341638e-01 + -1.15270744e+00 -3.03563774e-01 + -4.43994248e+00 -2.97323905e+00 + -7.17067733e-01 -7.08349542e-01 + -3.28870393e+00 -1.19263863e+00 + -7.55325944e-01 -5.12703329e-01 + -2.07291938e+00 -2.65025085e-01 + -7.50073814e-01 -1.70771041e-01 + -8.77381404e-01 -5.47417325e-01 + -5.33725862e-01 5.15837119e-01 + 8.45056431e-01 2.82125560e-01 + -1.59598637e+00 -1.38743235e+00 + 1.41362902e+00 1.06407789e+00 + 1.02584504e+00 -3.68219466e-01 + -1.04644488e+00 -1.48769392e-01 + 2.66990191e+00 8.57633492e-01 + -1.84251857e+00 -9.82430175e-01 + 9.71404204e-01 -2.81934209e-01 + -2.50177989e+00 -9.21260335e-01 + -1.31060074e+00 -5.84488113e-01 + -2.12129400e-01 -3.06244708e-02 + -5.28933882e+00 -2.50663129e+00 + 1.90220541e+00 1.08662918e+00 + -3.99366086e-02 -6.87178973e-01 + -4.93417342e-01 4.37354182e-01 + 2.13494486e+00 1.37679569e+00 + 2.18396765e+00 5.81023868e-01 + -3.07866587e+00 -1.45384974e+00 + 6.10894119e-01 -4.17050124e-01 + -1.88766952e+00 -8.86160058e-01 + 3.34527253e+00 1.78571260e+00 + 6.87769059e-01 -5.01157336e-01 + 2.60470837e+00 1.45853560e+00 + -6.49315691e-01 -9.16112805e-01 + -1.29817687e+00 -2.15924339e-01 + -1.20100409e-03 -4.03137422e-01 + -1.36471594e+00 -6.93266356e-01 + 1.38682062e+00 7.15131598e-01 + 2.47830103e+00 1.24862305e+00 + -2.78288147e+00 -1.03329235e+00 + -7.33443403e-01 -6.11041652e-01 + -4.12745671e-01 -5.96133390e-02 + -2.58632336e+00 -4.51557058e-01 + -1.16570367e+00 -1.27065510e+00 + 2.76187104e+00 2.21895451e-01 + -3.80443767e+00 -1.66319902e+00 + 9.84658633e-01 6.81475569e-01 + 9.33814584e-01 -4.89335563e-02 + -4.63427997e-01 1.72989539e-01 + 1.82401546e+00 3.60164021e-01 + -5.36521077e-01 -8.08691351e-01 + -1.37367030e+00 -1.02126160e+00 + -3.70310682e+00 -1.19840844e+00 + -1.51894242e+00 -3.89510223e-01 + -3.67347940e-01 -3.25540516e-02 + -1.00988595e+00 1.82802194e-01 + 2.01622795e+00 7.86367901e-01 + 1.02440231e+00 8.79780360e-01 + -3.05971480e+00 -8.40901527e-01 + 2.73909457e+00 1.20558628e+00 + 2.39559056e+00 1.10786694e+00 + 1.65471544e+00 7.33824651e-01 + 2.18546787e+00 6.41168955e-01 + 1.47152266e+00 3.91839132e-01 + 1.45811155e+00 5.21820495e-01 + -4.27531469e-02 -3.52343068e-03 + -9.54948010e-01 -1.52313876e-01 + 7.57151215e-01 -5.68728854e-03 + -8.46205751e-01 -7.54580229e-01 + 4.14493548e+00 1.45532780e+00 + 4.58688968e-01 -4.54012803e-02 + -1.49295381e+00 -4.57471758e-01 + 1.80020351e+00 8.13724973e-01 + -5.82727738e+00 -2.18269581e+00 + -2.09017809e+00 -1.18305177e+00 + -2.31628303e+00 -7.21600235e-01 + -8.09679091e-01 -1.49101752e-01 + 8.88005605e-01 8.57940857e-01 + -1.44148219e+00 -3.10926299e-01 + 3.68828186e-01 -3.08848059e-01 + -6.63267389e-01 -8.58950139e-02 + -1.14702569e+00 -6.32147854e-01 + -1.51741715e+00 -8.53330564e-01 + -1.33903718e+00 -1.45875547e-01 + 4.12485387e+00 1.85620435e+00 + -2.42353639e+00 -2.92669850e-01 + 1.88708583e+00 9.35984730e-01 + 2.15585179e+00 6.30469051e-01 + -1.13627973e-01 -1.62554045e-01 + 2.04540494e+00 1.36599834e+00 + 2.81591381e+00 1.60897941e+00 + 3.02736260e-02 3.83255815e-03 + 7.97634013e-02 -2.82035099e-01 + -3.24607473e-01 -5.30065956e-01 + -3.91862894e+00 -1.94083334e+00 + 1.56360901e+00 7.93882743e-01 + -1.03905772e+00 6.25590229e-01 + 2.54746492e+00 1.64233560e+00 + -4.80774423e-01 -8.92298032e-02 + 9.06979990e-02 1.05020427e+00 + -2.47521290e+00 -1.78275982e-01 + -3.91871729e-01 3.80285423e-01 + 1.00658382e+00 4.58947483e-01 + 4.68102941e-01 1.02992741e+00 + 4.44242568e-01 2.89870239e-01 + 3.29684452e+00 1.44677474e+00 + -2.24983007e+00 -9.65574499e-01 + -3.54453926e-01 -3.99020325e-01 + -3.87429665e+00 -1.90079739e+00 + 2.02656674e+00 1.12444894e+00 + 3.77011621e+00 1.43200852e+00 + 1.61259275e+00 4.65417399e-01 + 2.28725434e+00 6.79181395e-01 + 2.75421009e+00 2.27327345e+00 + -2.40894409e+00 -1.03926359e+00 + 1.52996651e-01 -2.73373046e-02 + -2.63218977e+00 -7.22802821e-01 + 2.77688169e+00 1.15310186e+00 + 1.18832341e+00 4.73457165e-01 + -2.35536326e+00 -1.08034554e+00 + -5.84221627e-01 1.03505984e-02 + 2.96730300e+00 1.33478306e+00 + -8.61947692e-01 6.09137051e-02 + 8.22343921e-01 -8.14155286e-02 + 1.75809015e+00 1.07921470e+00 + 1.19501279e+00 1.05309972e+00 + -1.75901792e+00 9.75320161e-02 + 1.64398635e+00 9.54384323e-01 + -2.21878052e-01 -3.64847144e-01 + -2.03128968e+00 -8.57866419e-01 + 1.86750633e+00 7.08524487e-01 + 8.03972976e-01 3.47404314e-01 + 3.41203749e+00 1.39810900e+00 + 4.22397681e-01 -6.41440488e-01 + -4.88493360e+00 -1.58967816e+00 + -1.67649284e-01 -1.08485915e-01 + 2.11489023e+00 1.50506158e+00 + -1.81639929e+00 -3.85542192e-01 + 2.24044819e-01 -1.45100577e-01 + -3.39262411e+00 -1.44394324e+00 + 1.68706599e+00 2.29199618e-01 + -1.94093257e+00 -1.65975814e-01 + 8.28143367e-01 5.92109281e-01 + -8.29587998e-01 -9.57130831e-01 + -1.50011401e+00 -8.36802092e-01 + 2.40770449e+00 9.32820177e-01 + 7.41391309e-02 3.12878473e-01 + 1.87745264e-01 6.19231425e-01 + 9.57622692e-01 -2.20640033e-01 + 3.18479243e+00 1.02986233e+00 + 2.43133846e+00 8.41302677e-01 + -7.09963834e-01 1.99718943e-01 + -2.88253498e-01 -3.62772094e-01 + 5.14052574e+00 1.79304595e+00 + -3.27930993e+00 -1.29177973e+00 + -1.16723536e+00 1.29519656e-01 + 1.04801056e+00 3.41508300e-01 + -3.99256195e+00 -2.51176471e+00 + -7.62824318e-01 -6.84242153e-01 + 2.71524986e-02 5.35157164e-02 + 3.26430102e+00 1.34887262e+00 + -1.72357766e+00 -4.94524388e-01 + -3.81149536e+00 -1.28121944e+00 + 3.36919354e+00 1.10672075e+00 + -3.14841757e+00 -7.10713767e-01 + -3.16463676e+00 -7.58558435e-01 + -2.44745969e+00 -1.08816514e+00 + 2.79173264e-01 -2.19652051e-02 + 4.15309883e-01 6.07502790e-01 + -9.51007417e-01 -5.83976336e-01 + -1.47929839e+00 -8.39850409e-01 + 2.38335703e+00 6.16055149e-01 + -7.47749031e-01 -5.56164928e-01 + -3.65643622e-01 -5.06684411e-01 + -1.76634163e+00 -7.86382097e-01 + 6.76372222e-01 -3.06592181e-01 + -1.33505058e+00 -1.18301441e-01 + 3.59660179e+00 2.00424178e+00 + -7.88912762e-02 8.71956146e-02 + 1.22656397e+00 1.18149583e+00 + 4.24919729e+00 1.20082355e+00 + 2.94607456e+00 1.00676505e+00 + 7.46061275e-02 4.41761753e-02 + -2.47738025e-02 1.92737701e-01 + -2.20509316e-01 -3.79163193e-01 + -3.50222190e-01 3.58727299e-01 + -3.64788014e+00 -1.36107312e+00 + 3.56062799e+00 9.27032742e-01 + 1.04317289e+00 6.08035970e-01 + 4.06718718e-01 3.00628051e-01 + 4.33158086e+00 2.25860714e+00 + 2.13917145e-01 -1.72757967e-01 + -1.40637998e+00 -1.14119465e+00 + 3.61554872e+00 1.87797348e+00 + 1.01726871e+00 5.70255097e-01 + -7.04902551e-01 2.16444147e-01 + -2.51492186e+00 -8.52997369e-01 + 1.85097530e+00 1.15124496e+00 + -8.67569714e-01 -3.05682432e-01 + 8.07550858e-01 5.88901608e-01 + 1.85186755e-01 -1.94589367e-01 + -1.23378238e+00 -7.84128347e-01 + -1.22713161e+00 -4.21218235e-01 + 2.97751165e-01 2.81055275e-01 + 4.77703554e+00 1.66265524e+00 + 2.51549669e+00 7.49980674e-01 + 2.76510822e-01 1.40456909e-01 + 1.98740905e+00 -1.79608212e-01 + 9.35429145e-01 8.44344180e-01 + -1.20854492e+00 -5.00598453e-01 + 2.29936219e+00 8.10236668e-01 + 6.92555544e-01 -2.65891331e-01 + -1.58050994e+00 2.31237821e-01 + -1.50864880e+00 -9.49661690e-01 + -1.27689206e+00 -7.18260016e-01 + -3.12517127e+00 -1.75587113e+00 + 8.16062912e-02 -6.56551804e-01 + -5.02479939e-01 -4.67162543e-01 + -5.47435788e+00 -2.47799576e+00 + 1.95872901e-02 5.80874076e-01 + -1.59064958e+00 -6.34554756e-01 + -3.77521478e+00 -1.74301790e+00 + 5.89628224e-01 8.55736553e-01 + -1.81903543e+00 -7.50011008e-01 + 1.38557775e+00 3.71490991e-01 + 9.70032652e-01 -7.11356016e-01 + 2.63539625e-01 -4.20994771e-01 + 2.12154222e+00 8.19081400e-01 + -6.56977937e-01 -1.37810098e-01 + 8.91309581e-01 2.77864361e-01 + -7.43693195e-01 -1.46293770e-01 + 2.24447769e+00 4.00911438e-01 + -2.25169262e-01 2.04148801e-02 + 1.68744684e+00 9.47573007e-01 + 2.73086373e-01 3.30877195e-01 + 5.54294414e+00 2.14198009e+00 + -8.49238733e-01 3.65603298e-02 + 2.39685712e+00 1.17951039e+00 + -2.58230528e+00 -5.52116673e-01 + 2.79785277e+00 2.88833717e-01 + -1.96576188e-01 1.11652123e+00 + -4.69383301e-01 1.96496282e-01 + -1.95011845e+00 -6.15235169e-01 + 1.03379890e-02 2.33701239e-01 + 4.18933607e-01 2.77939814e-01 + -1.18473337e+00 -4.10051126e-01 + -7.61499744e-01 -1.43658094e+00 + -1.65586092e+00 -3.41615303e-01 + -5.58523700e-02 -5.21837080e-01 + -2.40331088e+00 -2.64521583e-01 + 2.24925206e+00 6.79843335e-02 + 1.46360479e+00 1.04271443e+00 + -3.09255443e+00 -1.82548953e+00 + 2.11325841e+00 1.14996627e+00 + -8.70657797e-01 1.02461839e-01 + -5.71056521e-01 9.71232588e-02 + -3.37870752e+00 -1.54091877e+00 + 1.03907189e+00 -1.35661392e-01 + 8.40057486e-01 6.12172413e-02 + -1.30998234e+00 -1.34077226e+00 + 7.53744974e-01 1.49447350e-01 + 9.13995056e-01 -1.81227962e-01 + 2.28386229e-01 3.74498520e-01 + 2.54829151e-01 -2.88802704e-01 + 1.61709009e+00 2.09319193e-01 + -1.12579380e+00 -5.95955338e-01 + -2.69610726e+00 -2.76222736e-01 + -2.63773329e+00 -7.84491970e-01 + -2.62167427e+00 -1.54792874e+00 + -4.80639856e-01 -1.30582102e-01 + -1.26130891e+00 -8.86841840e-01 + -1.24951950e+00 -1.18182622e+00 + -1.40107574e+00 -9.13695575e-01 + 4.99872179e-01 4.69014702e-01 + -2.03550193e-02 -1.48859738e-01 + -1.50189069e+00 -2.97714278e-02 + -2.07846113e+00 -7.29937809e-01 + -5.50576792e-01 -7.03151525e-01 + -3.88069238e+00 -1.63215295e+00 + 2.97032988e+00 6.43571144e-01 + -1.85999273e-01 1.18107620e+00 + 1.79249709e+00 6.65356160e-01 + 2.68842472e+00 1.35703255e+00 + 1.07675417e+00 1.39845588e-01 + 8.01226349e-01 2.11392275e-01 + 9.64329379e-01 3.96146195e-01 + -8.22529511e-01 1.96080831e-01 + 1.92481841e+00 4.62985744e-01 + 3.69756927e-01 3.77135799e-01 + 1.19807835e+00 8.87715050e-01 + -1.01363587e+00 -2.48151636e-01 + 8.53071010e-01 4.96887868e-01 + -3.41120553e+00 -1.35401843e+00 + -2.64787381e+00 -1.08690563e+00 + -1.11416759e+00 -4.43848915e-01 + 1.46242648e+00 6.17106076e-02 + -7.52968881e-01 -9.20972209e-01 + -1.22492228e+00 -5.40327617e-01 + 1.08001827e+00 5.29593785e-01 + -2.58706464e-01 1.13022085e-01 + -4.27394011e-01 1.17864354e-02 + -3.20728413e+00 -1.71224737e-01 + 1.71398530e+00 8.68885893e-01 + 2.12067866e+00 1.45092772e+00 + 4.32782616e-01 -3.34117769e-01 + 7.80084374e-01 -1.35100217e-01 + -2.05547729e+00 -4.70217750e-01 + 2.38379736e+00 1.09186058e+00 + -2.80825477e+00 -1.03320187e+00 + 2.63434576e+00 1.15671733e+00 + -1.60936214e+00 1.91843035e-01 + -5.02298769e+00 -2.32820708e+00 + 1.90349195e+00 1.45215416e+00 + 3.00232888e-01 3.24412586e-01 + -2.46503943e+00 -1.19550010e+00 + 1.06304233e+00 2.20136246e-01 + -2.99101388e+00 -1.58299318e+00 + 2.30071719e+00 1.12881362e+00 + -2.37587247e+00 -8.08298336e-01 + 7.27006308e-01 3.80828984e-01 + 2.61199061e+00 1.56473491e+00 + 8.33936357e-01 -1.42189425e-01 + 3.13291605e+00 1.77771210e+00 + 2.21917371e+00 5.68427075e-01 + 2.38867649e+00 9.06637262e-01 + -6.92959466e+00 -3.57682881e+00 + 2.57904824e+00 5.93959108e-01 + 2.71452670e+00 1.34436199e+00 + 4.39988761e+00 2.13124672e+00 + 5.71783077e-01 5.08346173e-01 + -3.65399429e+00 -1.18192861e+00 + 4.46176453e-01 3.75685594e-02 + -2.97501495e+00 -1.69459236e+00 + 1.60855728e+00 9.20930014e-01 + -1.44270290e+00 -1.93922306e-01 + 1.67624229e+00 1.66233866e+00 + -1.42579598e+00 -1.44990145e-01 + 1.19923176e+00 4.58490278e-01 + -9.00068460e-01 5.09701825e-02 + -1.69391694e+00 -7.60070300e-01 + -1.36576440e+00 -5.24244256e-01 + -1.03016748e+00 -3.44625878e-01 + 2.40519313e+00 1.09947587e+00 + 1.50365433e+00 1.06464802e+00 + -1.07609727e+00 -3.68897187e-01 + 2.44969069e+00 1.28486192e+00 + -1.25610307e+00 -1.14644789e+00 + 2.05962899e+00 4.31162369e-01 + -7.15886908e-01 -6.11587804e-02 + -6.92354119e-01 -7.85019920e-01 + -1.63016508e+00 -5.96944975e-01 + 1.90352536e+00 1.28197457e+00 + -4.01535243e+00 -1.81934488e+00 + -1.07534435e+00 -2.10544784e-01 + 3.25500866e-01 7.69603661e-01 + 2.18443365e+00 6.59773335e-01 + 8.80856790e-01 6.39505913e-01 + -2.23956372e-01 -4.65940132e-01 + -1.06766519e+00 -5.38388505e-03 + 7.25556863e-01 -2.91123488e-01 + -4.69451411e-01 7.89182650e-02 + 2.58146587e+00 1.29653243e+00 + 1.53747468e-01 7.69239075e-01 + -4.61152262e-01 -4.04151413e-01 + 1.48183517e+00 8.10079506e-01 + -1.83402614e+00 -1.36939322e+00 + 1.49315501e+00 7.95225425e-01 + 1.41922346e+00 1.05582774e-01 + 1.57473493e-01 9.70795657e-01 + -2.67603254e+00 -7.48562280e-01 + -8.49156216e-01 -6.05762529e-03 + 1.12944274e+00 3.67741591e-01 + 1.94228071e-01 5.28188141e-01 + -3.65610158e-01 4.05851838e-01 + -1.98839111e+00 -1.38452764e+00 + 2.73765752e+00 8.24150530e-01 + 7.63728641e-01 3.51617707e-01 + 5.78307267e+00 1.68103612e+00 + 2.27547227e+00 3.60876164e-01 + -3.50681697e+00 -1.74429984e+00 + 4.01241184e+00 1.26227829e+00 + 2.44946343e+00 9.06119057e-01 + -2.96638941e+00 -9.01532322e-01 + 1.11267643e+00 -3.43333381e-01 + -6.61868994e-01 -3.44666391e-01 + -8.34917179e-01 5.69478372e-01 + -1.91888454e+00 -3.03791075e-01 + 1.50397636e+00 8.31961240e-01 + 6.12260198e+00 2.16851807e+00 + 1.34093127e+00 8.86649385e-01 + 1.48748519e+00 8.26273697e-01 + 7.62243068e-01 2.64841396e-01 + -2.17604986e+00 -3.54219958e-01 + 2.64708640e-01 -4.38136718e-02 + 1.44725372e+00 1.18499914e-01 + -6.71259446e-01 -1.19526851e-01 + 2.40134595e-01 -8.90042323e-02 + -3.57238199e+00 -1.23166201e+00 + -3.77626645e+00 -1.19533443e+00 + -3.81101035e-01 -4.94160532e-01 + -3.02758757e+00 -1.18436066e+00 + 2.59116298e-01 1.38023047e+00 + 4.17900116e+00 1.12065959e+00 + 1.54598848e+00 2.89806755e-01 + 1.00656475e+00 1.76974511e-01 + -4.15730234e-01 -6.22681694e-01 + -6.00903565e-01 -1.43256959e-01 + -6.03652508e-01 -5.09936379e-01 + -1.94096658e+00 -9.48789544e-01 + -1.74464105e+00 -8.50491590e-01 + 1.17652544e+00 1.88118317e+00 + 2.35507776e+00 1.44000205e+00 + 2.63067924e+00 1.06692988e+00 + 2.88805386e+00 1.23924715e+00 + 8.27595008e-01 5.75364692e-01 + 3.91384216e-01 9.72781920e-02 + -1.03866816e+00 -1.37567768e+00 + -1.34777969e+00 -8.40266025e-02 + -4.12904508e+00 -1.67618340e+00 + 1.27918111e+00 3.52085961e-01 + 4.15361174e-01 6.28896189e-01 + -7.00539496e-01 4.80447955e-02 + -1.62332639e+00 -5.98236485e-01 + 1.45957300e+00 1.00305154e+00 + -3.06875603e+00 -1.25897545e+00 + -1.94708176e+00 4.85143006e-01 + 3.55744156e+00 -1.07468822e+00 + 1.21602223e+00 1.28768827e-01 + 1.89093098e+00 -4.70835659e-01 + -6.55759125e+00 2.70114082e+00 + 8.96843535e-01 -3.98115252e-01 + 4.13450429e+00 -2.32069236e+00 + 2.37764218e+00 -1.09098890e+00 + -1.11388901e+00 6.27083097e-01 + -6.34116929e-01 4.62816387e-01 + 2.90203079e+00 -1.33589143e+00 + 3.17457598e+00 -5.13575945e-01 + -1.76362299e+00 5.71820693e-01 + 1.66103362e+00 -8.99466249e-01 + -2.53947433e+00 8.40084780e-01 + 4.36631397e-01 7.24234261e-02 + -1.87589394e+00 5.08529113e-01 + 4.49563965e+00 -9.43365992e-01 + 1.78876299e+00 -1.27076149e+00 + -1.16269107e-01 -4.55078316e-01 + 1.92966079e+00 -8.05371385e-01 + 2.20632583e+00 -9.00919345e-01 + 1.52387824e+00 -4.82391996e-01 + 8.04004564e-01 -2.73650595e-01 + -7.75326067e-01 1.07469566e+00 + 1.83226282e+00 -4.52173344e-01 + 1.25079758e-01 -3.52895417e-02 + -9.90957437e-01 8.55993130e-01 + 1.71623322e+00 -7.08691667e-01 + -2.86175924e+00 6.75160955e-01 + -8.40817853e-01 -1.00361809e-01 + 1.33393000e+00 -4.65788123e-01 + 5.29394114e-01 -5.44881619e-02 + -8.07435599e-01 8.27353370e-01 + -4.33165824e+00 1.97299638e+00 + 1.26452422e+00 -8.34070486e-01 + 1.45996394e-02 2.97736043e-01 + -1.64489287e+00 6.72839598e-01 + -5.74234578e+00 3.20975117e+00 + 2.13841341e-02 3.64514015e-01 + 6.68084924e+00 -2.27464254e+00 + -3.22881590e+00 8.01879324e-01 + 3.02534313e-01 -4.56222796e-01 + -5.84520734e+00 1.95678162e+00 + 2.81515232e+00 -1.72101318e+00 + -2.39620908e-01 2.69145522e-01 + -7.41669691e-01 -2.30283281e-01 + -2.15682714e+00 3.45313021e-01 + 1.23475788e+00 -7.32276553e-01 + -1.71816113e-01 1.20419560e-02 + 1.89174235e+00 2.27435901e-01 + -3.64511114e-01 1.72260361e-02 + -3.24143860e+00 6.50125817e-01 + -2.25707409e+00 5.66970751e-01 + 1.03901456e+00 -1.00588433e+00 + -5.09159710e+00 1.58736109e+00 + 1.45534075e+00 -5.83787452e-01 + 4.28879587e+00 -1.58006866e+00 + 8.52384427e-01 -1.11042299e+00 + 4.51431615e+00 -2.63844265e+00 + -4.33042648e+00 1.86497078e+00 + -2.13568046e+00 5.82559743e-01 + -4.42568887e+00 1.26131214e+00 + 3.15821315e+00 -1.61515905e+00 + -3.14125204e+00 8.49604386e-01 + 6.54152300e-01 -2.04624711e-01 + -3.73374317e-01 9.94187820e-02 + -3.96177282e+00 1.27245623e+00 + 9.59825199e-01 -1.15547861e+00 + 3.56902055e+00 -1.46591091e+00 + 1.55433633e-02 6.93544345e-01 + 1.15684646e+00 -4.99836352e-01 + 3.11824573e+00 -4.75900506e-01 + -8.61706369e-01 -3.50774059e-01 + 9.89057391e-01 -7.16878802e-01 + -4.94787870e+00 2.09137481e+00 + 1.37777347e+00 -1.34946349e+00 + -1.13161577e+00 8.05114754e-01 + 8.12020675e-01 -1.04849421e+00 + 4.73783881e+00 -2.26718812e+00 + 8.99579366e-01 -8.89764451e-02 + 4.78524868e+00 -2.25795843e+00 + 1.75164590e+00 -1.73822209e-01 + 1.30204590e+00 -7.26724717e-01 + -7.26526403e-01 -5.23925361e-02 + 2.01255351e+00 -1.69965366e+00 + 9.87852740e-01 -4.63577220e-01 + 2.45957762e+00 -1.29278962e+00 + -3.13817948e+00 1.64433038e+00 + -1.76302159e+00 9.62784302e-01 + -1.91106331e+00 5.81460008e-01 + -3.30883001e+00 1.30378978e+00 + 5.54376450e-01 3.78814272e-01 + 1.09982111e+00 -1.47969612e+00 + -2.61300705e-02 -1.42573464e-01 + -2.22096157e+00 7.75684440e-01 + 1.70319323e+00 -2.89738444e-01 + -1.43223842e+00 6.39284281e-01 + 2.34360959e-01 -1.64379268e-01 + -2.67147991e+00 9.46548086e-01 + 1.51131425e+00 -4.91594395e-01 + -2.48446856e+00 1.01286123e+00 + 1.50534658e-01 -2.94620246e-01 + -1.66966792e+00 1.67755508e+00 + -1.50094241e+00 3.30163095e-01 + 2.27681194e+00 -1.08064317e+00 + 2.05122965e+00 -1.15165939e+00 + -4.23509309e-01 -6.56906167e-02 + 1.80084023e+00 -1.07228556e+00 + -2.65769521e+00 1.18023206e+00 + 2.02852676e+00 -8.06793574e-02 + -4.49544185e+00 2.68200163e+00 + -7.50043216e-01 1.17079331e+00 + 6.80060893e-02 3.99055351e-01 + -3.83634635e+00 1.38406887e+00 + 3.24858545e-01 -9.25273218e-02 + -2.19895100e+00 1.47819500e+00 + -3.61569522e-01 -1.03188739e-01 + 1.12180375e-01 -9.52696354e-02 + -1.31477803e+00 1.79900570e-01 + 2.39573628e+00 -6.09739269e-01 + -1.00135700e+00 6.02837296e-01 + -4.11994589e+00 2.49599192e+00 + -1.54196236e-01 -4.84921951e-01 + 5.92569908e-01 -1.87310359e-01 + 3.85407741e+00 -1.50979925e+00 + 5.17802528e+00 -2.26032607e+00 + -1.37018916e+00 1.87111822e-01 + 8.46682996e-01 -3.56676331e-01 + -1.17559949e+00 5.29057734e-02 + -5.56475671e-02 6.79049243e-02 + 1.07851745e+00 -5.14535101e-01 + -2.71622446e+00 1.00151846e+00 + -1.08477208e+00 8.81391054e-01 + 5.50755824e-01 -5.20577727e-02 + 4.70885495e+00 -2.04220397e+00 + -1.87375336e-01 -6.16962830e-02 + 3.52097100e-01 2.21163550e-01 + 7.07929984e-01 -1.75827590e-01 + -1.22149219e+00 1.83084346e-01 + 2.58247412e+00 -6.15914898e-01 + -6.01206182e-01 -2.29832987e-01 + 9.83360449e-01 -3.75870060e-01 + -3.20027685e+00 1.35467480e+00 + 1.79178978e+00 -1.38531981e+00 + -3.30376867e-01 -1.16250192e-01 + -1.89053055e+00 5.68463567e-01 + -4.20604849e+00 1.65429681e+00 + -1.01185529e+00 1.92801240e-01 + -6.18819882e-01 5.42206996e-01 + -5.08091672e+00 2.61598591e+00 + -2.62570344e+00 2.51590658e+00 + 3.05577906e+00 -1.49090609e+00 + 2.77609677e+00 -1.37681378e+00 + -7.93515301e-02 4.28072744e-01 + -2.08359471e+00 8.94334295e-01 + 2.20163801e+00 4.01127167e-02 + -1.18145785e-01 -2.06822464e-01 + -2.74788298e-01 2.96250607e-01 + 1.59613555e+00 -3.87246203e-01 + -3.82971472e-01 -3.39716093e-02 + -4.20311307e-02 3.88529510e-01 + 1.52128574e+00 -9.33138876e-01 + -9.06584458e-01 -2.75016094e-02 + 3.56216834e+00 -9.99384622e-01 + 2.11964220e+00 -9.98749118e-02 + 4.01203480e+00 -2.03032745e+00 + -1.24171557e+00 1.97596725e-01 + -1.57230455e+00 4.14126609e-01 + -1.85484741e+00 5.40041563e-01 + 1.76329831e+00 -6.95967734e-01 + -2.29439232e-01 5.08669245e-01 + -5.45124276e+00 2.26907549e+00 + -5.71364288e-02 5.04476476e-01 + 3.12468018e+00 -1.46358879e+00 + 8.20017359e-01 6.51949028e-01 + -1.33977500e+00 2.83634232e-04 + -1.83311685e+00 1.23947117e+00 + 6.31205922e-01 1.19792164e-02 + -2.21967834e+00 6.94056232e-01 + -1.41693842e+00 9.93526233e-01 + -7.58885703e-01 6.78547347e-01 + 3.60239086e+00 -1.08644935e+00 + 6.72217073e-02 3.00036011e-02 + -3.42680958e-01 -3.48049352e-01 + 1.87546079e+00 -4.78018246e-01 + 7.00485821e-01 -3.52905383e-01 + -8.54580948e-01 8.17330861e-01 + 8.19123706e-01 -5.73927281e-01 + 2.70855639e-01 -3.08940052e-01 + -1.05059952e+00 3.27873168e-01 + 1.08282999e+00 4.84559349e-02 + -7.89899220e-01 1.22291138e+00 + -2.87939816e+00 7.17403497e-01 + -2.08429452e+00 8.87409226e-01 + 1.58409232e+00 -4.74123532e-01 + 1.26882735e+00 1.59162510e-01 + -2.53782993e+00 6.18253491e-01 + -8.92757445e-01 3.35979011e-01 + 1.31867900e+00 -1.17355054e+00 + 1.14918879e-01 -5.35184038e-01 + -1.70288738e-01 5.35868087e-02 + 4.21355121e-01 5.41848690e-02 + 2.07926943e+00 -5.72538144e-01 + 4.08788970e-01 3.77655777e-01 + -3.39631381e+00 9.84216764e-01 + 2.94170163e+00 -1.83120916e+00 + -7.94798752e-01 7.39889052e-01 + 1.46555463e+00 -4.62275563e-01 + 2.57255955e+00 -1.04671434e+00 + 8.45042540e-01 -1.96952892e-01 + -3.23526646e+00 1.60049846e+00 + 3.21948565e+00 -8.88376674e-01 + 1.43005104e+00 -9.21561086e-01 + 8.82360506e-01 2.98403872e-01 + -8.91168097e-01 1.01319072e+00 + -5.13215241e-01 -2.47182649e-01 + -1.35759444e+00 7.07450608e-02 + -4.04550983e+00 2.23534867e+00 + 1.39348883e+00 3.81637747e-01 + -2.85676418e+00 1.53240862e+00 + -1.37183120e+00 6.37977425e-02 + -3.88195859e+00 1.73887145e+00 + 1.19509776e+00 -6.25013512e-01 + -2.80062734e+00 1.79840585e+00 + 1.96558429e+00 -4.70997234e-01 + 1.93111352e+00 -9.70318441e-01 + 3.57991190e+00 -1.65065116e+00 + 2.12831714e+00 -1.11531708e+00 + -3.95661018e-01 -8.54339904e-02 + -2.41630441e+00 1.65166304e+00 + 7.55412624e-01 -1.53453579e-01 + -1.77043450e+00 1.39928715e+00 + -9.32631260e-01 8.73649199e-01 + 1.53342205e+00 -8.39569765e-01 + -6.29846924e-02 1.25023084e-01 + 3.31509049e+00 -1.10733235e+00 + -2.18957109e+00 3.07376993e-01 + -2.35740747e+00 6.47437564e-01 + -2.22142438e+00 8.47318938e-01 + -6.51401147e-01 3.48398562e-01 + 2.75763095e+00 -1.21390708e+00 + 1.12550484e+00 -5.61412847e-01 + -5.65053161e-01 6.74365205e-02 + 1.68952456e+00 -6.57566096e-01 + 8.95598401e-01 3.96738993e-01 + -1.86537066e+00 9.44129208e-01 + -2.59933294e+00 2.57423247e-01 + -6.59598267e-01 1.91828851e-02 + -2.64506676e+00 8.41783205e-01 + -1.25911802e+00 5.52425066e-01 + -1.39754507e+00 3.73689222e-01 + 5.49550729e-02 1.35071215e+00 + 3.31874811e+00 -1.05682424e+00 + 3.63159604e+00 -1.42864695e+00 + -4.45944617e+00 1.42889446e+00 + 5.87314342e-01 -4.88892988e-01 + -7.26130820e-01 1.51936106e-01 + -1.79246441e+00 6.05888105e-01 + -5.50948207e-01 6.21443081e-01 + -3.17246063e-01 1.77213880e-01 + -2.00098937e+00 1.23799074e+00 + 4.33790961e+00 -1.08490465e+00 + -2.03114114e+00 1.31613237e+00 + -6.29216542e+00 1.92406317e+00 + -1.60265624e+00 8.87947500e-01 + 8.64465062e-01 -8.37416270e-01 + -2.14273937e+00 8.05485900e-01 + -2.36844256e+00 6.17915124e-01 + -1.40429636e+00 6.78296866e-01 + 9.99019988e-01 -5.84297572e-01 + 7.38824546e-01 1.68838678e-01 + 1.45681238e+00 3.04641461e-01 + 2.15914949e+00 -3.43089227e-01 + -1.23895930e+00 1.05339864e-01 + -1.23162264e+00 6.46629863e-01 + 2.28183862e+00 -9.24157063e-01 + -4.29615882e-01 5.69130863e-01 + -1.37449121e+00 -9.12032183e-01 + -7.33890904e-01 -3.91865471e-02 + 8.41400661e-01 -4.76002200e-01 + -1.73349274e-01 -6.84143467e-02 + 3.16042891e+00 -1.32651856e+00 + -3.78244609e+00 2.38619718e+00 + -3.69634380e+00 2.22368561e+00 + 1.83766344e+00 -1.65675953e+00 + -1.63206002e+00 1.19484469e+00 + 3.68480064e-01 -5.70764494e-01 + 3.61982479e-01 1.04274409e-01 + 2.48863048e+00 -1.13285542e+00 + -2.81896488e+00 9.47958768e-01 + 5.74952901e-01 -2.75959392e-01 + 3.72783275e-01 -3.48937848e-01 + 1.95935716e+00 -1.06750415e+00 + 5.19357531e+00 -2.32070803e+00 + 4.09246149e+00 -1.89976700e+00 + -3.36666087e-01 8.17645057e-02 + 1.85453493e-01 3.76913151e-01 + -3.06458262e+00 1.34106402e+00 + -3.13796566e+00 7.00485099e-01 + 1.42964058e+00 -1.35536932e-01 + -1.23440423e-01 4.60094177e-02 + -2.86753037e+00 -5.21724160e-02 + 2.67113726e+00 -1.83746924e+00 + -1.35335062e+00 1.28238073e+00 + -2.43569899e+00 1.25998539e+00 + 1.26036740e-01 -2.35416844e-01 + -1.35725745e+00 7.37788491e-01 + -3.80897538e-01 3.30757889e-01 + 6.58694434e-01 -1.07566603e+00 + 2.11273640e+00 -9.02260632e-01 + 4.00755057e-01 -2.49229150e-02 + -1.80095812e+00 9.73099742e-01 + -2.68408372e+00 1.63737364e+00 + -2.66079826e+00 7.47289412e-01 + -9.92321439e-02 -1.49331396e-01 + 4.45678251e+00 -1.80352394e+00 + 1.35962915e+00 -1.31554389e+00 + -7.76601417e-01 -9.66173523e-02 + 1.68096348e+00 -6.27235133e-01 + 1.53081227e-01 -3.54216830e-01 + -1.54913095e+00 3.43689269e-01 + 5.29187357e-02 -6.73916964e-01 + -2.06606084e+00 8.34784242e-01 + 1.73701179e+00 -6.06467340e-01 + 1.55856757e+00 -2.58642780e-01 + 1.04349101e+00 -4.43027348e-01 + -1.02397719e+00 1.01308824e+00 + -2.13860204e-01 -4.73347361e-01 + -2.59004955e+00 1.43367853e+00 + 7.98457679e-01 2.18621627e-02 + -1.32974762e+00 4.61802208e-01 + 3.21419359e-01 2.30723316e-02 + 2.87201888e-02 6.24566672e-02 + -1.22261418e+00 6.02340363e-01 + 1.28750335e+00 -3.34839548e-02 + -9.67952623e-01 4.34470505e-01 + 2.02850324e+00 -9.05160255e-01 + -4.13946010e+00 2.33779091e+00 + -4.47508806e-01 3.06440495e-01 + -3.91543394e+00 1.68251022e+00 + -6.45193001e-01 5.29781162e-01 + -2.15518916e-02 5.07278355e-01 + -2.83356868e+00 1.00670227e+00 + 1.82989749e+00 -1.37329222e+00 + -1.09330213e+00 1.08560688e+00 + 1.90533722e+00 -1.28905879e+00 + 2.33986084e+00 2.30642626e-02 + 8.01940220e-01 -1.63986962e+00 + -4.23415165e+00 2.07530423e+00 + 9.33382522e-01 -7.62917211e-01 + -1.84033954e+00 1.07469401e+00 + -2.81938669e+00 1.07342024e+00 + -7.05169988e-01 2.13124943e-01 + 5.09598137e-01 1.32725493e-01 + -2.34558226e+00 8.62383168e-01 + -1.70322072e+00 2.70893796e-01 + 1.23652660e+00 -7.53216034e-02 + 2.84660646e+00 -3.48178304e-02 + 2.50250128e+00 -1.27770855e+00 + -1.00279469e+00 8.77194218e-01 + -4.34674121e-02 -2.12091350e-01 + -5.84151289e-01 1.50382340e-01 + -1.79024013e+00 4.24972808e-01 + -1.23434666e+00 -8.85546570e-02 + 1.36575412e+00 -6.42639880e-01 + -1.98429947e+00 2.27650336e-01 + 2.36253589e+00 -1.51340773e+00 + 8.79157643e-01 6.84142159e-01 + -2.18577755e+00 2.76526200e-01 + -3.55473434e-01 8.29976561e-01 + 1.16442595e+00 -5.97699411e-01 + -7.35528097e-01 2.40318183e-01 + -1.73702631e-01 7.33788663e-02 + -1.40451745e+00 3.24899628e-01 + -2.05434385e+00 5.68123738e-01 + 8.47876642e-01 -5.74224294e-01 + -6.91955602e-01 1.26009087e+00 + 2.56574498e+00 -1.15602581e+00 + 3.93306545e+00 -1.38398209e+00 + -2.73230251e+00 4.89062581e-01 + -1.04315474e+00 6.06335547e-01 + 1.23231431e+00 -4.46675065e-01 + -3.93035285e+00 1.43287651e+00 + -1.02132111e+00 9.58919791e-01 + -1.49425352e+00 1.06456165e+00 + -6.26485337e-01 1.03791402e+00 + -6.61772998e-01 2.63275425e-01 + -1.80940386e+00 5.70767403e-01 + 9.83720450e-01 -1.39449756e-01 + -2.24619662e+00 9.01044870e-01 + 8.94343014e-01 5.31038678e-02 + 1.95518199e-01 -2.81343295e-01 + -2.30533019e-01 -1.74478106e-01 + -2.01550361e+00 5.55958010e-01 + -4.36281469e+00 1.94374226e+00 + -5.18530457e+00 2.89278357e+00 + 2.67289101e+00 -2.98511449e-01 + -1.53566179e+00 -1.00588944e-01 + -6.09943217e-02 -1.56986047e-01 + -5.22146452e+00 1.66209208e+00 + -3.69777478e+00 2.26154873e+00 + 2.24607181e-01 -4.86934960e-01 + 2.49909450e+00 -1.03033370e+00 + -1.07841120e+00 8.22388054e-01 + -3.20697089e+00 1.09536143e+00 + 3.43524232e+00 -1.47289362e+00 + -5.65784134e-01 4.60365175e-01 + -1.76714734e+00 1.57752346e-01 + -7.77620365e-01 5.60153443e-01 + 6.34399352e-01 -5.22339836e-01 + 2.91011875e+00 -9.72623380e-01 + -1.19286824e+00 6.32370253e-01 + -2.18327609e-01 8.23953181e-01 + 3.42430842e-01 1.37098055e-01 + 1.28658034e+00 -9.11357320e-01 + 2.06914465e+00 -6.67556382e-01 + -6.69451020e-01 -6.38605102e-01 + -2.09312398e+00 1.16743634e+00 + -3.63778357e+00 1.91919157e+00 + 8.74685911e-01 -1.09931208e+00 + -3.91496791e+00 1.00808357e+00 + 1.29621330e+00 -8.32239802e-01 + 9.00222045e-01 -1.31159793e+00 + -1.12242062e+00 1.98517079e-01 + -3.71932852e-01 1.31667093e-01 + -2.23829610e+00 1.26328346e+00 + -2.08365062e+00 9.93385336e-01 + -1.91082720e+00 7.45866855e-01 + 4.38024917e+00 -2.05901118e+00 + -2.28872886e+00 6.85279335e-01 + 1.01274497e-01 -3.26227153e-01 + -5.04447572e-01 -3.18619513e-01 + 1.28537006e+00 -1.04573551e+00 + -7.83175212e-01 1.54791645e-01 + -3.89239175e+00 1.60017929e+00 + -8.87877111e-01 -1.04968005e-01 + 9.32215179e-01 -5.58691113e-01 + -6.44977127e-01 -2.23018375e-01 + 1.10141900e+00 -1.00666432e+00 + 2.92755687e-01 -1.45480350e-01 + 7.73580681e-01 -2.21150567e-01 + -1.40873709e+00 7.61548044e-01 + -8.89031805e-01 -3.48542923e-01 + 4.16844267e-01 -2.39914494e-01 + -4.64265832e-01 7.29581138e-01 + 1.99835179e+00 -7.70542813e-01 + 4.20523191e-02 -2.18783563e-01 + -6.32611758e-01 -3.09926115e-01 + 6.82912198e-02 -8.48327050e-01 + 1.92425229e+00 -1.37876951e+00 + 3.49461782e+00 -1.88354255e+00 + -3.25209026e+00 1.49809395e+00 + 6.59273182e-01 -2.37435654e-01 + -1.15517300e+00 8.46134387e-01 + 1.26756151e+00 -4.58988026e-01 + -3.99178418e+00 2.04153008e+00 + 7.05687841e-01 -6.83433306e-01 + -1.61997342e+00 8.16577004e-01 + -3.89750399e-01 4.29753250e-01 + -2.53026432e-01 4.92861432e-01 + -3.16788324e+00 4.44285524e-01 + -7.86248901e-01 1.12753716e+00 + -3.02351433e+00 1.28419015e+00 + -1.30131355e+00 1.71226678e+00 + -4.08843475e+00 1.62063214e+00 + -3.09209403e+00 1.19958520e+00 + 1.49102271e+00 -1.11834864e+00 + -3.18059348e+00 5.74587042e-01 + 2.06054867e+00 3.25797860e-03 + -3.50999200e+00 2.02412428e+00 + -8.26610023e-01 3.46528211e-01 + 2.00546034e+00 -4.07333110e-01 + -9.69941653e-01 4.80953753e-01 + 4.47925660e+00 -2.33127314e+00 + 2.03845790e+00 -9.90439915e-01 + -1.11349191e+00 4.31183918e-01 + -4.03628396e+00 1.68509679e+00 + -1.48177601e+00 7.74322088e-01 + 3.07369385e+00 -9.57465886e-01 + 2.39011286e+00 -6.44506921e-01 + 2.91561991e+00 -8.78627328e-01 + 1.10212733e+00 -4.21637388e-01 + 5.31985231e-01 -6.17445696e-01 + -6.82340929e-01 -2.93529716e-01 + 1.94290679e+00 -4.64268634e-01 + 1.92262116e+00 -7.93142835e-01 + 4.73762800e+00 -1.63654174e+00 + -3.17848641e+00 8.05791391e-01 + 4.08739432e+00 -1.80816807e+00 + -7.60648826e-01 1.24216138e-01 + -2.24716400e+00 7.90020937e-01 + 1.64284052e+00 -7.18784070e-01 + 1.04410012e-01 -7.11195880e-02 + 2.18268225e+00 -7.01767831e-01 + 2.06218013e+00 -8.70251746e-01 + -1.35266581e+00 7.08456358e-01 + -1.38157779e+00 5.14401086e-01 + -3.28326008e+00 1.20988399e+00 + 8.85358917e-01 -8.12213495e-01 + -2.34067500e+00 3.67657353e-01 + 3.96878127e+00 -1.66841450e+00 + 1.36518053e+00 -8.33436812e-01 + 5.25771988e-01 -5.06121987e-01 + -2.25948361e+00 1.30663765e+00 + -2.57662070e+00 6.32114628e-01 + -3.43134685e+00 2.38106008e+00 + 2.31571924e+00 -1.56566818e+00 + -2.95397202e+00 1.05661888e+00 + -1.35331242e+00 6.76383411e-01 + 1.40977132e+00 -1.17775938e+00 + 1.52561996e+00 -9.83147176e-01 + 2.26550832e+00 -2.10464123e-02 + 6.23371684e-01 -5.30768122e-01 + -4.42356624e-01 9.72226986e-01 + 2.31517901e+00 -1.08468105e+00 + 1.97236640e+00 -1.42016619e+00 + 3.18618687e+00 -1.45056343e+00 + -2.75880360e+00 5.40254980e-01 + -1.92916581e+00 1.45029864e-01 + 1.90022524e+00 -6.03805754e-01 + -1.05446211e+00 5.74361752e-01 + 1.45990390e+00 -9.28233993e-01 + 5.14960557e+00 -2.07564096e+00 + -7.53104842e-01 1.55876958e-01 + 8.09490983e-02 -8.58886384e-02 + -1.56894969e+00 4.53497227e-01 + 1.36944658e-01 5.60670875e-01 + -5.32635329e-01 4.40309945e-01 + 1.32507853e+00 -5.83670099e-01 + 1.20676031e+00 -8.02296831e-01 + -3.65023422e+00 1.17211368e+00 + 1.53393850e+00 -6.17771312e-01 + -3.99977129e+00 1.71415137e+00 + 5.70705058e-01 -4.60771539e-01 + -2.20608002e+00 1.07866596e+00 + -1.09040244e+00 6.77441076e-01 + -5.09886482e-01 -1.97282128e-01 + -1.58062785e+00 6.18333697e-01 + -1.53295020e+00 4.02168701e-01 + -5.18580598e-01 2.25767177e-01 + 1.59514316e+00 -2.54983617e-01 + -5.91938655e+00 2.68223782e+00 + 2.84200509e+00 -1.04685313e+00 + 1.31298664e+00 -1.16672614e+00 + -2.36660033e+00 1.81359460e+00 + 6.94163290e-02 3.76658816e-01 + 2.33973934e+00 -8.33173023e-01 + -8.24640389e-01 7.83717285e-01 + -1.02888281e+00 1.04680766e+00 + 1.34750745e+00 -5.89568160e-01 + -2.48761231e+00 7.44199284e-01 + -1.04501559e+00 4.72326911e-01 + -3.14610089e+00 1.89843692e+00 + 2.13003416e-01 5.76633620e-01 + -1.69239608e+00 5.66070021e-01 + 1.80491280e+00 -9.31701080e-01 + -6.94362572e-02 6.96026587e-01 + 1.36502578e+00 -6.85599000e-02 + -7.76764337e-01 3.64328661e-01 + -2.67322167e+00 6.80150021e-01 + 1.84338485e+00 -1.18487494e+00 + 2.88009231e+00 -1.25700411e+00 + 1.17114433e+00 -7.69727080e-01 + 2.11576167e+00 2.81502116e-01 + -1.51470088e+00 2.61553540e-01 + 1.18923669e-01 -1.17890202e-01 + 4.48359786e+00 -1.81427466e+00 + -1.27055948e+00 9.92388998e-01 + -8.00276606e-01 9.11326621e-02 + 7.51764024e-01 -1.03676498e-01 + 1.35769348e-01 -2.11470084e-01 + 2.50731332e+00 -1.12418270e+00 + -2.49752781e-01 7.81224033e-02 + -6.23037902e-01 3.16599691e-01 + -3.93772902e+00 1.37195391e+00 + 1.74256361e+00 -1.12363582e+00 + -1.49737281e+00 5.98828310e-01 + 7.75592115e-01 -4.64733802e-01 + -2.26027693e+00 1.36991118e+00 + -1.62849836e+00 7.36899107e-01 + 2.36850751e+00 -9.32126872e-01 + 5.86169745e+00 -2.49342512e+00 + -5.37092226e-01 1.23821274e+00 + 2.80535867e+00 -1.93363302e+00 + -1.77638106e+00 9.10050276e-01 + 3.02692018e+00 -1.60774676e+00 + 1.97833084e+00 -1.50636531e+00 + 9.09168906e-01 -8.83799359e-01 + 2.39769655e+00 -7.56977869e-01 + 1.47283981e+00 -1.06749890e+00 + 2.92060943e-01 -6.07040605e-01 + -2.09278201e+00 7.71858590e-01 + 7.10015905e-01 -5.42768432e-01 + -2.16826169e-01 1.56897896e-01 + 4.56288247e+00 -2.08912680e+00 + -6.63374020e-01 6.67325183e-01 + 1.80564442e+00 -9.76366134e-01 + 3.28720168e+00 -4.66575145e-01 + -1.60463695e-01 -2.58428153e-01 + 1.78590750e+00 -3.96427146e-01 + 2.75950306e+00 -1.82102856e+00 + -1.18234310e+00 6.28073320e-01 + 4.11415835e+00 -2.33551216e+00 + 1.38721004e+00 -2.77450622e-01 + -2.94903545e+00 1.74813352e+00 + 8.67290400e-01 -6.51667894e-01 + 2.70022274e+00 -8.11832480e-01 + -2.06766146e+00 8.24047249e-01 + 3.90717142e+00 -1.20155758e+00 + -2.95102809e+00 1.36667968e+00 + 6.08815147e+00 -2.60737974e+00 + 2.78576476e+00 -7.86628755e-01 + -3.26258407e+00 1.09302450e+00 + 1.59849422e+00 -1.09705202e+00 + -2.50600710e-01 1.63243175e-01 + -4.90477087e-01 -4.57729572e-01 + -1.24837181e+00 3.22157840e-01 + -2.46341049e+00 1.06517849e+00 + 9.62880751e-01 4.56962496e-01 + 3.99964487e-01 2.07472802e-01 + 6.36657705e-01 -3.46400942e-02 + 4.91231407e-02 -1.40289235e-02 + -4.66683524e-02 -3.72326100e-01 + -5.22049702e-01 -1.70440260e-01 + 5.27062938e-01 -2.32628395e-01 + -2.69440318e+00 1.18914874e+00 + 3.65087539e+00 -1.53427267e+00 + -1.16546364e-01 4.93245392e-02 + 7.55931384e-01 -3.02980139e-01 + 2.06338745e+00 -6.24841225e-01 + 1.31177908e-01 7.29338183e-01 + 1.48021784e+00 -6.39509896e-01 + -5.98656707e-01 2.84525503e-01 + -2.18611080e+00 1.79549812e+00 + -2.91673624e+00 2.15772237e-01 + -8.95591350e-01 7.68250538e-01 + 1.36139762e+00 -1.93845144e-01 + 5.45730414e+00 -2.28114404e+00 + 3.22747247e-01 9.33582332e-01 + -1.46384504e+00 1.12801186e-01 + 4.26728166e-01 -2.33481242e-01 + -1.41327270e+00 8.16103740e-01 + -2.53998067e-01 1.44906646e-01 + -1.32436467e+00 1.87556361e-01 + -3.77313086e+00 1.32896038e+00 + 3.77651731e+00 -1.76548043e+00 + -2.45297093e+00 1.32571926e+00 + -6.55900588e-01 3.56921462e-01 + 9.25558722e-01 -4.51988954e-01 + 1.20732231e+00 -3.02821614e-01 + 3.72660154e-01 -1.89365208e-01 + -1.77090939e+00 9.18087975e-01 + 3.01127567e-01 2.67965829e-01 + -1.76708900e+00 4.62069259e-01 + -2.71812099e+00 1.57233508e+00 + -5.35297633e-01 4.99231535e-01 + 1.50507631e+00 -9.85763646e-01 + 3.00424787e+00 -1.29837562e+00 + -4.99311105e-01 3.91086482e-01 + 1.30125207e+00 -1.26247924e-01 + 4.01699483e-01 -4.46909391e-01 + -1.33635257e+00 5.12068703e-01 + 1.39229757e+00 -9.10974858e-01 + -1.74229508e+00 1.49475978e+00 + -1.21489414e+00 4.04193753e-01 + -3.36537605e-01 -6.74335427e-01 + -2.79186828e-01 8.48314720e-01 + -2.03080140e+00 1.66599815e+00 + -3.53064281e-01 -7.68582906e-04 + -5.30305657e+00 2.91091546e+00 + -1.20049972e+00 8.26578358e-01 + 2.95906989e-01 2.40215920e-01 + -1.42955534e+00 4.63480310e-01 + -1.87856619e+00 8.21459385e-01 + -2.71124720e+00 1.80246843e+00 + -3.06933780e+00 1.22235760e+00 + 5.21935582e-01 -1.27298218e+00 + -1.34175797e+00 7.69018937e-01 + -1.81962785e+00 1.15528991e+00 + -3.99227550e-01 2.93821598e-01 + 1.22533179e+00 -4.73846323e-01 + -2.08068359e-01 -1.75039817e-01 + -2.03068526e+00 1.50370503e+00 + -3.27606113e+00 1.74906330e+00 + -4.37802587e-01 -2.26956048e-01 + -7.69774213e-02 -3.54922468e-01 + 6.47160749e-02 -2.07334721e-01 + -1.37791524e+00 4.43766709e-01 + 3.29846803e+00 -1.04060799e+00 + -3.63704046e+00 1.05800226e+00 + -1.26716116e+00 1.13077353e+00 + 1.98549075e+00 -1.31864807e+00 + 1.85159500e+00 -5.78629560e-01 + -1.55295206e+00 1.23655857e+00 + 6.76026255e-01 9.18824125e-02 + 1.23418960e+00 -4.68162027e-01 + 2.43186642e+00 -9.22422440e-01 + -3.18729701e+00 1.77582673e+00 + -4.02945613e+00 1.14303496e+00 + -1.92694576e-01 1.03301431e-01 + 1.89554730e+00 -4.60128096e-01 + -2.55626581e+00 1.16057084e+00 + 6.89144365e-01 -9.94982900e-01 + -4.44680606e+00 2.19751983e+00 + -3.15196193e+00 1.18762993e+00 + -1.17434977e+00 1.04534656e+00 + 8.58386984e-02 -1.03947487e+00 + 3.33354973e-01 5.54813610e-01 + -9.37631808e-01 3.33450150e-01 + -2.50232471e+00 5.39720635e-01 + 1.03611949e+00 -7.16304095e-02 + -2.05556816e-02 -3.28992265e-01 + -2.24176201e+00 1.13077506e+00 + 4.53583688e+00 -1.10710212e+00 + 4.77389762e-01 -8.99445512e-01 + -2.69075551e+00 6.83176866e-01 + -2.21779724e+00 1.16916849e+00 + -1.09669056e+00 2.10044765e-01 + -8.45367920e-01 -8.45951423e-02 + 4.37558941e-01 -6.95904256e-01 + 1.84884195e+00 -1.71205136e-01 + -8.36371957e-01 5.62862478e-01 + 1.27786531e+00 -1.33362147e+00 + 2.90684492e+00 -7.49892184e-01 + -3.38652716e+00 1.51180670e+00 + -1.30945978e+00 7.09261928e-01 + -7.50471924e-01 -5.24637889e-01 + 1.18580718e+00 -9.97943971e-04 + -7.55395645e+00 3.19273590e+00 + 1.72822535e+00 -1.20996962e+00 + 5.67374320e-01 6.19573416e-01 + -2.99163781e+00 1.79721534e+00 + 1.49862187e+00 -6.05631846e-02 + 1.79503506e+00 -4.90419706e-01 + 3.85626054e+00 -1.95396324e+00 + -9.39188410e-01 7.96498057e-01 + 2.91986664e+00 -1.29392724e+00 + -1.54265750e+00 6.40727933e-01 + 1.14919794e+00 1.20834257e-01 + 2.00936817e+00 -1.53728359e+00 + 3.72468420e+00 -1.38704612e+00 + -1.27794802e+00 3.48543179e-01 + 3.63294077e-01 5.70623314e-01 + 1.49381016e+00 -6.04500534e-01 + 2.98912256e+00 -1.72295726e+00 + -1.80833817e+00 2.94907625e-01 + -3.19669622e+00 1.31888700e+00 + 1.45889401e+00 -8.88448639e-01 + -2.80045388e+00 1.01207060e+00 + -4.78379567e+00 1.48646520e+00 + 2.25510003e+00 -7.13372461e-01 + -9.74441433e-02 -2.17766373e-01 + 2.64468496e-01 -3.60842698e-01 + -5.98821713e+00 3.20197892e+00 + 2.67030213e-01 -5.36386416e-01 + 2.24546960e+00 -8.13464649e-01 + -4.89171414e-01 3.86255031e-01 + -7.45713706e-01 6.29800380e-01 + -3.30460503e-01 3.85127284e-01 + -4.19588147e+00 1.52793198e+00 + 5.42078582e-01 -2.61642741e-02 + 4.24938513e-01 -5.72936751e-01 + 2.82717288e+00 -6.75355024e-01 + -1.44741788e+00 5.03578028e-01 + -1.65547573e+00 7.76444277e-01 + 2.20361170e+00 -1.40835680e+00 + -3.69540235e+00 2.32953767e+00 + -1.41909357e-01 2.28989778e-01 + 1.92838879e+00 -8.72525737e-01 + 1.40708100e+00 -6.81849638e-02 + 1.24988112e+00 -1.39470590e-01 + -2.39435855e+00 7.26587655e-01 + 7.03985028e-01 4.85403277e-02 + 4.05214529e+00 -9.16928318e-01 + 3.74198837e-01 -5.04192358e-01 + -8.43374127e-01 2.36064018e-01 + -3.32253349e-01 7.47840055e-01 + -6.03725210e+00 1.95173337e+00 + 4.60829865e+00 -1.51191309e+00 + -1.46247098e+00 1.11140916e+00 + -9.60111157e-01 -1.23189114e-01 + -7.49613187e-01 4.53614129e-01 + -5.77838219e-01 2.07366469e-02 + 8.07652950e-01 -5.16272662e-01 + -6.02556049e-01 5.05318649e-01 + -1.28712445e-01 2.57836512e-01 + -5.27662820e+00 2.11790737e+00 + 5.40819308e+00 -2.15366022e+00 + 9.37742513e-02 -1.60221751e-01 + 4.55902865e+00 -1.24646307e+00 + -9.06582589e-01 1.92928110e-01 + 2.99928996e+00 -8.04301218e-01 + -3.24317381e+00 1.80076061e+00 + 3.20421743e-01 8.76524679e-01 + -5.29606705e-01 -3.16717696e-01 + -1.77264560e+00 7.52686776e-01 + -1.51706824e+00 8.43755103e-01 + 1.52759111e+00 -7.86814243e-01 + 4.74845617e-01 4.21319700e-01 + 6.97829149e-01 -8.15664881e-01 + 3.09564973e+00 -1.06202469e+00 + 2.95320379e+00 -1.98963943e+00 + -4.23033224e+00 1.41013338e+00 + 1.48576206e+00 8.02908511e-02 + 4.52041627e+00 -2.04620399e+00 + 6.58403922e-01 -7.60781799e-01 + 2.10667543e-01 1.15241731e-01 + 1.77702583e+00 -8.10271859e-01 + 2.41277385e+00 -1.46972042e+00 + 1.50685525e+00 -1.99272545e-01 + 7.61665522e-01 -4.11276152e-01 + 1.18352312e+00 -9.59908608e-01 + -3.32031305e-01 8.07500132e-02 + 1.16813118e+00 -1.73095194e-01 + 1.18363346e+00 -5.41565052e-01 + 5.17702179e-01 -7.62442035e-01 + 4.57401006e-01 -1.45951115e-02 + 1.49377115e-01 2.99571605e-01 + 1.40399453e+00 -1.30160353e+00 + 5.26231567e-01 3.52783752e-01 + -1.91136514e+00 4.24228635e-01 + 1.74156701e+00 -9.92076776e-01 + -4.89323391e+00 2.32483507e+00 + 2.54011209e+00 -8.80366295e-01 + -5.56925706e-01 1.48842026e-01 + -2.35904668e+00 9.60474853e-01 + 1.42216971e+00 -4.67062761e-01 + -1.10809680e+00 7.68684300e-01 + 4.09674726e+00 -1.90795680e+00 + -2.23048923e+00 9.03812542e-01 + 6.57025763e-01 1.36514871e-01 + 2.10944145e+00 -9.78897838e-02 + 1.22552525e+00 -2.50303867e-01 + 2.84620103e-01 -5.30164020e-01 + -2.13562585e+00 1.03503056e+00 + 1.32414902e-01 -8.14190240e-03 + -5.82433561e-01 3.21020292e-01 + -5.06473247e-01 3.11530419e-01 + 1.57162465e+00 -1.20763919e+00 + -1.43155284e+00 -2.51203698e-02 + -1.47093713e+00 -1.39620999e-01 + -2.65765643e+00 1.06091403e+00 + 2.45992927e+00 -5.88815836e-01 + -1.28440162e+00 -1.99377398e-01 + 6.11257504e-01 -3.73577401e-01 + -3.46606103e-01 6.06081290e-01 + 3.76687505e+00 -8.80181424e-01 + -1.03725103e+00 1.45177517e+00 + 2.76659936e+00 -1.09361320e+00 + -3.61311296e+00 9.75032455e-01 + 3.22878655e+00 -9.69497365e-01 + 1.43560379e+00 -5.52524585e-01 + 2.94042153e+00 -1.79747037e+00 + 1.30739580e+00 2.47989248e-01 + -4.05056982e-01 1.22831715e+00 + -2.25827421e+00 2.30604626e-01 + 3.69262926e-01 4.32714650e-02 + -5.52064063e-01 6.07806340e-01 + 7.03325987e+00 -2.17956730e+00 + -2.37823835e-01 -8.28068639e-01 + -4.84279888e-01 5.67765194e-01 + -3.15863410e+00 1.02241617e+00 + -3.39561593e+00 1.36876374e+00 + -2.78482934e+00 6.81641104e-01 + -4.37604334e+00 2.23826340e+00 + -2.54049692e+00 8.22676745e-01 + 3.73264822e+00 -9.93498732e-01 + -3.49536064e+00 1.84771519e+00 + 9.81801604e-01 -5.21278776e-01 + 1.52996831e+00 -1.27386206e+00 + -9.23490293e-01 5.29099482e-01 + -2.76999461e+00 9.24831872e-01 + -3.30029834e-01 -2.49645555e-01 + -1.71156166e+00 5.44940854e-01 + -2.37009487e+00 5.83826982e-01 + -3.03216865e+00 1.04922722e+00 + -2.19539936e+00 1.37558730e+00 + 1.15350207e+00 -6.15318535e-01 + 4.62011792e+00 -2.46714517e+00 + 1.52627952e-02 -1.00618283e-01 + -1.10399342e+00 4.87413533e-01 + 3.55448194e+00 -9.10394190e-01 + -5.21890321e+00 2.44710745e+00 + 1.54289749e+00 -6.54269311e-01 + 2.67935674e+00 -9.92758863e-01 + 1.05801310e+00 2.60054285e-02 + 1.52509097e+00 -4.08768600e-01 + 3.27576917e+00 -1.28769406e+00 + 1.71008412e-01 -2.68739994e-01 + -9.83351344e-04 7.02495897e-02 + -7.60795056e-03 1.61968285e-01 + -1.80620472e+00 4.24934471e-01 + 2.32023297e-02 -2.57284559e-01 + 3.98219478e-01 -4.65361935e-01 + 6.63476988e-01 -3.29823196e-02 + 4.00154707e+00 -1.01792211e+00 + -1.50286870e+00 9.46875359e-01 + -2.22717585e+00 7.50636195e-01 + -3.47381508e-01 -6.51596975e-01 + 2.08076453e+00 -8.22800165e-01 + 2.05099963e+00 -4.00868250e-01 + 3.52576988e-02 -2.54418565e-01 + 1.57342042e+00 -7.62166492e-02 + -1.47019722e+00 3.40861172e-01 + -1.21156090e+00 3.21891246e-01 + 3.79729047e+00 -1.54350764e+00 + 1.26459678e-02 6.99203693e-01 + 1.53974177e-01 4.68643204e-01 + -1.73923561e-01 -1.26229768e-01 + 4.54644993e+00 -2.13951783e+00 + 1.46022547e-01 -4.57084165e-01 + 6.50048037e+00 -2.78872609e+00 + -1.51934912e+00 1.03216768e+00 + -3.06483575e+00 1.81101446e+00 + -2.38212125e+00 9.19559042e-01 + -1.81319611e+00 8.10545112e-01 + 1.70951294e+00 -6.10712680e-01 + 1.67974156e+00 -1.51241453e+00 + -5.94795113e+00 2.56893813e+00 + 3.62633110e-01 -7.46965304e-01 + -2.44042594e+00 8.52761797e-01 + 3.32412550e+00 -1.28439899e+00 + 4.74860766e+00 -1.72821964e+00 + 1.29072541e+00 -8.24872902e-01 + -1.69450702e+00 4.09600876e-01 + 1.29705411e+00 1.22300809e-01 + -2.63597613e+00 8.55612913e-01 + 9.28467301e-01 -2.63550114e-02 + 2.44670264e+00 -4.10123002e-01 + 1.06408206e+00 -5.03361942e-01 + 5.12384049e-02 -1.27116595e-02 + -1.06731272e+00 -1.76205029e-01 + -9.45454582e-01 3.74404917e-01 + 2.54343689e+00 -7.13810545e-01 + -2.54460335e+00 1.31590265e+00 + 1.89864233e+00 -3.98436339e-01 + -1.93990133e+00 6.01474630e-01 + -1.35938824e+00 4.00751788e-01 + 2.38567018e+00 -6.13904880e-01 + 2.18748050e-01 2.62631712e-01 + -2.01388788e+00 1.41474031e+00 + 2.74014581e+00 -1.27448105e+00 + -2.13828583e+00 1.13616144e+00 + 5.98730932e+00 -2.53430080e+00 + -1.72872795e+00 1.53702057e+00 + -2.53263962e+00 1.27342410e+00 + 1.34326968e+00 -1.99395088e-01 + 3.83352666e-01 -1.25683065e-01 + -2.35630657e+00 5.54116983e-01 + -1.94900838e+00 5.76270178e-01 + -1.36699108e+00 -3.40904824e-01 + -2.34727346e+00 -1.93054940e-02 + -3.82779777e+00 1.83025664e+00 + -4.31602080e+00 9.21605705e-01 + 5.54098133e-01 2.33991419e-01 + -4.53591188e+00 1.99833353e+00 + -3.92715909e+00 1.83231482e+00 + 3.91344440e-01 -1.11355111e-01 + 3.48576363e+00 -1.41379449e+00 + -1.42858690e+00 3.84532286e-01 + 1.79519859e+00 -9.23486448e-01 + 8.49691242e-01 -1.76551084e-01 + 1.53618138e+00 8.23835015e-02 + 5.91476520e-02 3.88296940e-02 + 1.44837346e+00 -7.24097604e-01 + -6.79008418e-01 4.04078097e-01 + 2.87555510e+00 -9.51825076e-01 + -1.12379101e+00 2.93457714e-01 + 1.45263980e+00 -6.01960544e-01 + -2.55741621e-01 9.26233518e-01 + 3.54570714e+00 -1.41521877e+00 + -1.61542388e+00 6.57844512e-01 + -3.22844269e-01 3.02823546e-01 + 1.03523913e+00 -6.92730711e-01 + 1.11084909e+00 -3.50823642e-01 + 3.41268693e+00 -1.90865862e+00 + 7.67062858e-01 -9.48792160e-01 + -5.49798016e+00 1.71139960e+00 + 1.14865798e+00 -6.12669150e-01 + -2.18256680e+00 7.78634462e-01 + 4.78857389e+00 -2.55555085e+00 + -1.85555569e+00 8.04311615e-01 + -4.22278799e+00 2.01162524e+00 + -1.56556149e+00 1.54353907e+00 + -3.11527864e+00 1.65973526e+00 + 2.66342611e+00 -1.20449402e+00 + 1.57635314e+00 -1.48716308e-01 + -6.35606865e-01 2.59701180e-01 + 1.02431976e+00 -6.76929904e-01 + 1.12973772e+00 1.49473892e-02 + -9.12758116e-01 2.21533933e-01 + -2.98014470e+00 1.71651189e+00 + 2.74016965e+00 -9.47893923e-01 + -3.47830591e+00 1.34941430e+00 + 1.74757562e+00 -3.72503752e-01 + 5.55820383e-01 -6.47992466e-01 + -1.19871928e+00 9.82429151e-01 + -2.53040133e+00 2.10671307e+00 + -1.94085605e+00 1.38938137e+00 diff --git a/data/mllib/sample_fpgrowth.txt b/data/mllib/sample_fpgrowth.txt new file mode 100644 index 0000000000000..c451583e51317 --- /dev/null +++ b/data/mllib/sample_fpgrowth.txt @@ -0,0 +1,6 @@ +r z h k p +z y x w v u t s +s x o n r +x z y m t s q e +z +x z y r q t p diff --git a/data/mllib/sample_isotonic_regression_data.txt b/data/mllib/sample_isotonic_regression_data.txt new file mode 100644 index 0000000000000..d257b509d4d37 --- /dev/null +++ b/data/mllib/sample_isotonic_regression_data.txt @@ -0,0 +1,100 @@ +0.24579296,0.01 +0.28505864,0.02 +0.31208567,0.03 +0.35900051,0.04 +0.35747068,0.05 +0.16675166,0.06 +0.17491076,0.07 +0.04181540,0.08 +0.04793473,0.09 +0.03926568,0.10 +0.12952575,0.11 +0.00000000,0.12 +0.01376849,0.13 +0.13105558,0.14 +0.08873024,0.15 +0.12595614,0.16 +0.15247323,0.17 +0.25956145,0.18 +0.20040796,0.19 +0.19581846,0.20 +0.15757267,0.21 +0.13717491,0.22 +0.19020908,0.23 +0.19581846,0.24 +0.20091790,0.25 +0.16879143,0.26 +0.18510964,0.27 +0.20040796,0.28 +0.29576747,0.29 +0.43396226,0.30 +0.53391127,0.31 +0.52116267,0.32 +0.48546660,0.33 +0.49209587,0.34 +0.54156043,0.35 +0.59765426,0.36 +0.56144824,0.37 +0.58592555,0.38 +0.52983172,0.39 +0.50178480,0.40 +0.52626211,0.41 +0.58286588,0.42 +0.64660887,0.43 +0.68077511,0.44 +0.74298827,0.45 +0.64864865,0.46 +0.67261601,0.47 +0.65782764,0.48 +0.69811321,0.49 +0.63029067,0.50 +0.61601224,0.51 +0.63233044,0.52 +0.65323814,0.53 +0.65323814,0.54 +0.67363590,0.55 +0.67006629,0.56 +0.51555329,0.57 +0.50892402,0.58 +0.33299337,0.59 +0.36206017,0.60 +0.43090260,0.61 +0.45996940,0.62 +0.56348802,0.63 +0.54920959,0.64 +0.48393677,0.65 +0.48495665,0.66 +0.46965834,0.67 +0.45181030,0.68 +0.45843957,0.69 +0.47118817,0.70 +0.51555329,0.71 +0.58031617,0.72 +0.55481897,0.73 +0.56297807,0.74 +0.56603774,0.75 +0.57929628,0.76 +0.64762876,0.77 +0.66241713,0.78 +0.69301377,0.79 +0.65119837,0.80 +0.68332483,0.81 +0.66598674,0.82 +0.73890872,0.83 +0.73992861,0.84 +0.84242733,0.85 +0.91330954,0.86 +0.88016318,0.87 +0.90719021,0.88 +0.93115757,0.89 +0.93115757,0.90 +0.91942886,0.91 +0.92911780,0.92 +0.95665477,0.93 +0.95002550,0.94 +0.96940337,0.95 +1.00000000,0.96 +0.89801122,0.97 +0.90311066,0.98 +0.90362060,0.99 +0.83477817,1.0 \ No newline at end of file diff --git a/data/mllib/sample_lda_data.txt b/data/mllib/sample_lda_data.txt new file mode 100644 index 0000000000000..2e76702ca9d67 --- /dev/null +++ b/data/mllib/sample_lda_data.txt @@ -0,0 +1,12 @@ +1 2 6 0 2 3 1 1 0 0 3 +1 3 0 1 3 0 0 2 0 0 1 +1 4 1 0 0 4 9 0 1 2 0 +2 1 0 3 0 0 5 0 2 3 9 +3 1 1 9 3 0 2 0 0 1 3 +4 2 0 3 4 5 1 1 1 4 0 +2 1 0 3 0 0 5 0 2 2 9 +1 1 1 9 2 1 2 0 0 1 3 +4 4 0 3 4 2 1 3 0 0 0 +2 8 2 0 3 0 2 0 2 7 2 +1 1 1 9 0 2 2 0 0 3 3 +4 1 0 0 4 5 1 3 0 1 0 diff --git a/dev/change-version-to-2.10.sh b/dev/change-version-to-2.10.sh index 7473c20d28e09..15e0c73b4295e 100755 --- a/dev/change-version-to-2.10.sh +++ b/dev/change-version-to-2.10.sh @@ -16,5 +16,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +# Note that this will not necessarily work as intended with non-GNU sed (e.g. OS X) + find . -name 'pom.xml' | grep -v target \ - | xargs -I {} sed -i -e 's|\(artifactId.*\)_2.11|\1_2.10|g' {} + | xargs -I {} sed -i -e 's/\(artifactId.*\)_2.11/\1_2.10/g' {} + +# Also update in parent POM +sed -i -e '0,/2.112.10 in parent POM +sed -i -e '0,/2.102.11 "$JAR_DL" && mv "$JAR_DL" "$JAR" + curl -L --silent "${URL}" > "$JAR_DL" && mv "$JAR_DL" "$JAR" elif [ $(command -v wget) ]; then wget --quiet ${URL} -O "$JAR_DL" && mv "$JAR_DL" "$JAR" else diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 607ce1c803507..6f87fcd6d4eb4 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -34,6 +34,9 @@ ASF_PASSWORD=${ASF_PASSWORD:-XXX} GPG_PASSPHRASE=${GPG_PASSPHRASE:-XXX} GIT_BRANCH=${GIT_BRANCH:-branch-1.0} RELEASE_VERSION=${RELEASE_VERSION:-1.2.0} +# Allows publishing under a different version identifier than +# was present in the actual release sources (e.g. rc-X) +PUBLISH_VERSION=${PUBLISH_VERSION:-$RELEASE_VERSION} NEXT_VERSION=${NEXT_VERSION:-1.2.1} RC_NAME=${RC_NAME:-rc2} @@ -97,30 +100,35 @@ if [[ ! "$@" =~ --skip-publish ]]; then pushd spark git checkout --force $GIT_TAG + # Substitute in case published version is different than released + old="^\( \{2,4\}\)${RELEASE_VERSION}<\/version>$" + new="\1${PUBLISH_VERSION}<\/version>" + find . -name pom.xml | grep -v dev | xargs -I {} sed -i \ + -e "s/${old}/${new}/" {} + # Using Nexus API documented here: # https://support.sonatype.com/entries/39720203-Uploading-to-a-Staging-Repository-via-REST-API echo "Creating Nexus staging repository" - repo_request="Apache Spark $GIT_TAG" + repo_request="Apache Spark $GIT_TAG (published as $PUBLISH_VERSION)" out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ -H "Content-Type:application/xml" -v \ $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/") echo "Created Nexus staging repository: $staged_repo_id" - rm -rf $SPARK_REPO - - mvn -DskipTests -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ + build/mvn -DskipTests -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ clean install ./dev/change-version-to-2.11.sh - mvn -DskipTests -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ + build/mvn -DskipTests -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ -Dscala-2.11 -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ clean install ./dev/change-version-to-2.10.sh + rm -rf $SPARK_REPO pushd $SPARK_REPO # Remove any extra files generated during install @@ -197,6 +205,12 @@ if [[ ! "$@" =~ --skip-package ]]; then ./dev/change-version-to-2.11.sh fi + # Create new Zinc instances for each binary release to avoid interference + # that causes OOM's and random compiler crashes. + zinc_port=${zinc_port:-3030} + zinc_port=$[$zinc_port + 1] + export ZINC_PORT=$zinc_port + ./make-distribution.sh --name $NAME --tgz $FLAGS 2>&1 | tee ../binary-release-$NAME.log cd .. cp spark-$RELEASE_VERSION-bin-$NAME/spark-$RELEASE_VERSION-bin-$NAME.tgz . @@ -237,7 +251,7 @@ if [[ ! "$@" =~ --skip-package ]]; then sbt/sbt clean cd docs # Compile docs with Java 7 to use nicer format - JAVA_HOME=$JAVA_7_HOME PRODUCTION=1 jekyll build + JAVA_HOME="$JAVA_7_HOME" PRODUCTION=1 RELEASE_VERSION="$RELEASE_VERSION" jekyll build echo "Copying release documentation" rc_docs_folder=${rc_folder}-docs ssh $ASF_USERNAME@people.apache.org \ diff --git a/dev/create-release/generate-changelist.py b/dev/create-release/generate-changelist.py index 2e1a35a629342..1abf7ff3a3a35 100755 --- a/dev/create-release/generate-changelist.py +++ b/dev/create-release/generate-changelist.py @@ -31,8 +31,8 @@ import traceback SPARK_HOME = os.environ["SPARK_HOME"] -NEW_RELEASE_VERSION = "1.0.0" -PREV_RELEASE_GIT_TAG = "v0.9.1" +NEW_RELEASE_VERSION = "1.3.1" +PREV_RELEASE_GIT_TAG = "v1.3.0" CHANGELIST = "CHANGES.txt" OLD_CHANGELIST = "%s.old" % (CHANGELIST) diff --git a/dev/run-tests b/dev/run-tests index 2257a566bb1bb..483958757a2dd 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -36,7 +36,7 @@ function handle_error () { } -# Build against the right verison of Hadoop. +# Build against the right version of Hadoop. { if [ -n "$AMPLAB_JENKINS_BUILD_PROFILE" ]; then if [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop1.0" ]; then @@ -77,7 +77,7 @@ export SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Pkinesis-asl" fi } -# Only run Hive tests if there are sql changes. +# Only run Hive tests if there are SQL changes. # Partial solution for SPARK-1455. if [ -n "$AMPLAB_JENKINS" ]; then git fetch origin master:master @@ -183,7 +183,7 @@ CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS if [ -n "$_SQL_TESTS_ONLY" ]; then # This must be an array of individual arguments. Otherwise, having one long string # will be interpreted as a single test, which doesn't work. - SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "mllib/test") + SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "hive-thriftserver/test" "mllib/test") else SBT_MAVEN_TEST_ARGS=("test") fi diff --git a/docs/_config.yml b/docs/_config.yml index e2db274e1f619..1fd60473073d4 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -10,11 +10,12 @@ kramdown: include: - _static + - _modules # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 1.3.0-SNAPSHOT -SPARK_VERSION_SHORT: 1.3.0 +SPARK_VERSION: 1.3.1 +SPARK_VERSION_SHORT: 1.3.1 SCALA_BINARY_VERSION: "2.10" SCALA_VERSION: "2.10.4" MESOS_VERSION: 0.21.0 diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 8841f7675d35e..2e88b3093652d 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -7,7 +7,9 @@ {{ page.title }} - Spark {{site.SPARK_VERSION_SHORT}} Documentation - + {% if page.description %} + + {% endif %} {% if page.redirect %} @@ -69,7 +71,7 @@
  • Spark Programming Guide
  • Spark Streaming
  • -
  • Spark SQL
  • +
  • DataFrames and SQL
  • MLlib (Machine Learning)
  • GraphX (Graph Processing)
  • Bagel (Pregel on Spark)
  • diff --git a/docs/bagel-programming-guide.md b/docs/bagel-programming-guide.md index 7e55131754a3f..c2fe6b0e286ce 100644 --- a/docs/bagel-programming-guide.md +++ b/docs/bagel-programming-guide.md @@ -1,6 +1,7 @@ --- layout: global -title: Bagel Programming Guide +displayTitle: Bagel Programming Guide +title: Bagel --- **Bagel will soon be superseded by [GraphX](graphx-programming-guide.html); we recommend that new users try GraphX instead.** diff --git a/docs/building-spark.md b/docs/building-spark.md index db69905813817..d12ee8b181c9e 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -9,6 +9,10 @@ redirect_from: "building-with-maven.html" Building Spark using Maven requires Maven 3.0.4 or newer and Java 6+. +**Note:** Building Spark with Java 7 or later can create JAR files that may not be +readable with early versions of Java 6, due to the large number of files in the JAR +archive. Build with Java 6 if this is an issue for your deployment. + # Building with `build/mvn` Spark now comes packaged with a self-contained Maven installation to ease building and deployment of Spark from source located under the `build/` directory. This script will automatically download and setup all necessary build requirements ([Maven](https://maven.apache.org/), [Scala](http://www.scala-lang.org/), and [Zinc](https://github.com/typesafehub/zinc)) locally within the `build/` directory itself. It honors any `mvn` binary if present already, however, will pull down its own copy of Scala and Zinc regardless to ensure proper version requirements are met. `build/mvn` execution acts as a pass through to the `mvn` call allowing easy transition from previous build methods. As an example, one can build a version of Spark as follows: @@ -111,9 +115,9 @@ To produce a Spark package compiled with Scala 2.11, use the `-Dscala-2.11` prop dev/change-version-to-2.11.sh mvn -Pyarn -Phadoop-2.4 -Dscala-2.11 -DskipTests clean package -Scala 2.11 support in Spark is experimental and does not support a few features. -Specifically, Spark's external Kafka library and JDBC component are not yet -supported in Scala 2.11 builds. +Scala 2.11 support in Spark does not support a few features due to dependencies +which are themselves not Scala 2.11 ready. Specifically, Spark's external +Kafka library and JDBC component are not yet supported in Scala 2.11 builds. # Spark Tests in Maven @@ -161,6 +165,8 @@ For help in setting up IntelliJ IDEA or Eclipse for Spark development, and troub # Building Spark Debian Packages +_NOTE: Debian packaging is deprecated and is scheduled to be removed in Spark 1.4._ + The Maven build includes support for building a Debian package containing the assembly 'fat-jar', PySpark, and the necessary scripts and configuration files. This can be created by specifying the following: mvn -Pdeb -DskipTests clean package diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index 6a75d5c457f02..7079de546e2f5 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -33,7 +33,11 @@ There are several useful things to note about this architecture: 2. Spark is agnostic to the underlying cluster manager. As long as it can acquire executor processes, and these communicate with each other, it is relatively easy to run it even on a cluster manager that also supports other applications (e.g. Mesos/YARN). -3. Because the driver schedules tasks on the cluster, it should be run close to the worker +3. The driver program must listen for and accept incoming connections from its executors throughout + its lifetime (e.g., see [spark.driver.port and spark.fileserver.port in the network config + section](configuration.html#networking)). As such, the driver program must be network + addressable from the worker nodes. +4. Because the driver schedules tasks on the cluster, it should be run close to the worker nodes, preferably on the same local area network. If you'd like to send requests to the cluster remotely, it's better to open an RPC to the driver and have it submit operations from nearby than to run a driver far away from the worker nodes. diff --git a/docs/configuration.md b/docs/configuration.md index 62d3fca937b2c..9f97dd2299128 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1,6 +1,7 @@ --- layout: global -title: Spark Configuration +displayTitle: Spark Configuration +title: Configuration --- * This will become a table of contents (this text will be scraped). {:toc} @@ -69,7 +70,9 @@ each line consists of a key and a value separated by whitespace. For example: Any values specified as flags or in the properties file will be passed on to the application and merged with those specified through SparkConf. Properties set directly on the SparkConf take highest precedence, then flags passed to `spark-submit` or `spark-shell`, then options -in the `spark-defaults.conf` file. +in the `spark-defaults.conf` file. A few configuration keys have been renamed since earlier +versions of Spark; in such cases, the older key names are still accepted, but take lower +precedence than any instance of the newer key. ## Viewing Spark Properties @@ -93,14 +96,6 @@ of the most common options to set are: The name of your application. This will appear in the UI and in log data.
    - - - - - @@ -108,23 +103,6 @@ of the most common options to set are: Number of cores to use for the driver process, only in cluster mode. - - - - - - - - - - - - - + + - - + + - + @@ -190,6 +165,14 @@ of the most common options to set are: Logs the effective SparkConf as INFO when a SparkContext is started. + + + + +
    Executor IDTotal Tasks Failed Tasks Succeeded TasksInputOutputShuffle ReadShuffle WriteShuffle Spill (Memory)Shuffle Spill (Disk) + Input Size / Records + + Output Size / Records + + + Shuffle Read Size / Records + + + Shuffle Write Size / Records + Shuffle Spill (Memory)Shuffle Spill (Disk)
    {v.failedTasks + v.succeededTasks} {v.failedTasks} {v.succeededTasks} - {Utils.bytesToString(v.inputBytes)} - {Utils.bytesToString(v.outputBytes)} - {Utils.bytesToString(v.shuffleRead)} - {Utils.bytesToString(v.shuffleWrite)} - {Utils.bytesToString(v.memoryBytesSpilled)} - {Utils.bytesToString(v.diskBytesSpilled)} + {s"${Utils.bytesToString(v.inputBytes)} / ${v.inputRecords}"} + + {s"${Utils.bytesToString(v.outputBytes)} / ${v.outputRecords}"} + + {s"${Utils.bytesToString(v.shuffleRead)} / ${v.shuffleReadRecords}"} + + {s"${Utils.bytesToString(v.shuffleWrite)} / ${v.shuffleWriteRecords}"} + + {Utils.bytesToString(v.memoryBytesSpilled)} + + {Utils.bytesToString(v.diskBytesSpilled)} +
    {UIUtils.formatDuration(millis.toLong)} @@ -273,35 +290,86 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { getFormattedTimeQuantiles(schedulerDelays) def getFormattedSizeQuantiles(data: Seq[Double]) = - Distribution(data).get.getQuantiles().map(d => {Utils.bytesToString(d.toLong)}{Utils.bytesToString(d.toLong)}{s"${Utils.bytesToString(d.toLong)} / ${recordDist.next().toLong}"}InputInput Size / RecordsOutputOutput Size / RecordsShuffle Read Blocked Time + + Shuffle Read Blocked Time + + + + Shuffle Read Size / Records + + Shuffle Read (Remote) + + Shuffle Remote Reads + + Shuffle WriteShuffle Write Size / Records
    - {inputReadable} + {s"$inputReadable / $inputRecords"} - {outputReadable} + {s"$outputReadable / $outputRecords"} - {shuffleReadReadable} + {s"$shuffleReadReadable / $shuffleReadRecords"} + + {shuffleReadRemoteReadable} - {shuffleWriteReadable} + {s"$shuffleWriteReadable / $shuffleWriteRecords"} {errorSummary}{details}
    spark.master(none) - The cluster manager to connect to. See the list of - allowed master URL's. -
    spark.driver.cores 1
    spark.driver.memory512m - Amount of memory to use for the driver process, i.e. where SparkContext is initialized. - (e.g. 512m, 2g). -
    spark.executor.memory512m - Amount of memory to use per executor process, in the same format as JVM memory strings - (e.g. 512m, 2g). -
    spark.driver.maxResultSize 1g @@ -137,38 +115,35 @@ of the most common options to set are:
    spark.serializerorg.apache.spark.serializer.
    JavaSerializer
    spark.driver.memory512m - Class to use for serializing objects that will be sent over the network or need to be cached - in serialized form. The default of Java serialization works with any Serializable Java object - but is quite slow, so we recommend using - org.apache.spark.serializer.KryoSerializer and configuring Kryo serialization - when speed is necessary. Can be any subclass of - - org.apache.spark.Serializer. + Amount of memory to use for the driver process, i.e. where SparkContext is initialized. + (e.g. 512m, 2g). + +
    Note: In client mode, this config must not be set through the SparkConf + directly in your application, because the driver JVM has already started at that point. + Instead, please set this through the --driver-memory command line option + or in your default properties file.
    spark.kryo.classesToRegister(none)spark.executor.memory512m - If you use Kryo serialization, give a comma-separated list of custom class names to register - with Kryo. - See the tuning guide for more details. + Amount of memory to use per executor process, in the same format as JVM memory strings + (e.g. 512m, 2g).
    spark.kryo.registratorspark.extraListeners (none) - If you use Kryo serialization, set this class to register your custom classes with Kryo. This - property is useful if you need to register your classes in a custom way, e.g. to specify a custom - field serializer. Otherwise spark.kryo.classesToRegister is simpler. It should be - set to a class that extends - - KryoRegistrator. - See the tuning guide for more details. + A comma-separated list of classes that implement SparkListener; when initializing + SparkContext, instances of these classes will be created and registered with Spark's listener + bus. If a class has a single-argument constructor that accepts a SparkConf, that constructor + will be called; otherwise, a zero-argument constructor will be called. If no valid constructor + can be found, the SparkContext creation will fail with an exception.
    spark.master(none) + The cluster manager to connect to. See the list of + allowed master URL's. +
    Apart from these, the following properties are also available, and may be useful in some situations: @@ -198,17 +181,27 @@ Apart from these, the following properties are also available, and may be useful - + - + @@ -216,54 +209,56 @@ Apart from these, the following properties are also available, and may be useful - - + + - + - + - - + + @@ -277,30 +272,40 @@ Apart from these, the following properties are also available, and may be useful - + + + + + + - + - - + + @@ -326,6 +331,15 @@ Apart from these, the following properties are also available, and may be useful automatically. + + + + + @@ -336,40 +350,38 @@ Apart from these, the following properties are also available, and may be useful from JVM to Python worker for every task. +
    Property NameDefaultMeaning
    spark.driver.extraJavaOptionsspark.driver.extraClassPath (none) - A string of extra JVM options to pass to the driver. For instance, GC settings or other logging. + Extra classpath entries to append to the classpath of the driver. + +
    Note: In client mode, this config must not be set through the SparkConf + directly in your application, because the driver JVM has already started at that point. + Instead, please set this through the --driver-class-path command line option or in + your default properties file.
    spark.driver.extraClassPathspark.driver.extraJavaOptions (none) - Extra classpath entries to append to the classpath of the driver. + A string of extra JVM options to pass to the driver. For instance, GC settings or other logging. + +
    Note: In client mode, this config must not be set through the SparkConf + directly in your application, because the driver JVM has already started at that point. + Instead, please set this through the --driver-java-options command line option or in + your default properties file.
    (none) Set a special library path to use when launching the driver JVM. + +
    Note: In client mode, this config must not be set through the SparkConf + directly in your application, because the driver JVM has already started at that point. + Instead, please set this through the --driver-library-path command line option or in + your default properties file.
    spark.executor.extraJavaOptions(none)spark.driver.userClassPathFirstfalse - A string of extra JVM options to pass to executors. For instance, GC settings or other - logging. Note that it is illegal to set Spark properties or heap size settings with this - option. Spark properties should be set using a SparkConf object or the - spark-defaults.conf file used with the spark-submit script. Heap size settings can be set - with spark.executor.memory. + (Experimental) Whether to give user-added jars precedence over Spark's own jars when loading + classes in the the driver. This feature can be used to mitigate conflicts between Spark's + dependencies and user dependencies. It is currently an experimental feature. + + This is used in cluster mode only.
    spark.executor.extraClassPath (none) - Extra classpath entries to append to the classpath of executors. This exists primarily - for backwards-compatibility with older versions of Spark. Users typically should not need - to set this option. + Extra classpath entries to append to the classpath of executors. This exists primarily for + backwards-compatibility with older versions of Spark. Users typically should not need to set + this option.
    spark.executor.extraLibraryPathspark.executor.extraJavaOptions (none) - Set a special library path to use when launching executor JVM's. + A string of extra JVM options to pass to executors. For instance, GC settings or other logging. + Note that it is illegal to set Spark properties or heap size settings with this option. Spark + properties should be set using a SparkConf object or the spark-defaults.conf file used with the + spark-submit script. Heap size settings can be set with spark.executor.memory.
    spark.executor.logs.rolling.strategyspark.executor.extraLibraryPath (none) - Set the strategy of rolling of executor logs. By default it is disabled. It can - be set to "time" (time-based rolling) or "size" (size-based rolling). For "time", - use spark.executor.logs.rolling.time.interval to set the rolling interval. - For "size", use spark.executor.logs.rolling.size.maxBytes to set - the maximum file size for rolling. + Set a special library path to use when launching executor JVM's.
    spark.executor.logs.rolling.time.intervaldailyspark.executor.logs.rolling.maxRetainedFiles(none) - Set the time interval by which the executor logs will be rolled over. - Rolling is disabled by default. Valid values are `daily`, `hourly`, `minutely` or - any interval in seconds. See spark.executor.logs.rolling.maxRetainedFiles - for automatic cleaning of old logs. + Sets the number of latest rolling log files that are going to be retained by the system. + Older log files will be deleted. Disabled by default.
    spark.executor.logs.rolling.maxRetainedFilesspark.executor.logs.rolling.strategy (none) - Sets the number of latest rolling log files that are going to be retained by the system. - Older log files will be deleted. Disabled by default. + Set the strategy of rolling of executor logs. By default it is disabled. It can + be set to "time" (time-based rolling) or "size" (size-based rolling). For "time", + use spark.executor.logs.rolling.time.interval to set the rolling interval. + For "size", use spark.executor.logs.rolling.size.maxBytes to set + the maximum file size for rolling. +
    spark.executor.logs.rolling.time.intervaldaily + Set the time interval by which the executor logs will be rolled over. + Rolling is disabled by default. Valid values are `daily`, `hourly`, `minutely` or + any interval in seconds. See spark.executor.logs.rolling.maxRetainedFiles + for automatic cleaning of old logs.
    spark.files.userClassPathFirstspark.executor.userClassPathFirst false - (Experimental) Whether to give user-added jars precedence over Spark's own jars when - loading classes in Executors. This feature can be used to mitigate conflicts between - Spark's dependencies and user dependencies. It is currently an experimental feature. - (Currently, this setting does not work for YARN, see SPARK-2996 for more details). + (Experimental) Same functionality as spark.driver.userClassPathFirst, but + applied to executor instances.
    spark.python.worker.memory512mspark.executorEnv.[EnvironmentVariableName](none) - Amount of memory to use per python worker process during aggregation, in the same - format as JVM memory strings (e.g. 512m, 2g). If the memory - used during aggregation goes above this amount, it will spill the data into disks. + Add the environment variable specified by EnvironmentVariableName to the Executor + process. The user can specify multiple of these to set multiple environment variables.
    spark.python.worker.memory512m + Amount of memory to use per python worker process during aggregation, in the same + format as JVM memory strings (e.g. 512m, 2g). If the memory + used during aggregation goes above this amount, it will spill the data into disks. +
    spark.python.worker.reuse true
    + +#### Shuffle Behavior + + - - + + - - + + - - + + -
    Property NameDefaultMeaning
    spark.executorEnv.[EnvironmentVariableName](none)spark.reducer.maxMbInFlight48 - Add the environment variable specified by EnvironmentVariableName to the Executor - process. The user can specify multiple of these to set multiple environment variables. + Maximum size (in megabytes) of map outputs to fetch simultaneously from each reduce task. Since + each output requires us to create a buffer to receive it, this represents a fixed memory + overhead per reduce task, so keep it small unless you have a large amount of memory.
    spark.mesos.executor.homedriver side SPARK_HOMEspark.shuffle.blockTransferServicenetty - Set the directory in which Spark is installed on the executors in Mesos. By default, the - executors will simply use the driver's Spark home directory, which may not be visible to - them. Note that this is only relevant if a Spark binary package is not specified through - spark.executor.uri. + Implementation to use for transferring shuffle and cached blocks between executors. There + are two implementations available: netty and nio. Netty-based + block transfer is intended to be simpler but equally efficient and is the default option + starting in 1.2.
    spark.mesos.executor.memoryOverheadexecutor memory * 0.07, with minimum of 384spark.shuffle.compresstrue - This value is an additive for spark.executor.memory, specified in MiB, - which is used to calculate the total Mesos task memory. A value of 384 - implies a 384MiB overhead. Additionally, there is a hard-coded 7% minimum - overhead. The final overhead will be the larger of either - `spark.mesos.executor.memoryOverhead` or 7% of `spark.executor.memory`. + Whether to compress map output files. Generally a good idea. Compression will use + spark.io.compression.codec.
    - -#### Shuffle Behavior - - @@ -381,55 +393,46 @@ Apart from these, the following properties are also available, and may be useful - - + + - - + + - - + + - + - - - - - - - + + @@ -441,6 +444,17 @@ Apart from these, the following properties are also available, and may be useful the default option starting in 1.2. + + + + + @@ -450,13 +464,19 @@ Apart from these, the following properties are also available, and may be useful - - + + + + + + +
    Property NameDefaultMeaning
    spark.shuffle.consolidateFiles false
    spark.shuffle.spilltruespark.shuffle.file.buffer.kb32 - If set to "true", limits the amount of memory used during reduces by spilling data out to disk. - This spilling threshold is specified by spark.shuffle.memoryFraction. + Size of the in-memory buffer for each shuffle file output stream, in kilobytes. These buffers + reduce the number of disk seeks and system calls made in creating intermediate shuffle files.
    spark.shuffle.spill.compresstruespark.shuffle.io.maxRetries3 - Whether to compress data spilled during shuffles. Compression will use - spark.io.compression.codec. + (Netty only) Fetches that fail due to IO-related exceptions are automatically retried if this is + set to a non-zero value. This retry logic helps stabilize large shuffles in the face of long GC + pauses or transient network connectivity issues.
    spark.shuffle.memoryFraction0.2spark.shuffle.io.numConnectionsPerPeer1 - Fraction of Java heap to use for aggregation and cogroups during shuffles, if - spark.shuffle.spill is true. At any given time, the collective size of - all in-memory maps used for shuffles is bounded by this limit, beyond which the contents will - begin to spill to disk. If spills are often, consider increasing this value at the expense of - spark.storage.memoryFraction. + (Netty only) Connections between hosts are reused in order to reduce connection buildup for + large clusters. For clusters with many hard disks and few hosts, this may result in insufficient + concurrency to saturate all disks, and so users may consider increasing this value.
    spark.shuffle.compressspark.shuffle.io.preferDirectBufs true - Whether to compress map output files. Generally a good idea. Compression will use - spark.io.compression.codec. -
    spark.shuffle.file.buffer.kb32 - Size of the in-memory buffer for each shuffle file output stream, in kilobytes. These buffers - reduce the number of disk seeks and system calls made in creating intermediate shuffle files. + (Netty only) Off-heap buffers are used to reduce garbage collection during shuffle and cache + block transfer. For environments where off-heap memory is tightly limited, users may wish to + turn this off to force all allocations from Netty to be on-heap.
    spark.reducer.maxMbInFlight48spark.shuffle.io.retryWait5 - Maximum size (in megabytes) of map outputs to fetch simultaneously from each reduce task. Since - each output requires us to create a buffer to receive it, this represents a fixed memory - overhead per reduce task, so keep it small unless you have a large amount of memory. + (Netty only) Seconds to wait between retries of fetches. The maximum delay caused by retrying + is simply maxRetries * retryWait, by default 15 seconds.
    spark.shuffle.memoryFraction0.2 + Fraction of Java heap to use for aggregation and cogroups during shuffles, if + spark.shuffle.spill is true. At any given time, the collective size of + all in-memory maps used for shuffles is bounded by this limit, beyond which the contents will + begin to spill to disk. If spills are often, consider increasing this value at the expense of + spark.storage.memoryFraction. +
    spark.shuffle.sort.bypassMergeThreshold 200
    spark.shuffle.blockTransferServicenettyspark.shuffle.spilltrue - Implementation to use for transferring shuffle and cached blocks between executors. There - are two implementations available: netty and nio. Netty-based - block transfer is intended to be simpler but equally efficient and is the default option - starting in 1.2. + If set to "true", limits the amount of memory used during reduces by spilling data out to disk. + This spilling threshold is specified by spark.shuffle.memoryFraction. +
    spark.shuffle.spill.compresstrue + Whether to compress data spilled during shuffles. Compression will use + spark.io.compression.codec.
    @@ -465,26 +485,28 @@ Apart from these, the following properties are also available, and may be useful - - + + - - + + - - + + @@ -495,28 +517,26 @@ Apart from these, the following properties are also available, and may be useful - - + + - - + + - - + +
    Property NameDefaultMeaning
    spark.ui.port4040spark.eventLog.compressfalse - Port for your application's dashboard, which shows memory and workload data. + Whether to compress logged events, if spark.eventLog.enabled is true.
    spark.ui.retainedStages1000spark.eventLog.dirfile:///tmp/spark-events - How many stages the Spark UI and status APIs remember before garbage - collecting. + Base directory in which Spark events are logged, if spark.eventLog.enabled is true. + Within this base directory, Spark creates a sub-directory for each application, and logs the + events specific to the application in this directory. Users may want to set this to + a unified location like an HDFS directory so history files can be read by the history server.
    spark.ui.retainedJobs1000spark.eventLog.enabledfalse - How many jobs the Spark UI and status APIs remember before garbage - collecting. + Whether to log Spark events, useful for reconstructing the Web UI after the application has + finished.
    spark.eventLog.enabledfalsespark.ui.port4040 - Whether to log Spark events, useful for reconstructing the Web UI after the application has - finished. + Port for your application's dashboard, which shows memory and workload data.
    spark.eventLog.compressfalsespark.ui.retainedJobs1000 - Whether to compress logged events, if spark.eventLog.enabled is true. + How many jobs the Spark UI and status APIs remember before garbage + collecting.
    spark.eventLog.dirfile:///tmp/spark-eventsspark.ui.retainedStages1000 - Base directory in which Spark events are logged, if spark.eventLog.enabled is true. - Within this base directory, Spark creates a sub-directory for each application, and logs the - events specific to the application in this directory. Users may want to set this to - a unified location like an HDFS directory so history files can be read by the history server. + How many stages the Spark UI and status APIs remember before garbage + collecting.
    @@ -532,12 +552,10 @@ Apart from these, the following properties are also available, and may be useful - spark.rdd.compress - false + spark.closure.serializer + org.apache.spark.serializer.
    JavaSerializer - Whether to compress serialized RDD partitions (e.g. for - StorageLevel.MEMORY_ONLY_SER). Can save substantial space at the cost of some - extra CPU time. + Serializer class to use for closures. Currently only the Java serializer is supported. @@ -553,14 +571,6 @@ Apart from these, the following properties are also available, and may be useful and org.apache.spark.io.SnappyCompressionCodec. - - spark.io.compression.snappy.block.size - 32768 - - Block size (in bytes) used in Snappy compression, in the case when Snappy compression codec - is used. Lowering this block size will also lower shuffle memory usage when Snappy is used. - - spark.io.compression.lz4.block.size 32768 @@ -570,21 +580,20 @@ Apart from these, the following properties are also available, and may be useful - spark.closure.serializer - org.apache.spark.serializer.
    JavaSerializer + spark.io.compression.snappy.block.size + 32768 - Serializer class to use for closures. Currently only the Java serializer is supported. + Block size (in bytes) used in Snappy compression, in the case when Snappy compression codec + is used. Lowering this block size will also lower shuffle memory usage when Snappy is used. - spark.serializer.objectStreamReset - 100 + spark.kryo.classesToRegister + (none) - When serializing using org.apache.spark.serializer.JavaSerializer, the serializer caches - objects to prevent writing redundant data, however that stops garbage collection of those - objects. By calling 'reset' you flush that info from the serializer, and allow old - objects to be collected. To turn off this periodic reset set it to -1. - By default it will reset the serializer every 100 objects. + If you use Kryo serialization, give a comma-separated list of custom class names to register + with Kryo. + See the tuning guide for more details. @@ -609,12 +618,16 @@ Apart from these, the following properties are also available, and may be useful - spark.kryoserializer.buffer.mb - 0.064 + spark.kryo.registrator + (none) - Initial size of Kryo's serialization buffer, in megabytes. Note that there will be one buffer - per core on each worker. This buffer will grow up to - spark.kryoserializer.buffer.max.mb if needed. + If you use Kryo serialization, set this class to register your custom classes with Kryo. This + property is useful if you need to register your classes in a custom way, e.g. to specify a custom + field serializer. Otherwise spark.kryo.classesToRegister is simpler. It should be + set to a class that extends + + KryoRegistrator. + See the tuning guide for more details. @@ -626,11 +639,80 @@ Apart from these, the following properties are also available, and may be useful inside Kryo. + + spark.kryoserializer.buffer.mb + 0.064 + + Initial size of Kryo's serialization buffer, in megabytes. Note that there will be one buffer + per core on each worker. This buffer will grow up to + spark.kryoserializer.buffer.max.mb if needed. + + + + spark.rdd.compress + false + + Whether to compress serialized RDD partitions (e.g. for + StorageLevel.MEMORY_ONLY_SER). Can save substantial space at the cost of some + extra CPU time. + + + + spark.serializer + org.apache.spark.serializer.
    JavaSerializer + + Class to use for serializing objects that will be sent over the network or need to be cached + in serialized form. The default of Java serialization works with any Serializable Java object + but is quite slow, so we recommend using + org.apache.spark.serializer.KryoSerializer and configuring Kryo serialization + when speed is necessary. Can be any subclass of + + org.apache.spark.Serializer. + + + + spark.serializer.objectStreamReset + 100 + + When serializing using org.apache.spark.serializer.JavaSerializer, the serializer caches + objects to prevent writing redundant data, however that stops garbage collection of those + objects. By calling 'reset' you flush that info from the serializer, and allow old + objects to be collected. To turn off this periodic reset set it to -1. + By default it will reset the serializer every 100 objects. + + #### Execution Behavior + + + + + + + + + + + + + + + - - + + + + + + + - - + + @@ -673,12 +766,23 @@ Apart from these, the following properties are also available, and may be useful - - - + + + + + + + + @@ -689,6 +793,15 @@ Apart from these, the following properties are also available, and may be useful increase it if you configure your own old generation size. + + + + + @@ -707,15 +820,6 @@ Apart from these, the following properties are also available, and may be useful directories on Tachyon file system. - - - - - @@ -723,106 +827,19 @@ Apart from these, the following properties are also available, and may be useful The URL of the underlying Tachyon file system in the TachyonStore. - - - - - - - - - - - - - - - - - - - -
    Property NameDefaultMeaning
    spark.broadcast.blockSize4096 + Size of each piece of a block in kilobytes for TorrentBroadcastFactory. + Too large a value decreases parallelism during broadcast (makes it slower); however, if it is + too small, BlockManager might take a performance hit. +
    spark.broadcast.factoryorg.apache.spark.broadcast.
    TorrentBroadcastFactory
    + Which broadcast implementation to use. +
    spark.cleaner.ttl(infinite) + Duration (seconds) of how long Spark will remember any metadata (stages generated, tasks + generated, etc.). Periodic cleanups will ensure that metadata older than this duration will be + forgotten. This is useful for running Spark for many hours / days (for example, running 24/7 in + case of Spark Streaming applications). Note that any RDD that persists in memory for more than + this duration will be cleared as well. +
    spark.default.parallelism @@ -649,19 +731,30 @@ Apart from these, the following properties are also available, and may be useful
    spark.broadcast.factoryorg.apache.spark.broadcast.
    TorrentBroadcastFactory
    spark.executor.heartbeatInterval10000Interval (milliseconds) between each executor's heartbeats to the driver. Heartbeats let + the driver know that the executor is still alive and update it with metrics for in-progress + tasks.
    spark.files.fetchTimeout60 - Which broadcast implementation to use. + Communication timeout to use when fetching files added through SparkContext.addFile() from + the driver, in seconds.
    spark.broadcast.blockSize4096spark.files.useFetchCachetrue - Size of each piece of a block in kilobytes for TorrentBroadcastFactory. - Too large a value decreases parallelism during broadcast (makes it slower); however, if it is - too small, BlockManager might take a performance hit. + If set to true (default), file fetching will use a local cache that is shared by executors + that belong to the same application, which can improve task launching performance when + running many executors on the same host. If set to false, these caching optimizations will + be disabled and all executors will fetch their own copies of files. This optimization may be + disabled in order to use Spark local directories that reside on NFS filesystems (see + SPARK-6313 for more details).
    spark.files.fetchTimeout60 - Communication timeout to use when fetching files added through SparkContext.addFile() from - the driver, in seconds. - spark.hadoop.cloneConffalseIf set to true, clones a new Hadoop Configuration object for each task. This + option should be enabled to work around Configuration thread-safety issues (see + SPARK-2546 for more details). + This is disabled by default in order to avoid unexpected performance regressions for jobs that + are not affected by these issues.
    spark.hadoop.validateOutputSpecstrueIf set to true, validates the output specification (e.g. checking if the output directory already exists) + used in saveAsHadoopFile and other variants. This can be disabled to silence exceptions due to pre-existing + output directories. We recommend that users do not disable this except if trying to achieve compatibility with + previous versions of Spark. Simply use Hadoop's FileSystem API to delete output directories by hand. + This setting is ignored for jobs generated through Spark Streaming's StreamingContext, since + data may need to be rewritten to pre-existing output directories during checkpoint recovery.
    spark.storage.memoryFraction
    spark.storage.memoryMapThreshold2097152 + Size of a block, in bytes, above which Spark memory maps when reading a block from disk. + This prevents Spark from memory mapping very small blocks. In general, memory + mapping has high overhead for blocks close to or below the page size of the operating system. +
    spark.storage.unrollFraction 0.2
    spark.storage.memoryMapThreshold2097152 - Size of a block, in bytes, above which Spark memory maps when reading a block from disk. - This prevents Spark from memory mapping very small blocks. In general, memory - mapping has high overhead for blocks close to or below the page size of the operating system. -
    spark.tachyonStore.url tachyon://localhost:19998
    spark.cleaner.ttl(infinite) - Duration (seconds) of how long Spark will remember any metadata (stages generated, tasks - generated, etc.). Periodic cleanups will ensure that metadata older than this duration will be - forgotten. This is useful for running Spark for many hours / days (for example, running 24/7 in - case of Spark Streaming applications). Note that any RDD that persists in memory for more than - this duration will be cleared as well. -
    spark.hadoop.validateOutputSpecstrueIf set to true, validates the output specification (e.g. checking if the output directory already exists) - used in saveAsHadoopFile and other variants. This can be disabled to silence exceptions due to pre-existing - output directories. We recommend that users do not disable this except if trying to achieve compatibility with - previous versions of Spark. Simply use Hadoop's FileSystem API to delete output directories by hand. - This setting is ignored for jobs generated through Spark Streaming's StreamingContext, since - data may need to be rewritten to pre-existing output directories during checkpoint recovery.
    spark.hadoop.cloneConffalseIf set to true, clones a new Hadoop Configuration object for each task. This - option should be enabled to work around Configuration thread-safety issues (see - SPARK-2546 for more details). - This is disabled by default in order to avoid unexpected performance regressions for jobs that - are not affected by these issues.
    spark.executor.heartbeatInterval10000Interval (milliseconds) between each executor's heartbeats to the driver. Heartbeats let - the driver know that the executor is still alive and update it with metrics for in-progress - tasks.
    #### Networking - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + @@ -835,181 +852,139 @@ Apart from these, the following properties are also available, and may be useful - - - - - - - - - - - - + + - - - - - - - - - - - - + + - - + + - - + + - - + + -
    Property NameDefaultMeaning
    spark.driver.host(local hostname) - Hostname or IP address for the driver to listen on. - This is used for communicating with the executors and the standalone Master. -
    spark.driver.port(random) - Port for the driver to listen on. - This is used for communicating with the executors and the standalone Master. -
    spark.fileserver.port(random) - Port for the driver's HTTP file server to listen on. -
    spark.broadcast.port(random) - Port for the driver's HTTP broadcast server to listen on. - This is not relevant for torrent broadcast. -
    spark.replClassServer.port(random) - Port for the driver's HTTP class server to listen on. - This is only relevant for the Spark shell. -
    spark.blockManager.port(random) - Port for all block managers to listen on. These exist on both the driver and the executors. -
    spark.executor.port(random) - Port for the executor to listen on. This is used for communicating with the driver. -
    spark.port.maxRetries16spark.akka.failure-detector.threshold300.0 - Default maximum number of retries when binding to a port before giving up. + This is set to a larger value to disable failure detector that comes inbuilt akka. It can be + enabled again, if you plan to use this feature (Not recommended). This maps to akka's + `akka.remote.transport-failure-detector.threshold`. Tune this in combination of + `spark.akka.heartbeat.pauses` and `spark.akka.heartbeat.interval` if you need to.
    spark.akka.threads4 - Number of actor threads to use for communication. Can be useful to increase on large clusters - when the driver has a lot of CPU cores. -
    spark.akka.timeout100 - Communication timeout between Spark nodes, in seconds. -
    spark.network.timeout120spark.akka.heartbeat.interval1000 - Default timeout for all network interactions, in seconds. This config will be used in - place of spark.core.connection.ack.wait.timeout, spark.akka.timeout, - spark.storage.blockManagerSlaveTimeoutMs or - spark.shuffle.io.connectionTimeout, if they are not configured. + This is set to a larger value to disable the transport failure detector that comes built in to + Akka. It can be enabled again, if you plan to use this feature (Not recommended). A larger + interval value in seconds reduces network overhead and a smaller value ( ~ 1 s) might be more + informative for Akka's failure detector. Tune this in combination of `spark.akka.heartbeat.pauses` + if you need to. A likely positive use case for using failure detector would be: a sensistive + failure detector can help evict rogue executors quickly. However this is usually not the case + as GC pauses and network lags are expected in a real Spark cluster. Apart from that enabling + this leads to a lot of exchanges of heart beats between nodes leading to flooding the network + with those.
    spark.akka.heartbeat.pauses 6000 - This is set to a larger value to disable failure detector that comes inbuilt akka. It can be - enabled again, if you plan to use this feature (Not recommended). Acceptable heart beat pause - in seconds for akka. This can be used to control sensitivity to gc pauses. Tune this in - combination of `spark.akka.heartbeat.interval` and `spark.akka.failure-detector.threshold` - if you need to. -
    spark.akka.failure-detector.threshold300.0 - This is set to a larger value to disable failure detector that comes inbuilt akka. It can be - enabled again, if you plan to use this feature (Not recommended). This maps to akka's - `akka.remote.transport-failure-detector.threshold`. Tune this in combination of - `spark.akka.heartbeat.pauses` and `spark.akka.heartbeat.interval` if you need to. -
    spark.akka.heartbeat.interval1000 - This is set to a larger value to disable failure detector that comes inbuilt akka. It can be - enabled again, if you plan to use this feature (Not recommended). A larger interval value in - seconds reduces network overhead and a smaller value ( ~ 1 s) might be more informative for - akka's failure detector. Tune this in combination of `spark.akka.heartbeat.pauses` and - `spark.akka.failure-detector.threshold` if you need to. Only positive use case for using - failure detector can be, a sensistive failure detector can help evict rogue executors really - quick. However this is usually not the case as gc pauses and network lags are expected in a - real Spark cluster. Apart from that enabling this leads to a lot of exchanges of heart beats - between nodes leading to flooding the network with those. + This is set to a larger value to disable the transport failure detector that comes built in to Akka. + It can be enabled again, if you plan to use this feature (Not recommended). Acceptable heart + beat pause in seconds for Akka. This can be used to control sensitivity to GC pauses. Tune + this along with `spark.akka.heartbeat.interval` if you need to.
    spark.shuffle.io.preferDirectBufstruespark.akka.threads4 - (Netty only) Off-heap buffers are used to reduce garbage collection during shuffle and cache - block transfer. For environments where off-heap memory is tightly limited, users may wish to - turn this off to force all allocations from Netty to be on-heap. + Number of actor threads to use for communication. Can be useful to increase on large clusters + when the driver has a lot of CPU cores.
    spark.shuffle.io.numConnectionsPerPeer1spark.akka.timeout100 - (Netty only) Connections between hosts are reused in order to reduce connection buildup for - large clusters. For clusters with many hard disks and few hosts, this may result in insufficient - concurrency to saturate all disks, and so users may consider increasing this value. + Communication timeout between Spark nodes, in seconds.
    spark.shuffle.io.maxRetries3spark.blockManager.port(random) - (Netty only) Fetches that fail due to IO-related exceptions are automatically retried if this is - set to a non-zero value. This retry logic helps stabilize large shuffles in the face of long GC - pauses or transient network connectivity issues. + Port for all block managers to listen on. These exist on both the driver and the executors.
    spark.shuffle.io.retryWait5spark.broadcast.port(random) - (Netty only) Seconds to wait between retries of fetches. The maximum delay caused by retrying - is simply maxRetries * retryWait, by default 15 seconds. + Port for the driver's HTTP broadcast server to listen on. + This is not relevant for torrent broadcast.
    - -#### Scheduling - - - - + + - - + + - - + + - - + + - - + + - - + + - - + + +
    Property NameDefaultMeaning
    spark.task.cpus1spark.driver.host(local hostname) - Number of cores to allocate for each task. + Hostname or IP address for the driver to listen on. + This is used for communicating with the executors and the standalone Master.
    spark.task.maxFailures4spark.driver.port(random) - Number of individual task failures before giving up on the job. - Should be greater than or equal to 1. Number of allowed retries = this value - 1. + Port for the driver to listen on. + This is used for communicating with the executors and the standalone Master.
    spark.scheduler.modeFIFOspark.executor.port(random) - The scheduling mode between - jobs submitted to the same SparkContext. Can be set to FAIR - to use fair sharing instead of queueing jobs one after another. Useful for - multi-user services. + Port for the executor to listen on. This is used for communicating with the driver.
    spark.cores.max(not set)spark.fileserver.port(random) - When running on a standalone deploy cluster or a - Mesos cluster in "coarse-grained" - sharing mode, the maximum amount of CPU cores to request for the application from - across the cluster (not from each machine). If not set, the default will be - spark.deploy.defaultCores on Spark's standalone cluster manager, or - infinite (all available cores) on Mesos. + Port for the driver's HTTP file server to listen on.
    spark.mesos.coarsefalsespark.network.timeout120 - If set to "true", runs over Mesos clusters in - "coarse-grained" sharing mode, - where Spark acquires one long-lived Mesos task on each machine instead of one Mesos task per - Spark task. This gives lower-latency scheduling for short queries, but leaves resources in use - for the whole duration of the Spark job. + Default timeout for all network interactions, in seconds. This config will be used in + place of spark.core.connection.ack.wait.timeout, spark.akka.timeout, + spark.storage.blockManagerSlaveTimeoutMs or + spark.shuffle.io.connectionTimeout, if they are not configured.
    spark.speculationfalsespark.port.maxRetries16 - If set to "true", performs speculative execution of tasks. This means if one or more tasks are - running slowly in a stage, they will be re-launched. + Default maximum number of retries when binding to a port before giving up.
    spark.speculation.interval100spark.replClassServer.port(random) - How often Spark will check for tasks to speculate, in milliseconds. + Port for the driver's HTTP class server to listen on. + This is only relevant for the Spark shell.
    + +#### Scheduling + + - - + + - - + + @@ -1025,19 +1000,19 @@ Apart from these, the following properties are also available, and may be useful - + - + @@ -1048,16 +1023,16 @@ Apart from these, the following properties are also available, and may be useful - - + + - + - + - - + + + + + + + - + + + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.speculation.quantile0.75spark.cores.max(not set) - Percentage of tasks which must be complete before speculation is enabled for a particular stage. + When running on a standalone deploy cluster or a + Mesos cluster in "coarse-grained" + sharing mode, the maximum amount of CPU cores to request for the application from + across the cluster (not from each machine). If not set, the default will be + spark.deploy.defaultCores on Spark's standalone cluster manager, or + infinite (all available cores) on Mesos.
    spark.speculation.multiplier1.5spark.localExecution.enabledfalse - How many times slower a task is than the median to be considered for speculation. + Enables Spark to run certain jobs, such as first() or take() on the driver, without sending + tasks to the cluster. This can make certain jobs execute very quickly, but may require + shipping a whole partition of data to the driver.
    spark.locality.wait.processspark.locality.wait.node spark.locality.wait - Customize the locality wait for process locality. This affects tasks that attempt to access - cached data in a particular executor process. + Customize the locality wait for node locality. For example, you can set this to 0 to skip + node locality and search immediately for rack locality (if your cluster has rack information).
    spark.locality.wait.nodespark.locality.wait.process spark.locality.wait - Customize the locality wait for node locality. For example, you can set this to 0 to skip - node locality and search immediately for rack locality (if your cluster has rack information). + Customize the locality wait for process locality. This affects tasks that attempt to access + cached data in a particular executor process.
    spark.scheduler.revive.interval1000spark.scheduler.maxRegisteredResourcesWaitingTime30000 - The interval length for the scheduler to revive the worker resource offers to run tasks + Maximum amount of time to wait for resources to register before scheduling begins (in milliseconds).
    spark.scheduler.minRegisteredResourcesRatio0.0 for Mesos and Standalone mode, 0.8 for YARN0.8 for YARN mode; 0.0 otherwise The minimum ratio of registered resources (registered resources / total expected resources) (resources are executors in yarn mode, CPU cores in standalone mode) @@ -1068,25 +1043,70 @@ Apart from these, the following properties are also available, and may be useful
    spark.scheduler.maxRegisteredResourcesWaitingTime30000spark.scheduler.modeFIFO - Maximum amount of time to wait for resources to register before scheduling begins + The scheduling mode between + jobs submitted to the same SparkContext. Can be set to FAIR + to use fair sharing instead of queueing jobs one after another. Useful for + multi-user services. +
    spark.scheduler.revive.interval1000 + The interval length for the scheduler to revive the worker resource offers to run tasks (in milliseconds).
    spark.localExecution.enabledspark.speculation false - Enables Spark to run certain jobs, such as first() or take() on the driver, without sending - tasks to the cluster. This can make certain jobs execute very quickly, but may require - shipping a whole partition of data to the driver. + If set to "true", performs speculative execution of tasks. This means if one or more tasks are + running slowly in a stage, they will be re-launched. +
    spark.speculation.interval100 + How often Spark will check for tasks to speculate, in milliseconds. +
    spark.speculation.multiplier1.5 + How many times slower a task is than the median to be considered for speculation. +
    spark.speculation.quantile0.75 + Percentage of tasks which must be complete before speculation is enabled for a particular stage. +
    spark.task.cpus1 + Number of cores to allocate for each task. +
    spark.task.maxFailures4 + Number of individual task failures before giving up on the job. + Should be greater than or equal to 1. Number of allowed retries = this value - 1.
    -#### Dynamic allocation +#### Dynamic Allocation @@ -1106,10 +1126,19 @@ Apart from these, the following properties are also available, and may be useful + + + + + + - @@ -1120,15 +1149,15 @@ Apart from these, the following properties are also available, and may be useful - + - + - - - - -
    Property NameDefaultMeaning
    spark.dynamicAllocation.executorIdleTimeout600 + If dynamic allocation is enabled and an executor has been idle for more than this duration + (in seconds), the executor will be removed. For more detail, see this + description. +
    spark.dynamicAllocation.initialExecutors spark.dynamicAllocation.minExecutors0 - Lower bound for the number of executors if dynamic allocation is enabled. + Initial number of executors to run if dynamic allocation is enabled.
    spark.dynamicAllocation.maxExecutors spark.dynamicAllocation.minExecutors0 - Initial number of executors to run if dynamic allocation is enabled. + Lower bound for the number of executors if dynamic allocation is enabled.
    spark.dynamicAllocation.schedulerBacklogTimeout605 If dynamic allocation is enabled and there have been pending tasks backlogged for more than this duration (in seconds), new executors will be requested. For more detail, see this @@ -1144,20 +1173,30 @@ Apart from these, the following properties are also available, and may be useful description.
    spark.dynamicAllocation.executorIdleTimeout600 - If dynamic allocation is enabled and an executor has been idle for more than this duration - (in seconds), the executor will be removed. For more detail, see this - description. -
    #### Security + + + + + + + + + + @@ -1174,6 +1213,15 @@ Apart from these, the following properties are also available, and may be useful not running on YARN and authentication is enabled. + + + + + @@ -1183,12 +1231,11 @@ Apart from these, the following properties are also available, and may be useful - - + + @@ -1205,16 +1252,6 @@ Apart from these, the following properties are also available, and may be useful -Dspark.com.test.filter1.params='param1=foo,param2=testing' - - - - - @@ -1223,23 +1260,6 @@ Apart from these, the following properties are also available, and may be useful user that started the Spark job has view access. - - - - - - - - - -
    Property NameDefaultMeaning
    spark.acls.enablefalse + Whether Spark acls should are enabled. If enabled, this checks to see if the user has + access permissions to view or modify the job. Note this requires the user to be known, + so if the user comes across as null no checks are done. Filters can be used with the UI + to authenticate and set the user. +
    spark.admin.aclsEmpty + Comma separated list of users/administrators that have view and modify access to all Spark jobs. + This can be used if you run on a shared cluster and have a set of administrators or devs who + help debug when things work. +
    spark.authenticate false
    spark.core.connection.ack.wait.timeout60 + Number of seconds for the connection to wait for ack to occur before timing + out and giving up. To avoid unwilling timeout caused by long pause like GC, + you can set larger value. +
    spark.core.connection.auth.wait.timeout 30
    spark.core.connection.ack.wait.timeout60spark.modify.aclsEmpty - Number of seconds for the connection to wait for ack to occur before timing - out and giving up. To avoid unwilling timeout caused by long pause like GC, - you can set larger value. + Comma separated list of users that have modify access to the Spark job. By default only the + user that started the Spark job has access to modify it (kill it for example).
    spark.acls.enablefalse - Whether Spark acls should are enabled. If enabled, this checks to see if the user has - access permissions to view or modify the job. Note this requires the user to be known, - so if the user comes across as null no checks are done. Filters can be used with the UI - to authenticate and set the user. -
    spark.ui.view.acls Empty
    spark.modify.aclsEmpty - Comma separated list of users that have modify access to the Spark job. By default only the - user that started the Spark job has access to modify it (kill it for example). -
    spark.admin.aclsEmpty - Comma separated list of users/administrators that have view and modify access to all Spark jobs. - This can be used if you run on a shared cluster and have a set of administrators or devs who - help debug when things work. -
    #### Encryption @@ -1263,6 +1283,23 @@ Apart from these, the following properties are also available, and may be useful file server.

    + + spark.ssl.enabledAlgorithms + Empty + + A comma separated list of ciphers. The specified ciphers must be supported by JVM. + The reference list of protocols one can find on + this + page. + + + + spark.ssl.keyPassword + None + + A password to the private key in key-store. + + spark.ssl.keyStore None @@ -1279,10 +1316,12 @@ Apart from these, the following properties are also available, and may be useful - spark.ssl.keyPassword + spark.ssl.protocol None - A password to the private key in key-store. + A protocol name. The protocol must be supported by JVM. The reference list of protocols + one can find on this + page. @@ -1300,25 +1339,6 @@ Apart from these, the following properties are also available, and may be useful A password to the trust-store. - - spark.ssl.protocol - None - - A protocol name. The protocol must be supported by JVM. The reference list of protocols - one can find on this - page. - - - - spark.ssl.enabledAlgorithms - Empty - - A comma separated list of ciphers. The specified ciphers must be supported by JVM. - The reference list of protocols one can find on - this - page. - - @@ -1337,9 +1357,9 @@ Apart from these, the following properties are also available, and may be useful spark.streaming.receiver.maxRate - infinite + not set - Maximum number records per second at which each receiver will receive data. + Maximum rate (number of records per second) at which each receiver will receive data. Effectively, each stream will consume at most this number of records per second. Setting this configuration to 0 or a negative number will put no limit on the rate. See the deployment guide @@ -1367,6 +1387,16 @@ Apart from these, the following properties are also available, and may be useful higher memory usage in Spark. + + spark.streaming.kafka.maxRatePerPartition + not set + + Maximum rate (number of records per second) at which data will be read from each Kafka + partition when using the new Kafka direct stream API. See the + Kafka Integration guide + for more details. + + #### Cluster Managers diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md index d50f445d7ecc7..8c9a1e1262d8f 100644 --- a/docs/ec2-scripts.md +++ b/docs/ec2-scripts.md @@ -52,7 +52,7 @@ identify machines belonging to each cluster in the Amazon EC2 Console. ```bash export AWS_SECRET_ACCESS_KEY=AaBbCcDdEeFGgHhIiJjKkLlMmNnOoPpQqRrSsTtU export AWS_ACCESS_KEY_ID=ABCDEFG1234567890123 -./spark-ec2 --key-pair=awskey --identity-file=awskey.pem --region=us-west-1 --zone=us-west-1a --spark-version=1.1.0 launch my-spark-cluster +./spark-ec2 --key-pair=awskey --identity-file=awskey.pem --region=us-west-1 --zone=us-west-1a launch my-spark-cluster ``` - After everything launches, check that the cluster scheduler is up and sees diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index e298c51f8a5b7..826f6d8f371c7 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -1,6 +1,8 @@ --- layout: global -title: GraphX Programming Guide +displayTitle: GraphX Programming Guide +title: GraphX +description: GraphX graph processing library guide for Spark SPARK_VERSION_SHORT --- * This will become a table of contents (this text will be scraped). diff --git a/docs/img/PIClusteringFiveCirclesInputsAndOutputs.png b/docs/img/PIClusteringFiveCirclesInputsAndOutputs.png deleted file mode 100644 index ed9adad11d03a..0000000000000 Binary files a/docs/img/PIClusteringFiveCirclesInputsAndOutputs.png and /dev/null differ diff --git a/docs/img/cluster-overview.png b/docs/img/cluster-overview.png index 368274068e754..317554c5f2a5b 100644 Binary files a/docs/img/cluster-overview.png and b/docs/img/cluster-overview.png differ diff --git a/docs/img/cluster-overview.pptx b/docs/img/cluster-overview.pptx index af3c462cd904d..1b90d7ec5a7ae 100644 Binary files a/docs/img/cluster-overview.pptx and b/docs/img/cluster-overview.pptx differ diff --git a/docs/index.md b/docs/index.md index 171d6ddad62f3..b5b016e34795e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,6 +1,8 @@ --- layout: global -title: Spark Overview +displayTitle: Spark Overview +title: Overview +description: Apache Spark SPARK_VERSION_SHORT documentation homepage --- Apache Spark is a fast and general-purpose cluster computing system. @@ -72,7 +74,7 @@ options for deployment: in all supported languages (Scala, Java, Python) * Modules built on Spark: * [Spark Streaming](streaming-programming-guide.html): processing real-time data streams - * [Spark SQL](sql-programming-guide.html): support for structured data and relational queries + * [Spark SQL and DataFrames](sql-programming-guide.html): support for structured data and relational queries * [MLlib](mllib-guide.html): built-in machine learning library * [GraphX](graphx-programming-guide.html): Spark's new API for graph processing * [Bagel (Pregel on Spark)](bagel-programming-guide.html): older, simple graph processing model @@ -113,6 +115,8 @@ options for deployment: * [Spark Homepage](http://spark.apache.org) * [Spark Wiki](https://cwiki.apache.org/confluence/display/SPARK) +* [Spark Community](http://spark.apache.org/community.html) resources, including local meetups +* [StackOverflow tag `apache-spark`](http://stackoverflow.com/questions/tagged/apache-spark) * [Mailing Lists](http://spark.apache.org/mailing-lists.html): ask questions about Spark here * [AMP Camps](http://ampcamp.berkeley.edu/): a series of training camps at UC Berkeley that featured talks and exercises about Spark, Spark Streaming, Mesos, and more. [Videos](http://ampcamp.berkeley.edu/3/), @@ -121,11 +125,3 @@ options for deployment: * [Code Examples](http://spark.apache.org/examples.html): more are also available in the `examples` subfolder of Spark ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples), [Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python)) - -# Community - -To get help using Spark or keep up with Spark development, sign up for the [user mailing list](http://spark.apache.org/mailing-lists.html). - -If you're in the San Francisco Bay Area, there's a regular [Spark meetup](http://www.meetup.com/spark-users/) every few weeks. Come by to meet the developers and other users. - -Finally, if you'd like to contribute code to Spark, read [how to contribute](contributing-to-spark.html). diff --git a/docs/ml-guide.md b/docs/ml-guide.md index be178d7689fdd..771a07183e26f 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -23,13 +23,13 @@ to `spark.ml`. Spark ML standardizes APIs for machine learning algorithms to make it easier to combine multiple algorithms into a single pipeline, or workflow. This section covers the key concepts introduced by the Spark ML API. -* **[ML Dataset](ml-guide.html#ml-dataset)**: Spark ML uses the [`SchemaRDD`](api/scala/index.html#org.apache.spark.sql.SchemaRDD) from Spark SQL as a dataset which can hold a variety of data types. +* **[ML Dataset](ml-guide.html#ml-dataset)**: Spark ML uses the [`DataFrame`](api/scala/index.html#org.apache.spark.sql.DataFrame) from Spark SQL as a dataset which can hold a variety of data types. E.g., a dataset could have different columns storing text, feature vectors, true labels, and predictions. -* **[`Transformer`](ml-guide.html#transformers)**: A `Transformer` is an algorithm which can transform one `SchemaRDD` into another `SchemaRDD`. +* **[`Transformer`](ml-guide.html#transformers)**: A `Transformer` is an algorithm which can transform one `DataFrame` into another `DataFrame`. E.g., an ML model is a `Transformer` which transforms an RDD with features into an RDD with predictions. -* **[`Estimator`](ml-guide.html#estimators)**: An `Estimator` is an algorithm which can be fit on a `SchemaRDD` to produce a `Transformer`. +* **[`Estimator`](ml-guide.html#estimators)**: An `Estimator` is an algorithm which can be fit on a `DataFrame` to produce a `Transformer`. E.g., a learning algorithm is an `Estimator` which trains on a dataset and produces a model. * **[`Pipeline`](ml-guide.html#pipeline)**: A `Pipeline` chains multiple `Transformer`s and `Estimator`s together to specify an ML workflow. @@ -39,20 +39,20 @@ E.g., a learning algorithm is an `Estimator` which trains on a dataset and produ ## ML Dataset Machine learning can be applied to a wide variety of data types, such as vectors, text, images, and structured data. -Spark ML adopts the [`SchemaRDD`](api/scala/index.html#org.apache.spark.sql.SchemaRDD) from Spark SQL in order to support a variety of data types under a unified Dataset concept. +Spark ML adopts the [`DataFrame`](api/scala/index.html#org.apache.spark.sql.DataFrame) from Spark SQL in order to support a variety of data types under a unified Dataset concept. -`SchemaRDD` supports many basic and structured types; see the [Spark SQL datatype reference](sql-programming-guide.html#spark-sql-datatype-reference) for a list of supported types. -In addition to the types listed in the Spark SQL guide, `SchemaRDD` can use ML [`Vector`](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) types. +`DataFrame` supports many basic and structured types; see the [Spark SQL datatype reference](sql-programming-guide.html#spark-sql-datatype-reference) for a list of supported types. +In addition to the types listed in the Spark SQL guide, `DataFrame` can use ML [`Vector`](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) types. -A `SchemaRDD` can be created either implicitly or explicitly from a regular `RDD`. See the code examples below and the [Spark SQL programming guide](sql-programming-guide.html) for examples. +A `DataFrame` can be created either implicitly or explicitly from a regular `RDD`. See the code examples below and the [Spark SQL programming guide](sql-programming-guide.html) for examples. -Columns in a `SchemaRDD` are named. The code examples below use names such as "text," "features," and "label." +Columns in a `DataFrame` are named. The code examples below use names such as "text," "features," and "label." ## ML Algorithms ### Transformers -A [`Transformer`](api/scala/index.html#org.apache.spark.ml.Transformer) is an abstraction which includes feature transformers and learned models. Technically, a `Transformer` implements a method `transform()` which converts one `SchemaRDD` into another, generally by appending one or more columns. +A [`Transformer`](api/scala/index.html#org.apache.spark.ml.Transformer) is an abstraction which includes feature transformers and learned models. Technically, a `Transformer` implements a method `transform()` which converts one `DataFrame` into another, generally by appending one or more columns. For example: * A feature transformer might take a dataset, read a column (e.g., text), convert it into a new column (e.g., feature vectors), append the new column to the dataset, and output the updated dataset. @@ -60,7 +60,7 @@ For example: ### Estimators -An [`Estimator`](api/scala/index.html#org.apache.spark.ml.Estimator) abstracts the concept of a learning algorithm or any algorithm which fits or trains on data. Technically, an `Estimator` implements a method `fit()` which accepts a `SchemaRDD` and produces a `Transformer`. +An [`Estimator`](api/scala/index.html#org.apache.spark.ml.Estimator) abstracts the concept of a learning algorithm or any algorithm which fits or trains on data. Technically, an `Estimator` implements a method `fit()` which accepts a `DataFrame` and produces a `Transformer`. For example, a learning algorithm such as `LogisticRegression` is an `Estimator`, and calling `fit()` trains a `LogisticRegressionModel`, which is a `Transformer`. ### Properties of ML Algorithms @@ -101,7 +101,7 @@ We illustrate this for the simple text document workflow. The figure below is f Above, the top row represents a `Pipeline` with three stages. The first two (`Tokenizer` and `HashingTF`) are `Transformer`s (blue), and the third (`LogisticRegression`) is an `Estimator` (red). -The bottom row represents data flowing through the pipeline, where cylinders indicate `SchemaRDD`s. +The bottom row represents data flowing through the pipeline, where cylinders indicate `DataFrame`s. The `Pipeline.fit()` method is called on the original dataset which has raw text documents and labels. The `Tokenizer.transform()` method splits the raw text documents into words, adding a new column with words into the dataset. The `HashingTF.transform()` method converts the words column into feature vectors, adding a new column with those vectors to the dataset. @@ -130,7 +130,7 @@ Each stage's `transform()` method updates the dataset and passes it to the next *DAG `Pipeline`s*: A `Pipeline`'s stages are specified as an ordered array. The examples given here are all for linear `Pipeline`s, i.e., `Pipeline`s in which each stage uses data produced by the previous stage. It is possible to create non-linear `Pipeline`s as long as the data flow graph forms a Directed Acyclic Graph (DAG). This graph is currently specified implicitly based on the input and output column names of each stage (generally specified as parameters). If the `Pipeline` forms a DAG, then the stages must be specified in topological order. -*Runtime checking*: Since `Pipeline`s can operate on datasets with varied types, they cannot use compile-time type checking. `Pipeline`s and `PipelineModel`s instead do runtime checking before actually running the `Pipeline`. This type checking is done using the dataset *schema*, a description of the data types of columns in the `SchemaRDD`. +*Runtime checking*: Since `Pipeline`s can operate on datasets with varied types, they cannot use compile-time type checking. `Pipeline`s and `PipelineModel`s instead do runtime checking before actually running the `Pipeline`. This type checking is done using the dataset *schema*, a description of the data types of columns in the `DataFrame`. ## Parameters @@ -171,12 +171,12 @@ import org.apache.spark.sql.{Row, SQLContext} val conf = new SparkConf().setAppName("SimpleParamsExample") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) -import sqlContext._ +import sqlContext.implicits._ // Prepare training data. // We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes -// into SchemaRDDs, where it uses the case class metadata to infer the schema. -val training = sparkContext.parallelize(Seq( +// into DataFrames, where it uses the case class metadata to infer the schema. +val training = sc.parallelize(Seq( LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), @@ -192,7 +192,7 @@ lr.setMaxIter(10) .setRegParam(0.01) // Learn a LogisticRegression model. This uses the parameters stored in lr. -val model1 = lr.fit(training) +val model1 = lr.fit(training.toDF) // Since model1 is a Model (i.e., a Transformer produced by an Estimator), // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this @@ -203,33 +203,35 @@ println("Model 1 was fit using parameters: " + model1.fittingParamMap) // which supports several methods for specifying parameters. val paramMap = ParamMap(lr.maxIter -> 20) paramMap.put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. -paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.5) // Specify multiple Params. +paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. // One can also combine ParamMaps. -val paramMap2 = ParamMap(lr.scoreCol -> "probability") // Changes output column name. +val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name val paramMapCombined = paramMap ++ paramMap2 // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. -val model2 = lr.fit(training, paramMapCombined) +val model2 = lr.fit(training.toDF, paramMapCombined) println("Model 2 was fit using parameters: " + model2.fittingParamMap) -// Prepare test documents. -val test = sparkContext.parallelize(Seq( +// Prepare test data. +val test = sc.parallelize(Seq( LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) -// Make predictions on test documents using the Transformer.transform() method. +// Make predictions on test data using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. -// Note that model2.transform() outputs a 'probability' column instead of the usual 'score' -// column since we renamed the lr.scoreCol parameter previously. -model2.transform(test) - .select('features, 'label, 'probability, 'prediction) +// Note that model2.transform() outputs a 'myProbability' column instead of the usual +// 'probability' column since we renamed the lr.probabilityCol parameter previously. +model2.transform(test.toDF) + .select("features", "label", "myProbability", "prediction") .collect() - .foreach { case Row(features: Vector, label: Double, prob: Double, prediction: Double) => - println("(" + features + ", " + label + ") -> prob=" + prob + ", prediction=" + prediction) + .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) => + println("($features, $label) -> prob=$prob, prediction=$prediction") } + +sc.stop() {% endhighlight %}
@@ -244,23 +246,23 @@ import org.apache.spark.ml.param.ParamMap; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.api.java.JavaSQLContext; -import org.apache.spark.sql.api.java.JavaSchemaRDD; -import org.apache.spark.sql.api.java.Row; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.Row; SparkConf conf = new SparkConf().setAppName("JavaSimpleParamsExample"); JavaSparkContext jsc = new JavaSparkContext(conf); -JavaSQLContext jsql = new JavaSQLContext(jsc); +SQLContext jsql = new SQLContext(jsc); // Prepare training data. -// We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes -// into SchemaRDDs, where it uses the case class metadata to infer the schema. +// We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans +// into DataFrames, where it uses the bean metadata to infer the schema. List localTraining = Lists.newArrayList( new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); -JavaSchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class); +DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class); // Create a LogisticRegression instance. This instance is an Estimator. LogisticRegression lr = new LogisticRegression(); @@ -281,13 +283,13 @@ System.out.println("Model 1 was fit using parameters: " + model1.fittingParamMap // We may alternatively specify parameters using a ParamMap. ParamMap paramMap = new ParamMap(); -paramMap.put(lr.maxIter(), 20); // Specify 1 Param. +paramMap.put(lr.maxIter().w(20)); // Specify 1 Param. paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter. -paramMap.put(lr.regParam(), 0.1); +paramMap.put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params. // One can also combine ParamMaps. ParamMap paramMap2 = new ParamMap(); -paramMap2.put(lr.scoreCol(), "probability"); // Changes output column name. +paramMap2.put(lr.probabilityCol().w("myProbability")); // Change output column name ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2); // Now learn a new model using the paramMapCombined parameters. @@ -300,19 +302,19 @@ List localTest = Lists.newArrayList( new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); -JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class); +DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); // Make predictions on test documents using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. -// Note that model2.transform() outputs a 'probability' column instead of the usual 'score' -// column since we renamed the lr.scoreCol parameter previously. -model2.transform(test).registerAsTable("results"); -JavaSchemaRDD results = - jsql.sql("SELECT features, label, probability, prediction FROM results"); -for (Row r: results.collect()) { +// Note that model2.transform() outputs a 'myProbability' column instead of the usual +// 'probability' column since we renamed the lr.probabilityCol parameter previously. +DataFrame results = model2.transform(test); +for (Row r: results.select("features", "label", "myProbability", "prediction").collect()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2) + ", prediction=" + r.get(3)); } + +jsc.stop(); {% endhighlight %} @@ -330,6 +332,7 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.feature.{HashingTF, Tokenizer} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{Row, SQLContext} // Labeled and unlabeled instance types. @@ -337,14 +340,14 @@ import org.apache.spark.sql.{Row, SQLContext} case class LabeledDocument(id: Long, text: String, label: Double) case class Document(id: Long, text: String) -// Set up contexts. Import implicit conversions to SchemaRDD from sqlContext. +// Set up contexts. Import implicit conversions to DataFrame from sqlContext. val conf = new SparkConf().setAppName("SimpleTextClassificationPipeline") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) -import sqlContext._ +import sqlContext.implicits._ // Prepare training documents, which are labeled. -val training = sparkContext.parallelize(Seq( +val training = sc.parallelize(Seq( LabeledDocument(0L, "a b c d e spark", 1.0), LabeledDocument(1L, "b d", 0.0), LabeledDocument(2L, "spark f g h", 1.0), @@ -365,30 +368,32 @@ val pipeline = new Pipeline() .setStages(Array(tokenizer, hashingTF, lr)) // Fit the pipeline to training documents. -val model = pipeline.fit(training) +val model = pipeline.fit(training.toDF) // Prepare test documents, which are unlabeled. -val test = sparkContext.parallelize(Seq( +val test = sc.parallelize(Seq( Document(4L, "spark i j k"), Document(5L, "l m n"), Document(6L, "mapreduce spark"), Document(7L, "apache hadoop"))) // Make predictions on test documents. -model.transform(test) - .select('id, 'text, 'score, 'prediction) +model.transform(test.toDF) + .select("id", "text", "probability", "prediction") .collect() - .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => - println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) + .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => + println("($id, $text) --> prob=$prob, prediction=$prediction") } + +sc.stop() {% endhighlight %}
{% highlight java %} -import java.io.Serializable; import java.util.List; import com.google.common.collect.Lists; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineModel; @@ -396,45 +401,44 @@ import org.apache.spark.ml.PipelineStage; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.feature.HashingTF; import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.sql.api.java.JavaSQLContext; -import org.apache.spark.sql.api.java.JavaSchemaRDD; -import org.apache.spark.sql.api.java.Row; -import org.apache.spark.SparkConf; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; // Labeled and unlabeled instance types. // Spark SQL can infer schema from Java Beans. public class Document implements Serializable { - private Long id; + private long id; private String text; - public Document(Long id, String text) { + public Document(long id, String text) { this.id = id; this.text = text; } - public Long getId() { return this.id; } - public void setId(Long id) { this.id = id; } + public long getId() { return this.id; } + public void setId(long id) { this.id = id; } public String getText() { return this.text; } public void setText(String text) { this.text = text; } } public class LabeledDocument extends Document implements Serializable { - private Double label; + private double label; - public LabeledDocument(Long id, String text, Double label) { + public LabeledDocument(long id, String text, double label) { super(id, text); this.label = label; } - public Double getLabel() { return this.label; } - public void setLabel(Double label) { this.label = label; } + public double getLabel() { return this.label; } + public void setLabel(double label) { this.label = label; } } // Set up contexts. SparkConf conf = new SparkConf().setAppName("JavaSimpleTextClassificationPipeline"); JavaSparkContext jsc = new JavaSparkContext(conf); -JavaSQLContext jsql = new JavaSQLContext(jsc); +SQLContext jsql = new SQLContext(jsc); // Prepare training documents, which are labeled. List localTraining = Lists.newArrayList( @@ -442,8 +446,7 @@ List localTraining = Lists.newArrayList( new LabeledDocument(1L, "b d", 0.0), new LabeledDocument(2L, "spark f g h", 1.0), new LabeledDocument(3L, "hadoop mapreduce", 0.0)); -JavaSchemaRDD training = - jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); +DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -468,16 +471,62 @@ List localTest = Lists.newArrayList( new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); -JavaSchemaRDD test = - jsql.applySchema(jsc.parallelize(localTest), Document.class); +DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. -model.transform(test).registerAsTable("prediction"); -JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); -for (Row r: predictions.collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) +DataFrame predictions = model.transform(test); +for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + ", prediction=" + r.get(3)); } + +jsc.stop(); +{% endhighlight %} +
+ +
+{% highlight python %} +from pyspark import SparkContext +from pyspark.ml import Pipeline +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.feature import HashingTF, Tokenizer +from pyspark.sql import Row, SQLContext + +sc = SparkContext(appName="SimpleTextClassificationPipeline") +sqlContext = SQLContext(sc) + +# Prepare training documents, which are labeled. +LabeledDocument = Row("id", "text", "label") +training = sc.parallelize([(0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0)]) \ + .map(lambda x: LabeledDocument(*x)).toDF() + +# Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. +tokenizer = Tokenizer(inputCol="text", outputCol="words") +hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") +lr = LogisticRegression(maxIter=10, regParam=0.01) +pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) + +# Fit the pipeline to training documents. +model = pipeline.fit(training) + +# Prepare test documents, which are unlabeled. +Document = Row("id", "text") +test = sc.parallelize([(4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop")]) \ + .map(lambda x: Document(*x)).toDF() + +# Make predictions on test documents and print columns of interest. +prediction = model.transform(test) +selected = prediction.select("id", "text", "prediction") +for row in selected.collect(): + print row + +sc.stop() {% endhighlight %}
@@ -508,21 +557,26 @@ However, it is also a well-established method for choosing parameters which is m
{% highlight scala %} import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.ml.feature.{HashingTF, Tokenizer} import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{Row, SQLContext} +// Labeled and unlabeled instance types. +// Spark SQL can infer schema from case classes. +case class LabeledDocument(id: Long, text: String, label: Double) +case class Document(id: Long, text: String) + val conf = new SparkConf().setAppName("CrossValidatorExample") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) -import sqlContext._ +import sqlContext.implicits._ // Prepare training documents, which are labeled. -val training = sparkContext.parallelize(Seq( +val training = sc.parallelize(Seq( LabeledDocument(0L, "a b c d e spark", 1.0), LabeledDocument(1L, "b d", 0.0), LabeledDocument(2L, "spark f g h", 1.0), @@ -565,24 +619,24 @@ crossval.setEstimatorParamMaps(paramGrid) crossval.setNumFolds(2) // Use 3+ in practice // Run cross-validation, and choose the best set of parameters. -val cvModel = crossval.fit(training) -// Get the best LogisticRegression model (with the best set of parameters from paramGrid). -val lrModel = cvModel.bestModel +val cvModel = crossval.fit(training.toDF) // Prepare test documents, which are unlabeled. -val test = sparkContext.parallelize(Seq( +val test = sc.parallelize(Seq( Document(4L, "spark i j k"), Document(5L, "l m n"), Document(6L, "mapreduce spark"), Document(7L, "apache hadoop"))) // Make predictions on test documents. cvModel uses the best model found (lrModel). -cvModel.transform(test) - .select('id, 'text, 'score, 'prediction) +cvModel.transform(test.toDF) + .select("id", "text", "probability", "prediction") .collect() - .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => - println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) + .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => + println(s"($id, $text) --> prob=$prob, prediction=$prediction") } + +sc.stop() {% endhighlight %}
@@ -592,7 +646,6 @@ import java.util.List; import com.google.common.collect.Lists; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.Model; import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineStage; import org.apache.spark.ml.classification.LogisticRegression; @@ -603,13 +656,43 @@ import org.apache.spark.ml.param.ParamMap; import org.apache.spark.ml.tuning.CrossValidator; import org.apache.spark.ml.tuning.CrossValidatorModel; import org.apache.spark.ml.tuning.ParamGridBuilder; -import org.apache.spark.sql.api.java.JavaSQLContext; -import org.apache.spark.sql.api.java.JavaSchemaRDD; -import org.apache.spark.sql.api.java.Row; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; + +// Labeled and unlabeled instance types. +// Spark SQL can infer schema from Java Beans. +public class Document implements Serializable { + private long id; + private String text; + + public Document(long id, String text) { + this.id = id; + this.text = text; + } + + public long getId() { return this.id; } + public void setId(long id) { this.id = id; } + + public String getText() { return this.text; } + public void setText(String text) { this.text = text; } +} + +public class LabeledDocument extends Document implements Serializable { + private double label; + + public LabeledDocument(long id, String text, double label) { + super(id, text); + this.label = label; + } + + public double getLabel() { return this.label; } + public void setLabel(double label) { this.label = label; } +} SparkConf conf = new SparkConf().setAppName("JavaCrossValidatorExample"); JavaSparkContext jsc = new JavaSparkContext(conf); -JavaSQLContext jsql = new JavaSQLContext(jsc); +SQLContext jsql = new SQLContext(jsc); // Prepare training documents, which are labeled. List localTraining = Lists.newArrayList( @@ -625,8 +708,7 @@ List localTraining = Lists.newArrayList( new LabeledDocument(9L, "a e c l", 0.0), new LabeledDocument(10L, "spark compile", 1.0), new LabeledDocument(11L, "hadoop software", 0.0)); -JavaSchemaRDD training = - jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); +DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -660,8 +742,6 @@ crossval.setNumFolds(2); // Use 3+ in practice // Run cross-validation, and choose the best set of parameters. CrossValidatorModel cvModel = crossval.fit(training); -// Get the best LogisticRegression model (with the best set of parameters from paramGrid). -Model lrModel = cvModel.bestModel(); // Prepare test documents, which are unlabeled. List localTest = Lists.newArrayList( @@ -669,15 +749,16 @@ List localTest = Lists.newArrayList( new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); -JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class); +DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. cvModel uses the best model found (lrModel). -cvModel.transform(test).registerAsTable("prediction"); -JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); -for (Row r: predictions.collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) +DataFrame predictions = cvModel.transform(test); +for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + ", prediction=" + r.get(3)); } + +jsc.stop(); {% endhighlight %} @@ -686,6 +767,21 @@ for (Row r: predictions.collect()) { # Dependencies Spark ML currently depends on MLlib and has the same dependencies. -Please see the [MLlib Dependencies guide](mllib-guide.html#Dependencies) for more info. +Please see the [MLlib Dependencies guide](mllib-guide.html#dependencies) for more info. Spark ML also depends upon Spark SQL, but the relevant parts of Spark SQL do not bring additional dependencies. + +# Migration Guide + +## From 1.2 to 1.3 + +The main API changes are from Spark SQL. We list the most important changes here: + +* The old [SchemaRDD](http://spark.apache.org/docs/1.2.1/api/scala/index.html#org.apache.spark.sql.SchemaRDD) has been replaced with [DataFrame](api/scala/index.html#org.apache.spark.sql.DataFrame) with a somewhat modified API. All algorithms in Spark ML which used to use SchemaRDD now use DataFrame. +* In Spark 1.2, we used implicit conversions from `RDD`s of `LabeledPoint` into `SchemaRDD`s by calling `import sqlContext._` where `sqlContext` was an instance of `SQLContext`. These implicits have been moved, so we now call `import sqlContext.implicits._`. +* Java APIs for SQL have also changed accordingly. Please see the examples above and the [Spark SQL Programming Guide](sql-programming-guide.html) for details. + +Other changes were in `LogisticRegression`: + +* The `scoreCol` output column (with default value "score") was renamed to be `probabilityCol` (with default value "probability"). The type was originally `Double` (for the probability of class 1.0), but it is now `Vector` (for the probability of each class, to support multiclass classification in the future). +* In Spark 1.2, `LogisticRegressionModel` did not include an intercept. In Spark 1.3, it includes an intercept; however, it will always be 0.0 since it uses the default settings for [spark.mllib.LogisticRegressionWithLBFGS](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS). The option to use an intercept will be added in the future. diff --git a/docs/mllib-classification-regression.md b/docs/mllib-classification-regression.md index 719cc95767b00..8e91d62f4a907 100644 --- a/docs/mllib-classification-regression.md +++ b/docs/mllib-classification-regression.md @@ -17,13 +17,13 @@ the supported algorithms for each type of problem. - Binary Classificationlinear SVMs, logistic regression, decision trees, naive Bayes + Binary Classificationlinear SVMs, logistic regression, decision trees, random forests, gradient-boosted trees, naive Bayes - Multiclass Classificationdecision trees, naive Bayes + Multiclass Classificationdecision trees, random forests, naive Bayes - Regressionlinear least squares, Lasso, ridge regression, decision trees + Regressionlinear least squares, Lasso, ridge regression, decision trees, random forests, gradient-boosted trees, isotonic regression @@ -34,4 +34,8 @@ More details for these methods can be found here: * [binary classification (SVMs, logistic regression)](mllib-linear-methods.html#binary-classification) * [linear regression (least squares, Lasso, ridge)](mllib-linear-methods.html#linear-least-squares-lasso-and-ridge-regression) * [Decision trees](mllib-decision-tree.html) +* [Ensembles of decision trees](mllib-ensembles.html) + * [random forests](mllib-ensembles.html#random-forests) + * [gradient-boosted trees](mllib-ensembles.html#gradient-boosted-trees-gbts) * [Naive Bayes](mllib-naive-bayes.html) +* [Isotonic regression](mllib-isotonic-regression.html) diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 413b824e369da..0b6db4fcb7b1f 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -4,25 +4,25 @@ title: Clustering - MLlib displayTitle: MLlib - Clustering --- -* Table of contents -{:toc} - - -## Clustering - Clustering is an unsupervised learning problem whereby we aim to group subsets of entities with one another based on some notion of similarity. Clustering is often used for exploratory analysis and/or as a component of a hierarchical supervised learning pipeline (in which distinct classifiers or regression -models are trained for each cluster). +models are trained for each cluster). + +MLlib supports the following models: -MLlib supports -[k-means](http://en.wikipedia.org/wiki/K-means_clustering) clustering, one of -the most commonly used clustering algorithms that clusters the data points into +* Table of contents +{:toc} + +## K-means + +[k-means](http://en.wikipedia.org/wiki/K-means_clustering) is one of the +most commonly used clustering algorithms that clusters the data points into a predefined number of clusters. The MLlib implementation includes a parallelized variant of the [k-means++](http://en.wikipedia.org/wiki/K-means%2B%2B) method called [kmeans||](http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf). -The implementation in MLlib has the following parameters: +The implementation in MLlib has the following parameters: * *k* is the number of desired clusters. * *maxIterations* is the maximum number of iterations to run. @@ -32,29 +32,9 @@ initialization via k-means\|\|. guaranteed to find a globally optimal solution, and when run multiple times on a given dataset, the algorithm returns the best clustering result). * *initializationSteps* determines the number of steps in the k-means\|\| algorithm. -* *epsilon* determines the distance threshold within which we consider k-means to have converged. - -### Power Iteration Clustering +* *epsilon* determines the distance threshold within which we consider k-means to have converged. -Power iteration clustering is a scalable and efficient algorithm for clustering points given pointwise mutual affinity values. Internally the algorithm: - -* accepts a [Graph](https://spark.apache.org/docs/0.9.2/api/graphx/index.html#org.apache.spark.graphx.Graph) that represents a normalized pairwise affinity between all input points. -* calculates the principal eigenvalue and eigenvector -* Clusters each of the input points according to their principal eigenvector component value - -Details of this algorithm are found within [Power Iteration Clustering, Lin and Cohen]{www.icml2010.org/papers/387.pdf} - -Example outputs for a dataset inspired by the paper - but with five clusters instead of three- have he following output from our implementation: - -

- The Property Graph - -

- -### Examples +**Examples**
@@ -168,41 +148,370 @@ print("Within Set Sum of Squared Error = " + str(WSSSE))
-In order to run the above application, follow the instructions -provided in the [Self-Contained Applications](quick-start.html#self-contained-applications) -section of the Spark -Quick Start guide. Be sure to also include *spark-mllib* to your build file as -a dependency. +## Gaussian mixture + +A [Gaussian Mixture Model](http://en.wikipedia.org/wiki/Mixture_model#Multivariate_Gaussian_mixture_model) +represents a composite distribution whereby points are drawn from one of *k* Gaussian sub-distributions, +each with its own probability. The MLlib implementation uses the +[expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) + algorithm to induce the maximum-likelihood model given a set of samples. The implementation +has the following parameters: + +* *k* is the number of desired clusters. +* *convergenceTol* is the maximum change in log-likelihood at which we consider convergence achieved. +* *maxIterations* is the maximum number of iterations to perform without reaching convergence. +* *initialModel* is an optional starting point from which to start the EM algorithm. If this parameter is omitted, a random starting point will be constructed from the data. + +**Examples** + +
+
+In the following example after loading and parsing data, we use a +[GaussianMixture](api/scala/index.html#org.apache.spark.mllib.clustering.GaussianMixture) +object to cluster the data into two clusters. The number of desired clusters is passed +to the algorithm. We then output the parameters of the mixture model. + +{% highlight scala %} +import org.apache.spark.mllib.clustering.GaussianMixture +import org.apache.spark.mllib.linalg.Vectors + +// Load and parse the data +val data = sc.textFile("data/mllib/gmm_data.txt") +val parsedData = data.map(s => Vectors.dense(s.trim.split(' ').map(_.toDouble))).cache() + +// Cluster the data into two classes using GaussianMixture +val gmm = new GaussianMixture().setK(2).run(parsedData) + +// output parameters of max-likelihood model +for (i <- 0 until gmm.k) { + println("weight=%f\nmu=%s\nsigma=\n%s\n" format + (gmm.weights(i), gmm.gaussians(i).mu, gmm.gaussians(i).sigma)) +} + +{% endhighlight %} +
+ +
+All of MLlib's methods use Java-friendly types, so you can import and call them there the same +way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the +Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by +calling `.rdd()` on your `JavaRDD` object. A self-contained application example +that is equivalent to the provided example in Scala is given below: + +{% highlight java %} +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.clustering.GaussianMixture; +import org.apache.spark.mllib.clustering.GaussianMixtureModel; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.SparkConf; + +public class GaussianMixtureExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("GaussianMixture Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + + // Load and parse data + String path = "data/mllib/gmm_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public Vector call(String s) { + String[] sarray = s.trim().split(" "); + double[] values = new double[sarray.length]; + for (int i = 0; i < sarray.length; i++) + values[i] = Double.parseDouble(sarray[i]); + return Vectors.dense(values); + } + } + ); + parsedData.cache(); + + // Cluster the data into two classes using GaussianMixture + GaussianMixtureModel gmm = new GaussianMixture().setK(2).run(parsedData.rdd()); + + // Output the parameters of the mixture model + for(int j=0; j + +
+In the following example after loading and parsing data, we use a +[GaussianMixture](api/python/pyspark.mllib.html#pyspark.mllib.clustering.GaussianMixture) +object to cluster the data into two clusters. The number of desired clusters is passed +to the algorithm. We then output the parameters of the mixture model. -## Streaming clustering +{% highlight python %} +from pyspark.mllib.clustering import GaussianMixture +from numpy import array + +# Load and parse the data +data = sc.textFile("data/mllib/gmm_data.txt") +parsedData = data.map(lambda line: array([float(x) for x in line.strip().split(' ')])) + +# Build the model (cluster the data) +gmm = GaussianMixture.train(parsedData, 2) -When data arrive in a stream, we may want to estimate clusters dynamically, -updating them as new data arrive. MLlib provides support for streaming k-means clustering, -with parameters to control the decay (or "forgetfulness") of the estimates. The algorithm -uses a generalization of the mini-batch k-means update rule. For each batch of data, we assign +# output parameters of model +for i in range(2): + print ("weight = ", gmm.weights[i], "mu = ", gmm.gaussians[i].mu, + "sigma = ", gmm.gaussians[i].sigma.toArray()) + +{% endhighlight %} +
+ +
+ +## Power iteration clustering (PIC) + +Power iteration clustering (PIC) is a scalable and efficient algorithm for clustering vertices of a +graph given pairwise similarties as edge properties, +described in [Lin and Cohen, Power Iteration Clustering](http://www.icml2010.org/papers/387.pdf). +It computes a pseudo-eigenvector of the normalized affinity matrix of the graph via +[power iteration](http://en.wikipedia.org/wiki/Power_iteration) and uses it to cluster vertices. +MLlib includes an implementation of PIC using GraphX as its backend. +It takes an `RDD` of `(srcId, dstId, similarity)` tuples and outputs a model with the clustering assignments. +The similarities must be nonnegative. +PIC assumes that the similarity measure is symmetric. +A pair `(srcId, dstId)` regardless of the ordering should appear at most once in the input data. +If a pair is missing from input, their similarity is treated as zero. +MLlib's PIC implementation takes the following (hyper-)parameters: + +* `k`: number of clusters +* `maxIterations`: maximum number of power iterations +* `initializationMode`: initialization model. This can be either "random", which is the default, + to use a random vector as vertex properties, or "degree" to use normalized sum similarities. + +**Examples** + +In the following, we show code snippets to demonstrate how to use PIC in MLlib. + +
+
+ +[`PowerIterationClustering`](api/scala/index.html#org.apache.spark.mllib.clustering.PowerIterationClustering) +implements the PIC algorithm. +It takes an `RDD` of `(srcId: Long, dstId: Long, similarity: Double)` tuples representing the +affinity matrix. +Calling `PowerIterationClustering.run` returns a +[`PowerIterationClusteringModel`](api/scala/index.html#org.apache.spark.mllib.clustering.PowerIterationClusteringModel), +which contains the computed clustering assignments. + +{% highlight scala %} +import org.apache.spark.mllib.clustering.PowerIterationClustering +import org.apache.spark.mllib.linalg.Vectors + +val similarities: RDD[(Long, Long, Double)] = ... + +val pic = new PowerIteartionClustering() + .setK(3) + .setMaxIterations(20) +val model = pic.run(similarities) + +model.assignments.foreach { a => + println(s"${a.id} -> ${a.cluster}") +} +{% endhighlight %} + +A full example that produces the experiment described in the PIC paper can be found under +[`examples/`](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala). + +
+ +
+ +[`PowerIterationClustering`](api/java/org/apache/spark/mllib/clustering/PowerIterationClustering.html) +implements the PIC algorithm. +It takes an `JavaRDD` of `(srcId: Long, dstId: Long, similarity: Double)` tuples representing the +affinity matrix. +Calling `PowerIterationClustering.run` returns a +[`PowerIterationClusteringModel`](api/java/org/apache/spark/mllib/clustering/PowerIterationClusteringModel.html) +which contains the computed clustering assignments. + +{% highlight java %} +import scala.Tuple2; +import scala.Tuple3; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.clustering.PowerIterationClustering; +import org.apache.spark.mllib.clustering.PowerIterationClusteringModel; + +JavaRDD> similarities = ... + +PowerIterationClustering pic = new PowerIterationClustering() + .setK(2) + .setMaxIterations(10); +PowerIterationClusteringModel model = pic.run(similarities); + +for (PowerIterationClustering.Assignment a: model.assignments().toJavaRDD().collect()) { + System.out.println(a.id() + " -> " + a.cluster()); +} +{% endhighlight %} +
+ +
+ +## Latent Dirichlet allocation (LDA) + +[Latent Dirichlet allocation (LDA)](http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation) +is a topic model which infers topics from a collection of text documents. +LDA can be thought of as a clustering algorithm as follows: + +* Topics correspond to cluster centers, and documents correspond to examples (rows) in a dataset. +* Topics and documents both exist in a feature space, where feature vectors are vectors of word counts. +* Rather than estimating a clustering using a traditional distance, LDA uses a function based + on a statistical model of how text documents are generated. + +LDA takes in a collection of documents as vectors of word counts. +It learns clustering using [expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) +on the likelihood function. After fitting on the documents, LDA provides: + +* Topics: Inferred topics, each of which is a probability distribution over terms (words). +* Topic distributions for documents: For each document in the training set, LDA gives a probability distribution over topics. + +LDA takes the following parameters: + +* `k`: Number of topics (i.e., cluster centers) +* `maxIterations`: Limit on the number of iterations of EM used for learning +* `docConcentration`: Hyperparameter for prior over documents' distributions over topics. Currently must be > 1, where larger values encourage smoother inferred distributions. +* `topicConcentration`: Hyperparameter for prior over topics' distributions over terms (words). Currently must be > 1, where larger values encourage smoother inferred distributions. +* `checkpointInterval`: If using checkpointing (set in the Spark configuration), this parameter specifies the frequency with which checkpoints will be created. If `maxIterations` is large, using checkpointing can help reduce shuffle file sizes on disk and help with failure recovery. + +*Note*: LDA is a new feature with some missing functionality. In particular, it does not yet +support prediction on new documents, and it does not have a Python API. These will be added in the future. + +**Examples** + +In the following example, we load word count vectors representing a corpus of documents. +We then use [LDA](api/scala/index.html#org.apache.spark.mllib.clustering.LDA) +to infer three topics from the documents. The number of desired clusters is passed +to the algorithm. We then output the topics, represented as probability distributions over words. + +
+
+ +{% highlight scala %} +import org.apache.spark.mllib.clustering.LDA +import org.apache.spark.mllib.linalg.Vectors + +// Load and parse the data +val data = sc.textFile("data/mllib/sample_lda_data.txt") +val parsedData = data.map(s => Vectors.dense(s.trim.split(' ').map(_.toDouble))) +// Index documents with unique IDs +val corpus = parsedData.zipWithIndex.map(_.swap).cache() + +// Cluster the documents into three topics using LDA +val ldaModel = new LDA().setK(3).run(corpus) + +// Output topics. Each is a distribution over words (matching word count vectors) +println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize + " words):") +val topics = ldaModel.topicsMatrix +for (topic <- Range(0, 3)) { + print("Topic " + topic + ":") + for (word <- Range(0, ldaModel.vocabSize)) { print(" " + topics(word, topic)); } + println() +} +{% endhighlight %} +
+ +
+{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.clustering.DistributedLDAModel; +import org.apache.spark.mllib.clustering.LDA; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.SparkConf; + +public class JavaLDAExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("LDA Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + + // Load and parse the data + String path = "data/mllib/sample_lda_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public Vector call(String s) { + String[] sarray = s.trim().split(" "); + double[] values = new double[sarray.length]; + for (int i = 0; i < sarray.length; i++) + values[i] = Double.parseDouble(sarray[i]); + return Vectors.dense(values); + } + } + ); + // Index documents with unique IDs + JavaPairRDD corpus = JavaPairRDD.fromJavaRDD(parsedData.zipWithIndex().map( + new Function, Tuple2>() { + public Tuple2 call(Tuple2 doc_id) { + return doc_id.swap(); + } + } + )); + corpus.cache(); + + // Cluster the documents into three topics using LDA + DistributedLDAModel ldaModel = new LDA().setK(3).run(corpus); + + // Output topics. Each is a distribution over words (matching word count vectors) + System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize() + + " words):"); + Matrix topics = ldaModel.topicsMatrix(); + for (int topic = 0; topic < 3; topic++) { + System.out.print("Topic " + topic + ":"); + for (int word = 0; word < ldaModel.vocabSize(); word++) { + System.out.print(" " + topics.apply(word, topic)); + } + System.out.println(); + } + } +} +{% endhighlight %} +
+ +
+ +## Streaming k-means + +When data arrive in a stream, we may want to estimate clusters dynamically, +updating them as new data arrive. MLlib provides support for streaming k-means clustering, +with parameters to control the decay (or "forgetfulness") of the estimates. The algorithm +uses a generalization of the mini-batch k-means update rule. For each batch of data, we assign all points to their nearest cluster, compute new cluster centers, then update each cluster using: `\begin{equation} c_{t+1} = \frac{c_tn_t\alpha + x_tm_t}{n_t\alpha+m_t} \end{equation}` `\begin{equation} - n_{t+1} = n_t + m_t + n_{t+1} = n_t + m_t \end{equation}` -Where `$c_t$` is the previous center for the cluster, `$n_t$` is the number of points assigned -to the cluster thus far, `$x_t$` is the new cluster center from the current batch, and `$m_t$` -is the number of points added to the cluster in the current batch. The decay factor `$\alpha$` -can be used to ignore the past: with `$\alpha$=1` all data will be used from the beginning; -with `$\alpha$=0` only the most recent data will be used. This is analogous to an -exponentially-weighted moving average. +Where `$c_t$` is the previous center for the cluster, `$n_t$` is the number of points assigned +to the cluster thus far, `$x_t$` is the new cluster center from the current batch, and `$m_t$` +is the number of points added to the cluster in the current batch. The decay factor `$\alpha$` +can be used to ignore the past: with `$\alpha$=1` all data will be used from the beginning; +with `$\alpha$=0` only the most recent data will be used. This is analogous to an +exponentially-weighted moving average. -The decay can be specified using a `halfLife` parameter, which determines the +The decay can be specified using a `halfLife` parameter, which determines the correct decay factor `a` such that, for data acquired at time `t`, its contribution by time `t + halfLife` will have dropped to 0.5. The unit of time can be specified either as `batches` or `points` and the update rule will be adjusted accordingly. -### Examples +**Examples** This example shows how to estimate clusters on streaming data. @@ -220,9 +529,9 @@ import org.apache.spark.mllib.clustering.StreamingKMeans {% endhighlight %} -Then we make an input stream of vectors for training, as well as a stream of labeled data -points for testing. We assume a StreamingContext `ssc` has been created, see -[Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info. +Then we make an input stream of vectors for training, as well as a stream of labeled data +points for testing. We assume a StreamingContext `ssc` has been created, see +[Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info. {% highlight scala %} @@ -244,24 +553,24 @@ val model = new StreamingKMeans() {% endhighlight %} -Now register the streams for training and testing and start the job, printing +Now register the streams for training and testing and start the job, printing the predicted cluster assignments on new data points as they arrive. {% highlight scala %} model.trainOn(trainingData) -model.predictOnValues(testData).print() +model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() ssc.start() ssc.awaitTermination() - + {% endhighlight %} -As you add new text files with data the cluster centers will update. Each training +As you add new text files with data the cluster centers will update. Each training point should be formatted as `[x1, x2, x3]`, and each test data point -should be formatted as `(y, [x1, x2, x3])`, where `y` is some useful label or identifier -(e.g. a true category assignment). Anytime a text file is placed in `/training/data/dir` -the model will update. Anytime a text file is placed in `/testing/data/dir` +should be formatted as `(y, [x1, x2, x3])`, where `y` is some useful label or identifier +(e.g. a true category assignment). Anytime a text file is placed in `/training/data/dir` +the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. With new data, the cluster centers will change!
diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index ef18cec9371d6..76140282a2dd0 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -66,6 +66,7 @@ recommendation model by measuring the Mean Squared Error of rating prediction. {% highlight scala %} import org.apache.spark.mllib.recommendation.ALS +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel import org.apache.spark.mllib.recommendation.Rating // Load and parse the data @@ -95,6 +96,10 @@ val MSE = ratesAndPreds.map { case ((user, product), (r1, r2)) => err * err }.mean() println("Mean Squared Error = " + MSE) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = MatrixFactorizationModel.load(sc, "myModelPath") {% endhighlight %} If the rating matrix is derived from another source of information (e.g., it is inferred from @@ -181,6 +186,10 @@ public class CollaborativeFiltering { } ).rdd()).mean(); System.out.println("Mean Squared Error = " + MSE); + + // Save and load model + model.save(sc.sc(), "myModelPath"); + MatrixFactorizationModel sameModel = MatrixFactorizationModel.load(sc.sc(), "myModelPath"); } } {% endhighlight %} @@ -192,7 +201,7 @@ We use the default ALS.train() method which assumes ratings are explicit. We eva recommendation by measuring the Mean Squared Error of rating prediction. {% highlight python %} -from pyspark.mllib.recommendation import ALS, Rating +from pyspark.mllib.recommendation import ALS, MatrixFactorizationModel, Rating # Load and parse the data data = sc.textFile("data/mllib/als/test.data") @@ -209,6 +218,10 @@ predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2])) ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions) MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y) / ratesAndPreds.count() print("Mean Squared Error = " + str(MSE)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = MatrixFactorizationModel.load(sc, "myModelPath") {% endhighlight %} If the rating matrix is derived from other source of information (i.e., it is inferred from other diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index 101dc2f8695f3..4f2a2f71048f7 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -78,13 +78,13 @@ MLlib recognizes the following types as dense vectors: and the following as sparse vectors: -* MLlib's [`SparseVector`](api/python/pyspark.mllib.linalg.SparseVector-class.html). +* MLlib's [`SparseVector`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.SparseVector). * SciPy's [`csc_matrix`](http://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csc_matrix.html#scipy.sparse.csc_matrix) with a single column We recommend using NumPy arrays over lists for efficiency, and using the factory methods implemented -in [`Vectors`](api/python/pyspark.mllib.linalg.Vectors-class.html) to create sparse vectors. +in [`Vectors`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Vector) to create sparse vectors. {% highlight python %} import numpy as np @@ -151,7 +151,7 @@ LabeledPoint neg = new LabeledPoint(1.0, Vectors.sparse(3, new int[] {0, 2}, new
A labeled point is represented by -[`LabeledPoint`](api/python/pyspark.mllib.regression.LabeledPoint-class.html). +[`LabeledPoint`](api/python/pyspark.mllib.html#pyspark.mllib.regression.LabeledPoint). {% highlight python %} from pyspark.mllib.linalg import SparseVector @@ -211,7 +211,7 @@ JavaRDD examples =
-[`MLUtils.loadLibSVMFile`](api/python/pyspark.mllib.util.MLUtils-class.html) reads training +[`MLUtils.loadLibSVMFile`](api/python/pyspark.mllib.html#pyspark.mllib.util.MLUtils) reads training examples stored in LIBSVM format. {% highlight python %} @@ -296,6 +296,70 @@ backed by an RDD of its entries. The underlying RDDs of a distributed matrix must be deterministic, because we cache the matrix size. In general the use of non-deterministic RDDs can lead to errors. +### BlockMatrix + +A `BlockMatrix` is a distributed matrix backed by an RDD of `MatrixBlock`s, where a `MatrixBlock` is +a tuple of `((Int, Int), Matrix)`, where the `(Int, Int)` is the index of the block, and `Matrix` is +the sub-matrix at the given index with size `rowsPerBlock` x `colsPerBlock`. +`BlockMatrix` supports methods such as `add` and `multiply` with another `BlockMatrix`. +`BlockMatrix` also has a helper function `validate` which can be used to check whether the +`BlockMatrix` is set up properly. + +
+
+ +A [`BlockMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.BlockMatrix) can be +most easily created from an `IndexedRowMatrix` or `CoordinateMatrix` by calling `toBlockMatrix`. +`toBlockMatrix` creates blocks of size 1024 x 1024 by default. +Users may change the block size by supplying the values through `toBlockMatrix(rowsPerBlock, colsPerBlock)`. + +{% highlight scala %} +import org.apache.spark.mllib.linalg.distributed.{BlockMatrix, CoordinateMatrix, MatrixEntry} + +val entries: RDD[MatrixEntry] = ... // an RDD of (i, j, v) matrix entries +// Create a CoordinateMatrix from an RDD[MatrixEntry]. +val coordMat: CoordinateMatrix = new CoordinateMatrix(entries) +// Transform the CoordinateMatrix to a BlockMatrix +val matA: BlockMatrix = coordMat.toBlockMatrix().cache() + +// Validate whether the BlockMatrix is set up properly. Throws an Exception when it is not valid. +// Nothing happens if it is valid. +matA.validate() + +// Calculate A^T A. +val ata = matA.transpose.multiply(matA) +{% endhighlight %} +
+ +
+ +A [`BlockMatrix`](api/java/org/apache/spark/mllib/linalg/distributed/BlockMatrix.html) can be +most easily created from an `IndexedRowMatrix` or `CoordinateMatrix` by calling `toBlockMatrix`. +`toBlockMatrix` creates blocks of size 1024 x 1024 by default. +Users may change the block size by supplying the values through `toBlockMatrix(rowsPerBlock, colsPerBlock)`. + +{% highlight java %} +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.linalg.distributed.BlockMatrix; +import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix; +import org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix; + +JavaRDD entries = ... // a JavaRDD of (i, j, v) Matrix Entries +// Create a CoordinateMatrix from a JavaRDD. +CoordinateMatrix coordMat = new CoordinateMatrix(entries.rdd()); +// Transform the CoordinateMatrix to a BlockMatrix +BlockMatrix matA = coordMat.toBlockMatrix().cache(); + +// Validate whether the BlockMatrix is set up properly. Throws an Exception when it is not valid. +// Nothing happens if it is valid. +matA.validate(); + +// Calculate A^T A. +BlockMatrix ata = matA.transpose().multiply(matA); +{% endhighlight %} +
+
+ ### RowMatrix A `RowMatrix` is a row-oriented distributed matrix without meaningful row indices, backed by an RDD diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index fc8e732251a30..c1d0f8a6b1cd8 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -1,7 +1,7 @@ --- layout: global -title: Decision Tree - MLlib -displayTitle: MLlib - Decision Tree +title: Decision Trees - MLlib +displayTitle: MLlib - Decision Trees --- * Table of contents @@ -54,8 +54,8 @@ impurity measure for regression (variance). Variance Regression - $\frac{1}{N} \sum_{i=1}^{N} (x_i - \mu)^2$$y_i$ is label for an instance, - $N$ is the number of instances and $\mu$ is the mean given by $\frac{1}{N} \sum_{i=1}^N x_i$. + $\frac{1}{N} \sum_{i=1}^{N} (y_i - \mu)^2$$y_i$ is label for an instance, + $N$ is the number of instances and $\mu$ is the mean given by $\frac{1}{N} \sum_{i=1}^N y_i$. @@ -194,6 +194,7 @@ maximum tree depth of 5. The test error is calculated to measure the algorithm a
{% highlight scala %} import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.util.MLUtils // Load and parse the data file. @@ -221,6 +222,10 @@ val labelAndPreds = testData.map { point => val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() println("Test Error = " + testErr) println("Learned classification tree model:\n" + model.toDebugString) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = DecisionTreeModel.load(sc, "myModelPath") {% endhighlight %}
@@ -279,13 +284,18 @@ Double testErr = }).count() / testData.count(); System.out.println("Test Error: " + testErr); System.out.println("Learned classification tree model:\n" + model.toDebugString()); + +// Save and load model +model.save(sc.sc(), "myModelPath"); +DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), "myModelPath"); {% endhighlight %}
+ {% highlight python %} from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.tree import DecisionTree +from pyspark.mllib.tree import DecisionTree, DecisionTreeModel from pyspark.mllib.util import MLUtils # Load and parse the data file into an RDD of LabeledPoint. @@ -305,6 +315,10 @@ testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(tes print('Test Error = ' + str(testErr)) print('Learned classification tree model:') print(model.toDebugString()) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = DecisionTreeModel.load(sc, "myModelPath") {% endhighlight %}
@@ -324,6 +338,7 @@ depth of 5. The Mean Squared Error (MSE) is computed at the end to evaluate
{% highlight scala %} import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.util.MLUtils // Load and parse the data file. @@ -350,6 +365,10 @@ val labelsAndPredictions = testData.map { point => val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() println("Test Mean Squared Error = " + testMSE) println("Learned regression tree model:\n" + model.toDebugString) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = DecisionTreeModel.load(sc, "myModelPath") {% endhighlight %}
@@ -414,13 +433,18 @@ Double testMSE = }) / data.count(); System.out.println("Test Mean Squared Error: " + testMSE); System.out.println("Learned regression tree model:\n" + model.toDebugString()); + +// Save and load model +model.save(sc.sc(), "myModelPath"); +DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), "myModelPath"); {% endhighlight %}
+ {% highlight python %} from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.tree import DecisionTree +from pyspark.mllib.tree import DecisionTree, DecisionTreeModel from pyspark.mllib.util import MLUtils # Load and parse the data file into an RDD of LabeledPoint. @@ -440,6 +464,10 @@ testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / flo print('Test Mean Squared Error = ' + str(testMSE)) print('Learned regression tree model:') print(model.toDebugString()) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = DecisionTreeModel.load(sc, "myModelPath") {% endhighlight %}
diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md index 23ede04b62d5b..84f3b3bf670f3 100644 --- a/docs/mllib-ensembles.md +++ b/docs/mllib-ensembles.md @@ -98,6 +98,7 @@ The test error is calculated to measure the algorithm accuracy.
{% highlight scala %} import org.apache.spark.mllib.tree.RandomForest +import org.apache.spark.mllib.tree.model.RandomForestModel import org.apache.spark.mllib.util.MLUtils // Load and parse the data file. @@ -127,6 +128,10 @@ val labelAndPreds = testData.map { point => val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() println("Test Error = " + testErr) println("Learned classification forest model:\n" + model.toDebugString) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = RandomForestModel.load(sc, "myModelPath") {% endhighlight %}
@@ -188,12 +193,17 @@ Double testErr = }).count() / testData.count(); System.out.println("Test Error: " + testErr); System.out.println("Learned classification forest model:\n" + model.toDebugString()); + +// Save and load model +model.save(sc.sc(), "myModelPath"); +RandomForestModel sameModel = RandomForestModel.load(sc.sc(), "myModelPath"); {% endhighlight %}
+ {% highlight python %} -from pyspark.mllib.tree import RandomForest +from pyspark.mllib.tree import RandomForest, RandomForestModel from pyspark.mllib.util import MLUtils # Load and parse the data file into an RDD of LabeledPoint. @@ -216,6 +226,10 @@ testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(tes print('Test Error = ' + str(testErr)) print('Learned classification forest model:') print(model.toDebugString()) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = RandomForestModel.load(sc, "myModelPath") {% endhighlight %}
@@ -235,6 +249,7 @@ The Mean Squared Error (MSE) is computed at the end to evaluate
{% highlight scala %} import org.apache.spark.mllib.tree.RandomForest +import org.apache.spark.mllib.tree.model.RandomForestModel import org.apache.spark.mllib.util.MLUtils // Load and parse the data file. @@ -264,6 +279,10 @@ val labelsAndPredictions = testData.map { point => val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() println("Test Mean Squared Error = " + testMSE) println("Learned regression forest model:\n" + model.toDebugString) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = RandomForestModel.load(sc, "myModelPath") {% endhighlight %}
@@ -328,12 +347,17 @@ Double testMSE = }) / testData.count(); System.out.println("Test Mean Squared Error: " + testMSE); System.out.println("Learned regression forest model:\n" + model.toDebugString()); + +// Save and load model +model.save(sc.sc(), "myModelPath"); +RandomForestModel sameModel = RandomForestModel.load(sc.sc(), "myModelPath"); {% endhighlight %}
+ {% highlight python %} -from pyspark.mllib.tree import RandomForest +from pyspark.mllib.tree import RandomForest, RandomForestModel from pyspark.mllib.util import MLUtils # Load and parse the data file into an RDD of LabeledPoint. @@ -356,6 +380,10 @@ testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / flo print('Test Mean Squared Error = ' + str(testMSE)) print('Learned regression forest model:') print(model.toDebugString()) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = RandomForestModel.load(sc, "myModelPath") {% endhighlight %}
@@ -430,8 +458,6 @@ We omit some decision tree parameters since those are covered in the [decision t ### Examples -GBTs currently have APIs in Scala and Java. Examples in both languages are shown below. - #### Classification The example below demonstrates how to load a @@ -446,6 +472,7 @@ The test error is calculated to measure the algorithm accuracy. {% highlight scala %} import org.apache.spark.mllib.tree.GradientBoostedTrees import org.apache.spark.mllib.tree.configuration.BoostingStrategy +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel import org.apache.spark.mllib.util.MLUtils // Load and parse the data file. @@ -458,7 +485,7 @@ val (trainingData, testData) = (splits(0), splits(1)) // The defaultParams for Classification use LogLoss by default. val boostingStrategy = BoostingStrategy.defaultParams("Classification") boostingStrategy.numIterations = 3 // Note: Use more iterations in practice. -boostingStrategy.treeStrategy.numClassesForClassification = 2 +boostingStrategy.treeStrategy.numClasses = 2 boostingStrategy.treeStrategy.maxDepth = 5 // Empty categoricalFeaturesInfo indicates all features are continuous. boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]() @@ -473,6 +500,10 @@ val labelAndPreds = testData.map { point => val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() println("Test Error = " + testErr) println("Learned classification GBT model:\n" + model.toDebugString) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = GradientBoostedTreesModel.load(sc, "myModelPath") {% endhighlight %} @@ -534,6 +565,41 @@ Double testErr = }).count() / testData.count(); System.out.println("Test Error: " + testErr); System.out.println("Learned classification GBT model:\n" + model.toDebugString()); + +// Save and load model +model.save(sc.sc(), "myModelPath"); +GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(sc.sc(), "myModelPath"); +{% endhighlight %} + + +
+ +{% highlight python %} +from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel +from pyspark.mllib.util import MLUtils + +# Load and parse the data file. +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") +# Split the data into training and test sets (30% held out for testing) +(trainingData, testData) = data.randomSplit([0.7, 0.3]) + +# Train a GradientBoostedTrees model. +# Notes: (a) Empty categoricalFeaturesInfo indicates all features are continuous. +# (b) Use more iterations in practice. +model = GradientBoostedTrees.trainClassifier(trainingData, + categoricalFeaturesInfo={}, numIterations=3) + +# Evaluate model on test instances and compute test error +predictions = model.predict(testData.map(lambda x: x.features)) +labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) +testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count()) +print('Test Error = ' + str(testErr)) +print('Learned classification GBT model:') +print(model.toDebugString()) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = GradientBoostedTreesModel.load(sc, "myModelPath") {% endhighlight %}
@@ -554,6 +620,7 @@ The Mean Squared Error (MSE) is computed at the end to evaluate {% highlight scala %} import org.apache.spark.mllib.tree.GradientBoostedTrees import org.apache.spark.mllib.tree.configuration.BoostingStrategy +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel import org.apache.spark.mllib.util.MLUtils // Load and parse the data file. @@ -580,6 +647,10 @@ val labelsAndPredictions = testData.map { point => val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() println("Test Mean Squared Error = " + testMSE) println("Learned regression GBT model:\n" + model.toDebugString) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = GradientBoostedTreesModel.load(sc, "myModelPath") {% endhighlight %} @@ -647,6 +718,41 @@ Double testMSE = }) / data.count(); System.out.println("Test Mean Squared Error: " + testMSE); System.out.println("Learned regression GBT model:\n" + model.toDebugString()); + +// Save and load model +model.save(sc.sc(), "myModelPath"); +GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(sc.sc(), "myModelPath"); +{% endhighlight %} + + +
+ +{% highlight python %} +from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel +from pyspark.mllib.util import MLUtils + +# Load and parse the data file. +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") +# Split the data into training and test sets (30% held out for testing) +(trainingData, testData) = data.randomSplit([0.7, 0.3]) + +# Train a GradientBoostedTrees model. +# Notes: (a) Empty categoricalFeaturesInfo indicates all features are continuous. +# (b) Use more iterations in practice. +model = GradientBoostedTrees.trainRegressor(trainingData, + categoricalFeaturesInfo={}, numIterations=3) + +# Evaluate model on test instances and compute test error +predictions = model.predict(testData.map(lambda x: x.features)) +labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) +testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(testData.count()) +print('Test Mean Squared Error = ' + str(testMSE)) +print('Learned regression GBT model:') +print(model.toDebugString()) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = GradientBoostedTreesModel.load(sc, "myModelPath") {% endhighlight %}
diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index d4a61a7fbf3d7..80842b27effd8 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -375,3 +375,105 @@ data2 = labels.zip(normalizer2.transform(features)) {% endhighlight %} + +## Feature selection +[Feature selection](http://en.wikipedia.org/wiki/Feature_selection) allows selecting the most relevant features for use in model construction. Feature selection reduces the size of the vector space and, in turn, the complexity of any subsequent operation with vectors. The number of features to select can be tuned using a held-out validation set. + +### ChiSqSelector +[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) stands for Chi-Squared feature selection. It operates on labeled data with categorical features. `ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, and then filters (selects) the top features which are most closely related to the label. + +#### Model Fitting + +[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) has the +following parameters in the constructor: + +* `numTopFeatures` number of top features that the selector will select (filter). + +We provide a [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) method in +`ChiSqSelector` which can take an input of `RDD[LabeledPoint]` with categorical features, learn the summary statistics, and then +return a `ChiSqSelectorModel` which can transform an input dataset into the reduced feature space. + +This model implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer) +which can apply the Chi-Squared feature selection on a `Vector` to produce a reduced `Vector` or on +an `RDD[Vector]` to produce a reduced `RDD[Vector]`. + +Note that the user can also construct a `ChiSqSelectorModel` by hand by providing an array of selected feature indices (which must be sorted in ascending order). + +#### Example + +The following example shows the basic use of ChiSqSelector. + +
+
+{% highlight scala %} +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils + +// Load some data in libsvm format +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") +// Discretize data in 16 equal bins since ChiSqSelector requires categorical features +val discretizedData = data.map { lp => + LabeledPoint(lp.label, Vectors.dense(lp.features.toArray.map { x => x / 16 } ) ) +} +// Create ChiSqSelector that will select 50 features +val selector = new ChiSqSelector(50) +// Create ChiSqSelector model (selecting features) +val transformer = selector.fit(discretizedData) +// Filter the top 50 features from each feature vector +val filteredData = discretizedData.map { lp => + LabeledPoint(lp.label, transformer.transform(lp.features)) +} +{% endhighlight %} +
+ +
+{% highlight java %} +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.feature.ChiSqSelector; +import org.apache.spark.mllib.feature.ChiSqSelectorModel; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; + +SparkConf sparkConf = new SparkConf().setAppName("JavaChiSqSelector"); +JavaSparkContext sc = new JavaSparkContext(sparkConf); +JavaRDD points = MLUtils.loadLibSVMFile(sc.sc(), + "data/mllib/sample_libsvm_data.txt").toJavaRDD().cache(); + +// Discretize data in 16 equal bins since ChiSqSelector requires categorical features +JavaRDD discretizedData = points.map( + new Function() { + @Override + public LabeledPoint call(LabeledPoint lp) { + final double[] discretizedFeatures = new double[lp.features().size()]; + for (int i = 0; i < lp.features().size(); ++i) { + discretizedFeatures[i] = lp.features().apply(i) / 16; + } + return new LabeledPoint(lp.label(), Vectors.dense(discretizedFeatures)); + } + }); + +// Create ChiSqSelector that will select 50 features +ChiSqSelector selector = new ChiSqSelector(50); +// Create ChiSqSelector model (selecting features) +final ChiSqSelectorModel transformer = selector.fit(discretizedData.rdd()); +// Filter the top 50 features from each feature vector +JavaRDD filteredData = discretizedData.map( + new Function() { + @Override + public LabeledPoint call(LabeledPoint lp) { + return new LabeledPoint(lp.label(), transformer.transform(lp.features())); + } + } +); + +sc.stop(); +{% endhighlight %} +
+
+ diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md new file mode 100644 index 0000000000000..9fd9be0dd01b1 --- /dev/null +++ b/docs/mllib-frequent-pattern-mining.md @@ -0,0 +1,98 @@ +--- +layout: global +title: Frequent Pattern Mining - MLlib +displayTitle: MLlib - Frequent Pattern Mining +--- + +Mining frequent items, itemsets, subsequences, or other substructures is usually among the +first steps to analyze a large-scale dataset, which has been an active research topic in +data mining for years. +We refer users to Wikipedia's [association rule learning](http://en.wikipedia.org/wiki/Association_rule_learning) +for more information. +MLlib provides a parallel implementation of FP-growth, +a popular algorithm to mining frequent itemsets. + +## FP-growth + +The FP-growth algorithm is described in the paper +[Han et al., Mining frequent patterns without candidate generation](http://dx.doi.org/10.1145/335191.335372), +where "FP" stands for frequent pattern. +Given a dataset of transactions, the first step of FP-growth is to calculate item frequencies and identify frequent items. +Different from [Apriori-like](http://en.wikipedia.org/wiki/Apriori_algorithm) algorithms designed for the same purpose, +the second step of FP-growth uses a suffix tree (FP-tree) structure to encode transactions without generating candidate sets +explicitly, which are usually expensive to generate. +After the second step, the frequent itemsets can be extracted from the FP-tree. +In MLlib, we implemented a parallel version of FP-growth called PFP, +as described in [Li et al., PFP: Parallel FP-growth for query recommendation](http://dx.doi.org/10.1145/1454008.1454027). +PFP distributes the work of growing FP-trees based on the suffices of transactions, +and hence more scalable than a single-machine implementation. +We refer users to the papers for more details. + +MLlib's FP-growth implementation takes the following (hyper-)parameters: + +* `minSupport`: the minimum support for an itemset to be identified as frequent. + For example, if an item appears 3 out of 5 transactions, it has a support of 3/5=0.6. +* `numPartitions`: the number of partitions used to distribute the work. + +**Examples** + +
+
+ +[`FPGrowth`](api/java/org/apache/spark/mllib/fpm/FPGrowth.html) implements the +FP-growth algorithm. +It take a `JavaRDD` of transactions, where each transaction is an `Iterable` of items of a generic type. +Calling `FPGrowth.run` with transactions returns an +[`FPGrowthModel`](api/java/org/apache/spark/mllib/fpm/FPGrowthModel.html) +that stores the frequent itemsets with their frequencies. + +{% highlight scala %} +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel} + +val transactions: RDD[Array[String]] = ... + +val fpg = new FPGrowth() + .setMinSupport(0.2) + .setNumPartitions(10) +val model = fpg.run(transactions) + +model.freqItemsets.collect().foreach { itemset => + println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) +} +{% endhighlight %} + +
+ +
+ +[`FPGrowth`](api/java/org/apache/spark/mllib/fpm/FPGrowth.html) implements the +FP-growth algorithm. +It take an `RDD` of transactions, where each transaction is an `Array` of items of a generic type. +Calling `FPGrowth.run` with transactions returns an +[`FPGrowthModel`](api/java/org/apache/spark/mllib/fpm/FPGrowthModel.html) +that stores the frequent itemsets with their frequencies. + +{% highlight java %} +import java.util.List; + +import com.google.common.base.Joiner; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.fpm.FPGrowth; +import org.apache.spark.mllib.fpm.FPGrowthModel; + +JavaRDD> transactions = ... + +FPGrowth fpg = new FPGrowth() + .setMinSupport(0.2) + .setNumPartitions(10); +FPGrowthModel model = fpg.run(transactions); + +for (FPGrowth.FreqItemset itemset: model.freqItemsets().toJavaRDD().collect()) { + System.out.println("[" + Joiner.on(",").join(s.javaItems()) + "], " + s.freq()); +} +{% endhighlight %} + +
+
diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 39c64d06926bf..03b948c421f0b 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -1,6 +1,8 @@ --- layout: global -title: Machine Learning Library (MLlib) Programming Guide +title: MLlib +displayTitle: Machine Learning Library (MLlib) Guide +description: MLlib machine learning library overview for Spark SPARK_VERSION_SHORT --- MLlib is Spark's scalable machine learning library consisting of common learning algorithms and utilities, @@ -19,14 +21,21 @@ filtering, dimensionality reduction, as well as underlying optimization primitiv * [naive Bayes](mllib-naive-bayes.html) * [decision trees](mllib-decision-tree.html) * [ensembles of trees](mllib-ensembles.html) (Random Forests and Gradient-Boosted Trees) + * [isotonic regression](mllib-isotonic-regression.html) * [Collaborative filtering](mllib-collaborative-filtering.html) * alternating least squares (ALS) * [Clustering](mllib-clustering.html) - * k-means + * [k-means](mllib-clustering.html#k-means) + * [Gaussian mixture](mllib-clustering.html#gaussian-mixture) + * [power iteration clustering (PIC)](mllib-clustering.html#power-iteration-clustering-pic) + * [latent Dirichlet allocation (LDA)](mllib-clustering.html#latent-dirichlet-allocation-lda) + * [streaming k-means](mllib-clustering.html#streaming-k-means) * [Dimensionality reduction](mllib-dimensionality-reduction.html) * singular value decomposition (SVD) * principal component analysis (PCA) * [Feature extraction and transformation](mllib-feature-extraction.html) +* [Frequent pattern mining](mllib-frequent-pattern-mining.html) + * FP-growth * [Optimization (developer)](mllib-optimization.html) * stochastic gradient descent * limited-memory BFGS (L-BFGS) @@ -37,7 +46,7 @@ and the migration guide below will explain all changes between releases. # spark.ml: high-level APIs for ML pipelines -Spark 1.2 includes a new package called `spark.ml`, which aims to provide a uniform set of +Spark 1.2 introduced a new package called `spark.ml`, which aims to provide a uniform set of high-level APIs that help users create and tune practical machine learning pipelines. It is currently an alpha component, and we would like to hear back from the community about how it fits real-world use cases and how it could be improved. @@ -52,149 +61,55 @@ See the **[spark.ml programming guide](ml-guide.html)** for more information on # Dependencies -MLlib uses the linear algebra package [Breeze](http://www.scalanlp.org/), -which depends on [netlib-java](https://github.com/fommil/netlib-java), -and [jblas](https://github.com/mikiobraun/jblas). -`netlib-java` and `jblas` depend on native Fortran routines. -You need to install the +MLlib uses the linear algebra package +[Breeze](http://www.scalanlp.org/), which depends on +[netlib-java](https://github.com/fommil/netlib-java) for optimised +numerical processing. If natives are not available at runtime, you +will see a warning message and a pure JVM implementation will be used +instead. + +To learn more about the benefits and background of system optimised +natives, you may wish to watch Sam Halliday's ScalaX talk on +[High Performance Linear Algebra in Scala](http://fommil.github.io/scalax14/#/)). + +Due to licensing issues with runtime proprietary binaries, we do not +include `netlib-java`'s native proxies by default. To configure +`netlib-java` / Breeze to use system optimised binaries, include +`com.github.fommil.netlib:all:1.1.2` (or build Spark with +`-Pnetlib-lgpl`) as a dependency of your project and read the +[netlib-java](https://github.com/fommil/netlib-java) documentation for +your platform's additional installation instructions. + +MLlib also uses [jblas](https://github.com/mikiobraun/jblas) which +will require you to install the [gfortran runtime library](https://github.com/mikiobraun/jblas/wiki/Missing-Libraries) if it is not already present on your nodes. -MLlib will throw a linking error if it cannot detect these libraries automatically. -Due to license issues, we do not include `netlib-java`'s native libraries in MLlib's -dependency set under default settings. -If no native library is available at runtime, you will see a warning message. -To use native libraries from `netlib-java`, please build Spark with `-Pnetlib-lgpl` or -include `com.github.fommil.netlib:all:1.1.2` as a dependency of your project. -If you want to use optimized BLAS/LAPACK libraries such as -[OpenBLAS](http://www.openblas.net/), please link its shared libraries to -`/usr/lib/libblas.so.3` and `/usr/lib/liblapack.so.3`, respectively. -BLAS/LAPACK libraries on worker nodes should be built without multithreading. - -To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.4 or newer. + +To use MLlib in Python, you will need [NumPy](http://www.numpy.org) +version 1.4 or newer. --- # Migration Guide -## From 1.1 to 1.2 - -The only API changes in MLlib v1.2 are in -[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), -which continues to be an experimental API in MLlib 1.2: - -1. *(Breaking change)* The Scala API for classification takes a named argument specifying the number -of classes. In MLlib v1.1, this argument was called `numClasses` in Python and -`numClassesForClassification` in Scala. In MLlib v1.2, the names are both set to `numClasses`. -This `numClasses` parameter is specified either via -[`Strategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.Strategy) -or via [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) -static `trainClassifier` and `trainRegressor` methods. - -2. *(Breaking change)* The API for -[`Node`](api/scala/index.html#org.apache.spark.mllib.tree.model.Node) has changed. -This should generally not affect user code, unless the user manually constructs decision trees -(instead of using the `trainClassifier` or `trainRegressor` methods). -The tree `Node` now includes more information, including the probability of the predicted label -(for classification). - -3. Printing methods' output has changed. The `toString` (Scala/Java) and `__repr__` (Python) methods used to print the full model; they now print a summary. For the full model, use `toDebugString`. - -Examples in the Spark distribution and examples in the -[Decision Trees Guide](mllib-decision-tree.html#examples) have been updated accordingly. - -## From 1.0 to 1.1 - -The only API changes in MLlib v1.1 are in -[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), -which continues to be an experimental API in MLlib 1.1: - -1. *(Breaking change)* The meaning of tree depth has been changed by 1 in order to match -the implementations of trees in -[scikit-learn](http://scikit-learn.org/stable/modules/classes.html#module-sklearn.tree) -and in [rpart](http://cran.r-project.org/web/packages/rpart/index.html). -In MLlib v1.0, a depth-1 tree had 1 leaf node, and a depth-2 tree had 1 root node and 2 leaf nodes. -In MLlib v1.1, a depth-0 tree has 1 leaf node, and a depth-1 tree has 1 root node and 2 leaf nodes. -This depth is specified by the `maxDepth` parameter in -[`Strategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.Strategy) -or via [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) -static `trainClassifier` and `trainRegressor` methods. - -2. *(Non-breaking change)* We recommend using the newly added `trainClassifier` and `trainRegressor` -methods to build a [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), -rather than using the old parameter class `Strategy`. These new training methods explicitly -separate classification and regression, and they replace specialized parameter types with -simple `String` types. - -Examples of the new, recommended `trainClassifier` and `trainRegressor` are given in the -[Decision Trees Guide](mllib-decision-tree.html#examples). - -## From 0.9 to 1.0 - -In MLlib v1.0, we support both dense and sparse input in a unified way, which introduces a few -breaking changes. If your data is sparse, please store it in a sparse format instead of dense to -take advantage of sparsity in both storage and computation. Details are described below. - -
-
- -We used to represent a feature vector by `Array[Double]`, which is replaced by -[`Vector`](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) in v1.0. Algorithms that used -to accept `RDD[Array[Double]]` now take -`RDD[Vector]`. [`LabeledPoint`](api/scala/index.html#org.apache.spark.mllib.regression.LabeledPoint) -is now a wrapper of `(Double, Vector)` instead of `(Double, Array[Double])`. Converting -`Array[Double]` to `Vector` is straightforward: - -{% highlight scala %} -import org.apache.spark.mllib.linalg.{Vector, Vectors} - -val array: Array[Double] = ... // a double array -val vector: Vector = Vectors.dense(array) // a dense vector -{% endhighlight %} - -[`Vectors`](api/scala/index.html#org.apache.spark.mllib.linalg.Vectors$) provides factory methods to create sparse vectors. - -*Note*: Scala imports `scala.collection.immutable.Vector` by default, so you have to import `org.apache.spark.mllib.linalg.Vector` explicitly to use MLlib's `Vector`. - -
- -
- -We used to represent a feature vector by `double[]`, which is replaced by -[`Vector`](api/java/index.html?org/apache/spark/mllib/linalg/Vector.html) in v1.0. Algorithms that used -to accept `RDD` now take -`RDD`. [`LabeledPoint`](api/java/index.html?org/apache/spark/mllib/regression/LabeledPoint.html) -is now a wrapper of `(double, Vector)` instead of `(double, double[])`. Converting `double[]` to -`Vector` is straightforward: - -{% highlight java %} -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; - -double[] array = ... // a double array -Vector vector = Vectors.dense(array); // a dense vector -{% endhighlight %} - -[`Vectors`](api/scala/index.html#org.apache.spark.mllib.linalg.Vectors$) provides factory methods to -create sparse vectors. +For the `spark.ml` package, please see the [spark.ml Migration Guide](ml-guide.html#migration-guide). -
- -
+## From 1.2 to 1.3 -We used to represent a labeled feature vector in a NumPy array, where the first entry corresponds to -the label and the rest are features. This representation is replaced by class -[`LabeledPoint`](api/python/pyspark.mllib.regression.LabeledPoint-class.html), which takes both -dense and sparse feature vectors. +In the `spark.mllib` package, there were several breaking changes. The first change (in `ALS`) is the only one in a component not marked as Alpha or Experimental. -{% highlight python %} -from pyspark.mllib.linalg import SparseVector -from pyspark.mllib.regression import LabeledPoint +* *(Breaking change)* In [`ALS`](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS), the extraneous method `solveLeastSquares` has been removed. The `DeveloperApi` method `analyzeBlocks` was also removed. +* *(Breaking change)* [`StandardScalerModel`](api/scala/index.html#org.apache.spark.mllib.feature.StandardScalerModel) remains an Alpha component. In it, the `variance` method has been replaced with the `std` method. To compute the column variance values returned by the original `variance` method, simply square the standard deviation values returned by `std`. +* *(Breaking change)* [`StreamingLinearRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD) remains an Experimental component. In it, there were two changes: + * The constructor taking arguments was removed in favor of a builder patten using the default constructor plus parameter setter methods. + * Variable `model` is no longer public. +* *(Breaking change)* [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) remains an Experimental component. In it and its associated classes, there were several changes: + * In `DecisionTree`, the deprecated class method `train` has been removed. (The object/static `train` methods remain.) + * In `Strategy`, the `checkpointDir` parameter has been removed. Checkpointing is still supported, but the checkpoint directory must be set before calling tree and tree ensemble training. +* `PythonMLlibAPI` (the interface between Scala/Java and Python for MLlib) was a public API but is now private, declared `private[python]`. This was never meant for external use. +* In linear regression (including Lasso and ridge regression), the squared loss is now divided by 2. + So in order to produce the same result as in 1.2, the regularization parameter needs to be divided by 2 and the step size needs to be multiplied by 2. -# Create a labeled point with a positive label and a dense feature vector. -pos = LabeledPoint(1.0, [1.0, 0.0, 3.0]) +## Previous Spark Versions -# Create a labeled point with a negative label and a sparse feature vector. -neg = LabeledPoint(0.0, SparseVector(3, [0, 2], [1.0, 3.0])) -{% endhighlight %} -
-
+Earlier migration guides are archived [on this page](mllib-migration-guides.html). diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md new file mode 100644 index 0000000000000..12fb29d426741 --- /dev/null +++ b/docs/mllib-isotonic-regression.md @@ -0,0 +1,155 @@ +--- +layout: global +title: Naive Bayes - MLlib +displayTitle: MLlib - Regression +--- + +## Isotonic regression +[Isotonic regression](http://en.wikipedia.org/wiki/Isotonic_regression) +belongs to the family of regression algorithms. Formally isotonic regression is a problem where +given a finite set of real numbers `$Y = {y_1, y_2, ..., y_n}$` representing observed responses +and `$X = {x_1, x_2, ..., x_n}$` the unknown response values to be fitted +finding a function that minimises + +`\begin{equation} + f(x) = \sum_{i=1}^n w_i (y_i - x_i)^2 +\end{equation}` + +with respect to complete order subject to +`$x_1\le x_2\le ...\le x_n$` where `$w_i$` are positive weights. +The resulting function is called isotonic regression and it is unique. +It can be viewed as least squares problem under order restriction. +Essentially isotonic regression is a +[monotonic function](http://en.wikipedia.org/wiki/Monotonic_function) +best fitting the original data points. + +MLlib supports a +[pool adjacent violators algorithm](http://doi.org/10.1198/TECH.2010.10111) +which uses an approach to +[parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). +The training input is a RDD of tuples of three double values that represent +label, feature and weight in this order. Additionally IsotonicRegression algorithm has one +optional parameter called $isotonic$ defaulting to true. +This argument specifies if the isotonic regression is +isotonic (monotonically increasing) or antitonic (monotonically decreasing). + +Training returns an IsotonicRegressionModel that can be used to predict +labels for both known and unknown features. The result of isotonic regression +is treated as piecewise linear function. The rules for prediction therefore are: + +* If the prediction input exactly matches a training feature + then associated prediction is returned. In case there are multiple predictions with the same + feature then one of them is returned. Which one is undefined + (same as java.util.Arrays.binarySearch). +* If the prediction input is lower or higher than all training features + then prediction with lowest or highest feature is returned respectively. + In case there are multiple predictions with the same feature + then the lowest or highest is returned respectively. +* If the prediction input falls between two training features then prediction is treated + as piecewise linear function and interpolated value is calculated from the + predictions of the two closest features. In case there are multiple values + with the same feature then the same rules as in previous point are used. + +### Examples + +
+
+Data are read from a file where each line has a format label,feature +i.e. 4710.28,500.00. The data are split to training and testing set. +Model is created using the training set and a mean squared error is calculated from the predicted +labels and real labels in the test set. + +{% highlight scala %} +import org.apache.spark.mllib.regression.IsotonicRegression + +val data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt") + +// Create label, feature, weight tuples from input data with weight set to default value 1.0. +val parsedData = data.map { line => + val parts = line.split(',').map(_.toDouble) + (parts(0), parts(1), 1.0) +} + +// Split data into training (60%) and test (40%) sets. +val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L) +val training = splits(0) +val test = splits(1) + +// Create isotonic regression model from training data. +// Isotonic parameter defaults to true so it is only shown for demonstration +val model = new IsotonicRegression().setIsotonic(true).run(training) + +// Create tuples of predicted and real labels. +val predictionAndLabel = test.map { point => + val predictedLabel = model.predict(point._2) + (predictedLabel, point._1) +} + +// Calculate mean squared error between predicted and real labels. +val meanSquaredError = predictionAndLabel.map{case(p, l) => math.pow((p - l), 2)}.mean() +println("Mean Squared Error = " + meanSquaredError) +{% endhighlight %} +
+ +
+Data are read from a file where each line has a format label,feature +i.e. 4710.28,500.00. The data are split to training and testing set. +Model is created using the training set and a mean squared error is calculated from the predicted +labels and real labels in the test set. + +{% highlight java %} +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.IsotonicRegressionModel; +import scala.Tuple2; +import scala.Tuple3; + +JavaRDD data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt"); + +// Create label, feature, weight tuples from input data with weight set to default value 1.0. +JavaRDD> parsedData = data.map( + new Function>() { + public Tuple3 call(String line) { + String[] parts = line.split(","); + return new Tuple3<>(new Double(parts[0]), new Double(parts[1]), 1.0); + } + } +); + +// Split data into training (60%) and test (40%) sets. +JavaRDD>[] splits = parsedData.randomSplit(new double[] {0.6, 0.4}, 11L); +JavaRDD> training = splits[0]; +JavaRDD> test = splits[1]; + +// Create isotonic regression model from training data. +// Isotonic parameter defaults to true so it is only shown for demonstration +IsotonicRegressionModel model = new IsotonicRegression().setIsotonic(true).run(training); + +// Create tuples of predicted and real labels. +JavaPairRDD predictionAndLabel = test.mapToPair( + new PairFunction, Double, Double>() { + @Override public Tuple2 call(Tuple3 point) { + Double predictedLabel = model.predict(point._2()); + return new Tuple2(predictedLabel, point._1()); + } + } +); + +// Calculate mean squared error between predicted and real labels. +Double meanSquaredError = new JavaDoubleRDD(predictionAndLabel.map( + new Function, Object>() { + @Override public Object call(Tuple2 pl) { + return Math.pow(pl._1() - pl._2(), 2); + } + } +).rdd()).mean(); + +System.out.println("Mean Squared Error = " + meanSquaredError); +{% endhighlight %} +
+
\ No newline at end of file diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 44b7f67c57734..9270741d439d9 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -17,7 +17,7 @@ displayTitle: MLlib - Linear Methods \newcommand{\av}{\mathbf{\alpha}} \newcommand{\bv}{\mathbf{b}} \newcommand{\N}{\mathbb{N}} -\newcommand{\id}{\mathbf{I}} +\newcommand{\id}{\mathbf{I}} \newcommand{\ind}{\mathbf{1}} \newcommand{\0}{\mathbf{0}} \newcommand{\unit}{\mathbf{e}} @@ -114,18 +114,26 @@ especially when the number of training examples is small. Under the hood, linear methods use convex optimization methods to optimize the objective functions. MLlib uses two methods, SGD and L-BFGS, described in the [optimization section](mllib-optimization.html). Currently, most algorithm APIs support Stochastic Gradient Descent (SGD), and a few support L-BFGS. Refer to [this optimization section](mllib-optimization.html#Choosing-an-Optimization-Method) for guidelines on choosing between optimization methods. -## Binary classification - -[Binary classification](http://en.wikipedia.org/wiki/Binary_classification) -aims to divide items into two categories: positive and negative. MLlib -supports two linear methods for binary classification: linear Support Vector -Machines (SVMs) and logistic regression. For both methods, MLlib supports -L1 and L2 regularized variants. The training data set is represented by an RDD -of [LabeledPoint](mllib-data-types.html) in MLlib. Note that, in the -mathematical formulation in this guide, a training label $y$ is denoted as -either $+1$ (positive) or $-1$ (negative), which is convenient for the -formulation. *However*, the negative label is represented by $0$ in MLlib -instead of $-1$, to be consistent with multiclass labeling. +## Classification + +[Classification](http://en.wikipedia.org/wiki/Statistical_classification) aims to divide items into +categories. +The most common classification type is +[binary classificaion](http://en.wikipedia.org/wiki/Binary_classification), where there are two +categories, usually named positive and negative. +If there are more than two categories, it is called +[multiclass classification](http://en.wikipedia.org/wiki/Multiclass_classification). +MLlib supports two linear methods for classification: linear Support Vector Machines (SVMs) +and logistic regression. +Linear SVMs supports only binary classification, while logistic regression supports both binary and +multiclass classification problems. +For both methods, MLlib supports L1 and L2 regularized variants. +The training data set is represented by an RDD of [LabeledPoint](mllib-data-types.html) in MLlib, +where labels are class indices starting from zero: $0, 1, 2, \ldots$. +Note that, in the mathematical formulation in this guide, a binary label $y$ is denoted as either +$+1$ (positive) or $-1$ (negative), which is convenient for the formulation. +*However*, the negative label is represented by $0$ in MLlib instead of $-1$, to be consistent with +multiclass labeling. ### Linear Support Vector Machines (SVMs) @@ -144,41 +152,7 @@ denoted by $\x$, the model makes predictions based on the value of $\wv^T \x$. By the default, if $\wv^T \x \geq 0$ then the outcome is positive, and negative otherwise. -### Logistic regression - -[Logistic regression](http://en.wikipedia.org/wiki/Logistic_regression) is widely used to predict a -binary response. -It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`, with the loss -function in the formulation given by the logistic loss: -`\[ -L(\wv;\x,y) := \log(1+\exp( -y \wv^T \x)). -\]` - -The logistic regression algorithm outputs a logistic regression model. Given a -new data point, denoted by $\x$, the model makes predictions by -applying the logistic function -`\[ -\mathrm{f}(z) = \frac{1}{1 + e^{-z}} -\]` -where $z = \wv^T \x$. -By default, if $\mathrm{f}(\wv^T x) > 0.5$, the outcome is positive, or -negative otherwise, though unlike linear SVMs, the raw output of the logistic regression -model, $\mathrm{f}(z)$, has a probabilistic interpretation (i.e., the probability -that $\x$ is positive). - -### Evaluation metrics - -MLlib supports common evaluation metrics for binary classification (not available in PySpark). -This -includes precision, recall, [F-measure](http://en.wikipedia.org/wiki/F1_score), -[receiver operating characteristic (ROC)](http://en.wikipedia.org/wiki/Receiver_operating_characteristic), -precision-recall curve, and -[area under the curves (AUC)](http://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve). -AUC is commonly used to compare the performance of various models while -precision/recall/F-measure can help determine the appropriate threshold to use -for prediction purposes. - -### Examples +**Examples**
@@ -190,7 +164,7 @@ error. {% highlight scala %} import org.apache.spark.SparkContext -import org.apache.spark.mllib.classification.SVMWithSGD +import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.Vectors @@ -211,7 +185,7 @@ val model = SVMWithSGD.train(training, numIterations) // Clear the default threshold. model.clearThreshold() -// Compute raw scores on the test set. +// Compute raw scores on the test set. val scoreAndLabels = test.map { point => val score = model.predict(point.features) (score, point.label) @@ -222,6 +196,10 @@ val metrics = new BinaryClassificationMetrics(scoreAndLabels) val auROC = metrics.areaUnderROC() println("Area under ROC = " + auROC) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = SVMModel.load(sc, "myModelPath") {% endhighlight %} The `SVMWithSGD.train()` method by default performs L2 regularization with the @@ -243,8 +221,6 @@ svmAlg.optimizer. val modelL1 = svmAlg.run(training) {% endhighlight %} -[`LogisticRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithSGD) can be used in a similar fashion as `SVMWithSGD`. -
@@ -280,11 +256,11 @@ public class SVMClassifier { JavaRDD training = data.sample(false, 0.6, 11L); training.cache(); JavaRDD test = data.subtract(training); - + // Run training algorithm to build the model. int numIterations = 100; final SVMModel model = SVMWithSGD.train(training.rdd(), numIterations); - + // Clear the default threshold. model.clearThreshold(); @@ -297,13 +273,17 @@ public class SVMClassifier { } } ); - + // Get evaluation metrics. - BinaryClassificationMetrics metrics = + BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(JavaRDD.toRDD(scoreAndLabels)); double auROC = metrics.areaUnderROC(); - + System.out.println("Area under ROC = " + auROC); + + // Save and load model + model.save(sc.sc(), "myModelPath"); + SVMModel sameModel = SVMModel.load(sc.sc(), "myModelPath"); } } {% endhighlight %} @@ -338,6 +318,8 @@ a dependency. The following example shows how to load a sample dataset, build Logistic Regression model, and make predictions with the resulting model to compute the training error. +Note that the Python API does not yet support model save/load but will in the future. + {% highlight python %} from pyspark.mllib.classification import LogisticRegressionWithSGD from pyspark.mllib.regression import LabeledPoint @@ -362,7 +344,191 @@ print("Training Error = " + str(trainErr))
-## Linear least squares, Lasso, and ridge regression +### Logistic regression + +[Logistic regression](http://en.wikipedia.org/wiki/Logistic_regression) is widely used to predict a +binary response. It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`, +with the loss function in the formulation given by the logistic loss: +`\[ +L(\wv;\x,y) := \log(1+\exp( -y \wv^T \x)). +\]` + +For binary classification problems, the algorithm outputs a binary logistic regression model. +Given a new data point, denoted by $\x$, the model makes predictions by +applying the logistic function +`\[ +\mathrm{f}(z) = \frac{1}{1 + e^{-z}} +\]` +where $z = \wv^T \x$. +By default, if $\mathrm{f}(\wv^T x) > 0.5$, the outcome is positive, or +negative otherwise, though unlike linear SVMs, the raw output of the logistic regression +model, $\mathrm{f}(z)$, has a probabilistic interpretation (i.e., the probability +that $\x$ is positive). + +Binary logistic regression can be generalized into +[multinomial logistic regression](http://en.wikipedia.org/wiki/Multinomial_logistic_regression) to +train and predict multiclass classification problems. +For example, for $K$ possible outcomes, one of the outcomes can be chosen as a "pivot", and the +other $K - 1$ outcomes can be separately regressed against the pivot outcome. +In MLlib, the first class $0$ is chosen as the "pivot" class. +See Section 4.4 of +[The Elements of Statistical Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for +references. +Here is an +[detailed mathematical derivation](http://www.slideshare.net/dbtsai/2014-0620-mlor-36132297). + +For multiclass classification problems, the algorithm will outputs a multinomial logistic regression +model, which contains $K - 1$ binary logistic regression models regressed against the first class. +Given a new data points, $K - 1$ models will be run, and the class with largest probability will be +chosen as the predicted class. + +We implemented two algorithms to solve logistic regression: mini-batch gradient descent and L-BFGS. +We recommend L-BFGS over mini-batch gradient descent for faster convergence. + +**Examples** + +
+ +
+The following code illustrates how to load a sample multiclass dataset, split it into train and +test, and use +[LogisticRegressionWithLBFGS](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS) +to fit a logistic regression model. +Then the model is evaluated against the test dataset and saved to disk. + +{% highlight scala %} +import org.apache.spark.SparkContext +import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, LogisticRegressionModel} +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLUtils + +// Load training data in LIBSVM format. +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + +// Split data into training (60%) and test (40%). +val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) +val training = splits(0).cache() +val test = splits(1) + +// Run training algorithm to build the model +val model = new LogisticRegressionWithLBFGS() + .setNumClasses(10) + .run(training) + +// Compute raw scores on the test set. +val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) +} + +// Get evaluation metrics. +val metrics = new MulticlassMetrics(predictionAndLabels) +val precision = metrics.precision +println("Precision = " + precision) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = LogisticRegressionModel.load(sc, "myModelPath") +{% endhighlight %} + +
+ +
+The following code illustrates how to load a sample multiclass dataset, split it into train and +test, and use +[LogisticRegressionWithLBFGS](api/java/org/apache/spark/mllib/classification/LogisticRegressionWithLBFGS.html) +to fit a logistic regression model. +Then the model is evaluated against the test dataset and saved to disk. + +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.MulticlassMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class MultinomialLogisticRegressionExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("SVM Classifier Example"); + SparkContext sc = new SparkContext(conf); + String path = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(10) + .run(training.rdd()); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); + double precision = metrics.precision(); + System.out.println("Precision = " + precision); + + // Save and load model + model.save(sc, "myModelPath"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); + } +} +{% endhighlight %} +
+ +
+The following example shows how to load a sample dataset, build Logistic Regression model, +and make predictions with the resulting model to compute the training error. + +Note that the Python API does not yet support multiclass classification and model save/load but +will in the future. + +{% highlight python %} +from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.regression import LabeledPoint +from numpy import array + +# Load and parse the data +def parsePoint(line): + values = [float(x) for x in line.split(' ')] + return LabeledPoint(values[0], values[1:]) + +data = sc.textFile("data/mllib/sample_svm_data.txt") +parsedData = data.map(parsePoint) + +# Build the model +model = LogisticRegressionWithLBFGS.train(parsedData) + +# Evaluating the model on training data +labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) +trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) +print("Training Error = " + str(trainErr)) +{% endhighlight %} +
+
+ +# Regression + +### Linear least squares, Lasso, and ridge regression Linear least squares is the most common formulation for regression problems. @@ -380,7 +546,7 @@ regularization; and [*Lasso*](http://en.wikipedia.org/wiki/Lasso_(statistics)) u regularization. For all of these models, the average loss or training error, $\frac{1}{n} \sum_{i=1}^n (\wv^T x_i - y_i)^2$, is known as the [mean squared error](http://en.wikipedia.org/wiki/Mean_squared_error). -### Examples +**Examples**
@@ -391,8 +557,9 @@ values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). {% highlight scala %} -import org.apache.spark.mllib.regression.LinearRegressionWithSGD import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.LinearRegressionModel +import org.apache.spark.mllib.regression.LinearRegressionWithSGD import org.apache.spark.mllib.linalg.Vectors // Load and parse the data @@ -413,6 +580,10 @@ val valuesAndPreds = parsedData.map { point => } val MSE = valuesAndPreds.map{case(v, p) => math.pow((v - p), 2)}.mean() println("training Mean Squared Error = " + MSE) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = LinearRegressionModel.load(sc, "myModelPath") {% endhighlight %} [`RidgeRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.RidgeRegressionWithSGD) @@ -483,6 +654,10 @@ public class LinearRegression { } ).rdd()).mean(); System.out.println("training Mean Squared Error = " + MSE); + + // Save and load model + model.save(sc.sc(), "myModelPath"); + LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), "myModelPath"); } } {% endhighlight %} @@ -494,6 +669,8 @@ The example then uses LinearRegressionWithSGD to build a simple linear model to values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). +Note that the Python API does not yet support model save/load but will in the future. + {% highlight python %} from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD from numpy import array @@ -523,7 +700,7 @@ section of the Spark quick-start guide. Be sure to also include *spark-mllib* to your build file as a dependency. -## Streaming linear regression +###Streaming linear regression When data arrive in a streaming fashion, it is useful to fit regression models online, updating the parameters of the model as new data arrives. MLlib currently supports @@ -531,7 +708,7 @@ streaming linear regression using ordinary least squares. The fitting is similar to that performed offline, except fitting occurs on each batch of data, so that the model continually updates to reflect the data from the stream. -### Examples +**Examples** The following example demonstrates how to load training and testing data from two different input streams of text files, parse the streams as labeled points, fit a linear regression model @@ -598,7 +775,7 @@ will get better!
-## Implementation (developer) +# Implementation (developer) Behind the scene, MLlib implements a simple distributed version of stochastic gradient descent (SGD), building on the underlying gradient descent primitive (as described in the MLlib - Old Migration Guides +description: MLlib migration guides from before Spark SPARK_VERSION_SHORT +--- + +The migration guide for the current Spark version is kept on the [MLlib Programming Guide main page](mllib-guide.html#migration-guide). + +## From 1.1 to 1.2 + +The only API changes in MLlib v1.2 are in +[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), +which continues to be an experimental API in MLlib 1.2: + +1. *(Breaking change)* The Scala API for classification takes a named argument specifying the number +of classes. In MLlib v1.1, this argument was called `numClasses` in Python and +`numClassesForClassification` in Scala. In MLlib v1.2, the names are both set to `numClasses`. +This `numClasses` parameter is specified either via +[`Strategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.Strategy) +or via [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) +static `trainClassifier` and `trainRegressor` methods. + +2. *(Breaking change)* The API for +[`Node`](api/scala/index.html#org.apache.spark.mllib.tree.model.Node) has changed. +This should generally not affect user code, unless the user manually constructs decision trees +(instead of using the `trainClassifier` or `trainRegressor` methods). +The tree `Node` now includes more information, including the probability of the predicted label +(for classification). + +3. Printing methods' output has changed. The `toString` (Scala/Java) and `__repr__` (Python) methods used to print the full model; they now print a summary. For the full model, use `toDebugString`. + +Examples in the Spark distribution and examples in the +[Decision Trees Guide](mllib-decision-tree.html#examples) have been updated accordingly. + +## From 1.0 to 1.1 + +The only API changes in MLlib v1.1 are in +[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), +which continues to be an experimental API in MLlib 1.1: + +1. *(Breaking change)* The meaning of tree depth has been changed by 1 in order to match +the implementations of trees in +[scikit-learn](http://scikit-learn.org/stable/modules/classes.html#module-sklearn.tree) +and in [rpart](http://cran.r-project.org/web/packages/rpart/index.html). +In MLlib v1.0, a depth-1 tree had 1 leaf node, and a depth-2 tree had 1 root node and 2 leaf nodes. +In MLlib v1.1, a depth-0 tree has 1 leaf node, and a depth-1 tree has 1 root node and 2 leaf nodes. +This depth is specified by the `maxDepth` parameter in +[`Strategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.Strategy) +or via [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) +static `trainClassifier` and `trainRegressor` methods. + +2. *(Non-breaking change)* We recommend using the newly added `trainClassifier` and `trainRegressor` +methods to build a [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), +rather than using the old parameter class `Strategy`. These new training methods explicitly +separate classification and regression, and they replace specialized parameter types with +simple `String` types. + +Examples of the new, recommended `trainClassifier` and `trainRegressor` are given in the +[Decision Trees Guide](mllib-decision-tree.html#examples). + +## From 0.9 to 1.0 + +In MLlib v1.0, we support both dense and sparse input in a unified way, which introduces a few +breaking changes. If your data is sparse, please store it in a sparse format instead of dense to +take advantage of sparsity in both storage and computation. Details are described below. + diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index d5b044d94fdd7..a83472f5be52e 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -37,7 +37,7 @@ smoothing parameter `lambda` as input, and output a can be used for evaluation and prediction. {% highlight scala %} -import org.apache.spark.mllib.classification.NaiveBayes +import org.apache.spark.mllib.classification.{NaiveBayes, NaiveBayesModel} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint @@ -55,6 +55,10 @@ val model = NaiveBayes.train(training, lambda = 1.0) val predictionAndLabel = test.map(p => (model.predict(p.features), p.label)) val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count() + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = NaiveBayesModel.load(sc, "myModelPath") {% endhighlight %} @@ -93,34 +97,46 @@ double accuracy = predictionAndLabel.filter(new Function, return pl._1().equals(pl._2()); } }).count() / (double) test.count(); + +// Save and load model +model.save(sc.sc(), "myModelPath"); +NaiveBayesModel sameModel = NaiveBayesModel.load(sc.sc(), "myModelPath"); {% endhighlight %}
-[NaiveBayes](api/python/pyspark.mllib.classification.NaiveBayes-class.html) implements multinomial +[NaiveBayes](api/python/pyspark.mllib.html#pyspark.mllib.classification.NaiveBayes) implements multinomial naive Bayes. It takes an RDD of -[LabeledPoint](api/python/pyspark.mllib.regression.LabeledPoint-class.html) and an optionally +[LabeledPoint](api/python/pyspark.mllib.html#pyspark.mllib.regression.LabeledPoint) and an optionally smoothing parameter `lambda` as input, and output a -[NaiveBayesModel](api/python/pyspark.mllib.classification.NaiveBayesModel-class.html), which can be +[NaiveBayesModel](api/python/pyspark.mllib.html#pyspark.mllib.classification.NaiveBayesModel), which can be used for evaluation and prediction. - +Note that the Python API does not yet support model save/load but will in the future. + {% highlight python %} -from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.classification import NaiveBayes +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.regression import LabeledPoint + +def parseLine(line): + parts = line.split(',') + label = float(parts[0]) + features = Vectors.dense([float(x) for x in parts[1].split(' ')]) + return LabeledPoint(label, features) + +data = sc.textFile('data/mllib/sample_naive_bayes_data.txt').map(parseLine) -# an RDD of LabeledPoint -data = sc.parallelize([ - LabeledPoint(0.0, [0.0, 0.0]) - ... # more labeled points -]) +# Split data aproximately into training (60%) and test (40%) +training, test = data.randomSplit([0.6, 0.4], seed = 0) # Train a naive Bayes model. -model = NaiveBayes.train(data, 1.0) +model = NaiveBayes.train(training, 1.0) -# Make prediction. -prediction = model.predict([0.0, 0.0]) +# Make prediction and test accuracy. +predictionAndLabel = test.map(lambda p : (model.predict(p.features), p.label)) +accuracy = 1.0 * predictionAndLabel.filter(lambda (x, v): x == v).count() / test.count() {% endhighlight %}
diff --git a/docs/mllib-optimization.md b/docs/mllib-optimization.md index 4d101afca2c97..6cabc1610a151 100644 --- a/docs/mllib-optimization.md +++ b/docs/mllib-optimization.md @@ -203,6 +203,10 @@ regularization, as well as L2 regularizer. recommended. * `maxNumIterations` is the maximal number of iterations that L-BFGS can be run. * `regParam` is the regularization parameter when using regularization. +* `convergenceTol` controls how much relative change is still allowed when L-BFGS +is considered to converge. This must be nonnegative. Lower values are less tolerant and +therefore generally cause more iterations to be run. This value looks at both average +improvement and the norm of gradient inside [Breeze LBFGS](https://github.com/scalanlp/breeze/blob/master/math/src/main/scala/breeze/optimize/LBFGS.scala). The `return` is a tuple containing two elements. The first element is a column matrix containing weights for every feature, and the second element is an array containing diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index ca8c29218f52d..887eae7f4f07b 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -81,8 +81,8 @@ System.out.println(summary.numNonzeros()); // number of nonzeros in each column
-[`colStats()`](api/python/pyspark.mllib.stat.Statistics-class.html#colStats) returns an instance of -[`MultivariateStatisticalSummary`](api/python/pyspark.mllib.stat.MultivariateStatisticalSummary-class.html), +[`colStats()`](api/python/pyspark.mllib.html#pyspark.mllib.stat.Statistics.colStats) returns an instance of +[`MultivariateStatisticalSummary`](api/python/pyspark.mllib.html#pyspark.mllib.stat.MultivariateStatisticalSummary), which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the total count. @@ -169,7 +169,7 @@ Matrix correlMatrix = Statistics.corr(data.rdd(), "pearson");
-[`Statistics`](api/python/pyspark.mllib.stat.Statistics-class.html) provides methods to +[`Statistics`](api/python/pyspark.mllib.html#pyspark.mllib.stat.Statistics) provides methods to calculate correlations between series. Depending on the type of input, two `RDD[Double]`s or an `RDD[Vector]`, the output will be a `Double` or the correlation `Matrix` respectively. @@ -258,7 +258,7 @@ JavaPairRDD exactSample = data.sampleByKeyExact(false, fractions); {% endhighlight %}
-[`sampleByKey()`](api/python/pyspark.rdd.RDD-class.html#sampleByKey) allows users to +[`sampleByKey()`](api/python/pyspark.html#pyspark.RDD.sampleByKey) allows users to sample approximately $\lceil f_k \cdot n_k \rceil \, \forall k \in K$ items, where $f_k$ is the desired fraction for key $k$, $n_k$ is the number of key-value pairs for key $k$, and $K$ is the set of keys. @@ -476,7 +476,7 @@ JavaDoubleRDD v = u.map(
-[`RandomRDDs`](api/python/pyspark.mllib.random.RandomRDDs-class.html) provides factory +[`RandomRDDs`](api/python/pyspark.mllib.html#pyspark.mllib.random.RandomRDDs) provides factory methods to generate random double RDDs or vector RDDs. The following example generates a random double RDD, whose values follows the standard normal distribution `N(0, 1)`, and then map it to `N(1, 4)`. diff --git a/docs/monitoring.md b/docs/monitoring.md index f32cdef240d31..657bf67be2f0c 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -1,6 +1,7 @@ --- layout: global title: Monitoring and Instrumentation +description: Monitoring, metrics, and instrumentation guide for Spark SPARK_VERSION_SHORT --- There are several ways to monitor Spark applications: web UIs, metrics, and external instrumentation. @@ -149,6 +150,8 @@ follows: Note that in all of these UIs, the tables are sortable by clicking their headers, making it easy to identify slow tasks, data skew, etc. +Note that the history server only displays completed Spark jobs. One way to signal the completion of a Spark job is to stop the Spark Context explicitly (`sc.stop()`), or in Python using the `with SparkContext() as sc:` to handle the Spark Context setup and tear down, and still show the job history on the UI. + # Metrics Spark has a configurable metrics system based on the diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 6486614e71354..dcb80396569db 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -1,6 +1,7 @@ --- layout: global title: Spark Programming Guide +description: Spark SPARK_VERSION_SHORT programming guide in Java, Scala and Python --- * This will become a table of contents (this text will be scraped). @@ -141,8 +142,8 @@ JavaSparkContext sc = new JavaSparkContext(conf);
-The first thing a Spark program must do is to create a [SparkContext](api/python/pyspark.context.SparkContext-class.html) object, which tells Spark -how to access a cluster. To create a `SparkContext` you first need to build a [SparkConf](api/python/pyspark.conf.SparkConf-class.html) object +The first thing a Spark program must do is to create a [SparkContext](api/python/pyspark.html#pyspark.SparkContext) object, which tells Spark +how to access a cluster. To create a `SparkContext` you first need to build a [SparkConf](api/python/pyspark.html#pyspark.SparkConf) object that contains information about your application. {% highlight python %} @@ -172,8 +173,11 @@ in-process. In the Spark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `--master` argument, and you can add JARs to the classpath -by passing a comma-separated list to the `--jars` argument. -For example, to run `bin/spark-shell` on exactly four cores, use: +by passing a comma-separated list to the `--jars` argument. You can also add dependencies +(e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates +to the `--packages` argument. Any additional repositories where dependencies might exist (e.g. SonaType) +can be passed to the `--repositories` argument. For example, to run `bin/spark-shell` on exactly +four cores, use: {% highlight bash %} $ ./bin/spark-shell --master local[4] @@ -185,6 +189,12 @@ Or, to also add `code.jar` to its classpath, use: $ ./bin/spark-shell --master local[4] --jars code.jar {% endhighlight %} +To include a dependency using maven coordinates: + +{% highlight bash %} +$ ./bin/spark-shell --master local[4] --packages "org.example:example:0.1" +{% endhighlight %} + For a complete list of options, run `spark-shell --help`. Behind the scenes, `spark-shell` invokes the more general [`spark-submit` script](submitting-applications.html). @@ -195,7 +205,11 @@ For a complete list of options, run `spark-shell --help`. Behind the scenes, In the PySpark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `--master` argument, and you can add Python .zip, .egg or .py files -to the runtime path by passing a comma-separated list to `--py-files`. +to the runtime path by passing a comma-separated list to `--py-files`. You can also add dependencies +(e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates +to the `--packages` argument. Any additional repositories where dependencies might exist (e.g. SonaType) +can be passed to the `--repositories` argument. Any python dependencies a Spark Package has (listed in +the requirements.txt of that package) must be manually installed using pip when necessary. For example, to run `bin/pyspark` on exactly four cores, use: {% highlight bash %} @@ -223,9 +237,13 @@ You can customize the `ipython` command by setting `PYSPARK_DRIVER_PYTHON_OPTS`. the [IPython Notebook](http://ipython.org/notebook.html) with PyLab plot support: {% highlight bash %} -$ PYSPARK_DRIVER_PYTHON=ipython PYSPARK_DRIVER_PYTHON_OPTS="notebook --pylab inline" ./bin/pyspark +$ PYSPARK_DRIVER_PYTHON=ipython PYSPARK_DRIVER_PYTHON_OPTS="notebook" ./bin/pyspark {% endhighlight %} +After the IPython Notebook server is launched, you can create a new "Python 2" notebook from +the "Files" tab. Inside the notebook, you can input the command `%pylab inline` as part of +your notebook before you start to try Spark from the IPython notebook. +
@@ -321,7 +339,7 @@ Apart from text files, Spark's Scala API also supports several other data format * For [SequenceFiles](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/mapred/SequenceFileInputFormat.html), use SparkContext's `sequenceFile[K, V]` method where `K` and `V` are the types of key and values in the file. These should be subclasses of Hadoop's [Writable](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/Writable.html) interface, like [IntWritable](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/IntWritable.html) and [Text](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/Text.html). In addition, Spark allows you to specify native types for a few common Writables; for example, `sequenceFile[Int, String]` will automatically read IntWritables and Texts. -* For other Hadoop InputFormats, you can use the `SparkContext.hadoopRDD` method, which takes an arbitrary `JobConf` and input format class, key class and value class. Set these the same way you would for a Hadoop job with your input source. You can also use `SparkContext.newHadoopRDD` for InputFormats based on the "new" MapReduce API (`org.apache.hadoop.mapreduce`). +* For other Hadoop InputFormats, you can use the `SparkContext.hadoopRDD` method, which takes an arbitrary `JobConf` and input format class, key class and value class. Set these the same way you would for a Hadoop job with your input source. You can also use `SparkContext.newAPIHadoopRDD` for InputFormats based on the "new" MapReduce API (`org.apache.hadoop.mapreduce`). * `RDD.saveAsObjectFile` and `SparkContext.objectFile` support saving an RDD in a simple format consisting of serialized Java objects. While this is not as efficient as specialized formats like Avro, it offers an easy way to save any RDD. @@ -353,7 +371,7 @@ Apart from text files, Spark's Java API also supports several other data formats * For [SequenceFiles](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/mapred/SequenceFileInputFormat.html), use SparkContext's `sequenceFile[K, V]` method where `K` and `V` are the types of key and values in the file. These should be subclasses of Hadoop's [Writable](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/Writable.html) interface, like [IntWritable](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/IntWritable.html) and [Text](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/Text.html). -* For other Hadoop InputFormats, you can use the `JavaSparkContext.hadoopRDD` method, which takes an arbitrary `JobConf` and input format class, key class and value class. Set these the same way you would for a Hadoop job with your input source. You can also use `JavaSparkContext.newHadoopRDD` for InputFormats based on the "new" MapReduce API (`org.apache.hadoop.mapreduce`). +* For other Hadoop InputFormats, you can use the `JavaSparkContext.hadoopRDD` method, which takes an arbitrary `JobConf` and input format class, key class and value class. Set these the same way you would for a Hadoop job with your input source. You can also use `JavaSparkContext.newAPIHadoopRDD` for InputFormats based on the "new" MapReduce API (`org.apache.hadoop.mapreduce`). * `JavaRDD.saveAsObjectFile` and `JavaSparkContext.objectFile` support saving an RDD in a simple format consisting of serialized Java objects. While this is not as efficient as specialized formats like Avro, it offers an easy way to save any RDD. @@ -835,7 +853,7 @@ The following table lists some of the common transformations supported by Spark. RDD API doc ([Scala](api/scala/index.html#org.apache.spark.rdd.RDD), [Java](api/java/index.html?org/apache/spark/api/java/JavaRDD.html), - [Python](api/python/pyspark.rdd.RDD-class.html)) + [Python](api/python/pyspark.html#pyspark.RDD)) and pair RDD functions doc ([Scala](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions), [Java](api/java/index.html?org/apache/spark/api/java/JavaPairRDD.html)) @@ -856,7 +874,7 @@ for details. Similar to map, but each input item can be mapped to 0 or more output items (so func should return a Seq rather than a single item). - mapPartitions(func) + mapPartitions(func) Similar to map, but runs separately on each partition (block) of the RDD, so func must be of type Iterator<T> => Iterator<U> when running on an RDD of type T. @@ -883,7 +901,7 @@ for details. Return a new dataset that contains the distinct elements of the source dataset. - groupByKey([numTasks]) + groupByKey([numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, Iterable<V>) pairs.
Note: If you are grouping in order to perform an aggregation (such as a sum or average) over each key, using reduceByKey or aggregateByKey will yield much better @@ -894,25 +912,25 @@ for details. - reduceByKey(func, [numTasks]) + reduceByKey(func, [numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, V) pairs where the values for each key are aggregated using the given reduce function func, which must be of type (V,V) => V. Like in groupByKey, the number of reduce tasks is configurable through an optional second argument. - aggregateByKey(zeroValue)(seqOp, combOp, [numTasks]) + aggregateByKey(zeroValue)(seqOp, combOp, [numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, U) pairs where the values for each key are aggregated using the given combine functions and a neutral "zero" value. Allows an aggregated value type that is different than the input value type, while avoiding unnecessary allocations. Like in groupByKey, the number of reduce tasks is configurable through an optional second argument. - sortByKey([ascending], [numTasks]) + sortByKey([ascending], [numTasks]) When called on a dataset of (K, V) pairs where K implements Ordered, returns a dataset of (K, V) pairs sorted by keys in ascending or descending order, as specified in the boolean ascending argument. - join(otherDataset, [numTasks]) + join(otherDataset, [numTasks]) When called on datasets of type (K, V) and (K, W), returns a dataset of (K, (V, W)) pairs with all pairs of elements for each key. Outer joins are supported through leftOuterJoin, rightOuterJoin, and fullOuterJoin. - cogroup(otherDataset, [numTasks]) + cogroup(otherDataset, [numTasks]) When called on datasets of type (K, V) and (K, W), returns a dataset of (K, (Iterable<V>, Iterable<W>)) tuples. This operation is also called groupWith. @@ -925,17 +943,17 @@ for details. process's stdin and lines output to its stdout are returned as an RDD of strings. - coalesce(numPartitions) + coalesce(numPartitions) Decrease the number of partitions in the RDD to numPartitions. Useful for running operations more efficiently after filtering down a large dataset. repartition(numPartitions) Reshuffle the data in the RDD randomly to create either more or fewer partitions and balance it across them. - This always shuffles all data over the network. + This always shuffles all data over the network. - repartitionAndSortWithinPartitions(partitioner) + repartitionAndSortWithinPartitions(partitioner) Repartition the RDD according to the given partitioner and, within each resulting partition, sort records by their keys. This is more efficient than calling repartition and then sorting within each partition because it can push the sorting down into the shuffle machinery. @@ -948,7 +966,7 @@ The following table lists some of the common actions supported by Spark. Refer t RDD API doc ([Scala](api/scala/index.html#org.apache.spark.rdd.RDD), [Java](api/java/index.html?org/apache/spark/api/java/JavaRDD.html), - [Python](api/python/pyspark.rdd.RDD-class.html)) + [Python](api/python/pyspark.html#pyspark.RDD)) and pair RDD functions doc ([Scala](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions), [Java](api/java/index.html?org/apache/spark/api/java/JavaPairRDD.html)) @@ -974,7 +992,7 @@ for details. take(n) - Return an array with the first n elements of the dataset. Note that this is currently not executed in parallel. Instead, the driver program computes all the elements. + Return an array with the first n elements of the dataset. takeSample(withReplacement, num, [seed]) @@ -999,7 +1017,7 @@ for details. SparkContext.objectFile(). - countByKey() + countByKey() Only available on RDDs of type (K, V). Returns a hashmap of (K, Int) pairs with the count of each key. @@ -1008,6 +1026,67 @@ for details. +### Shuffle operations + +Certain operations within Spark trigger an event known as the shuffle. The shuffle is Spark's +mechanism for re-distributing data so that is grouped differently across partitions. This typically +involves copying data across executors and machines, making the shuffle a complex and +costly operation. + +#### Background + +To understand what happens during the shuffle we can consider the example of the +[`reduceByKey`](#ReduceByLink) operation. The `reduceByKey` operation generates a new RDD where all +values for a single key are combined into a tuple - the key and the result of executing a reduce +function against all values associated with that key. The challenge is that not all values for a +single key necessarily reside on the same partition, or even the same machine, but they must be +co-located to compute the result. + +In Spark, data is generally not distributed across partitions to be in the necessary place for a +specific operation. During computations, a single task will operate on a single partition - thus, to +organize all the data for a single `reduceByKey` reduce task to execute, Spark needs to perform an +all-to-all operation. It must read from all partitions to find all the values for all keys, +and then bring together values across partitions to compute the final result for each key - +this is called the **shuffle**. + +Although the set of elements in each partition of newly shuffled data will be deterministic, and so +is the ordering of partitions themselves, the ordering of these elements is not. If one desires predictably +ordered data following shuffle then it's possible to use: + +* `mapPartitions` to sort each partition using, for example, `.sorted` +* `repartitionAndSortWithinPartitions` to efficiently sort partitions while simultaneously repartitioning +* `sortBy` to make a globally ordered RDD + +Operations which can cause a shuffle include **repartition** operations like +[`repartition`](#RepartitionLink), and [`coalesce`](#CoalesceLink), **'ByKey** operations +(except for counting) like [`groupByKey`](#GroupByLink) and [`reduceByKey`](#ReduceByLink), and +**join** operations like [`cogroup`](#CogroupLink) and [`join`](#JoinLink). + +#### Performance Impact +The **Shuffle** is an expensive operation since it involves disk I/O, data serialization, and +network I/O. To organize data for the shuffle, Spark generates sets of tasks - *map* tasks to +organize the data, and a set of *reduce* tasks to aggregate it. This nomenclature comes from +MapReduce and does not directly relate to Spark's `map` and `reduce` operations. + +Internally, results from individual map tasks are kept in memory until they can't fit. Then, these +are sorted based on the target partition and written to a single file. On the reduce side, tasks +read the relevant sorted blocks. + +Certain shuffle operations can consume significant amounts of heap memory since they employ +in-memory data structures to organize records before or after transferring them. Specifically, +`reduceByKey` and `aggregateByKey` create these structures on the map side and `'ByKey` operations +generate these on the reduce side. When data does not fit in memory Spark will spill these tables +to disk, incurring the additional overhead of disk I/O and increased garbage collection. + +Shuffle also generates a large number of intermediate files on disk. As of Spark 1.3, these files +are not cleaned up from Spark's temporary storage until Spark is stopped, which means that +long-running Spark jobs may consume available disk space. This is done so the shuffle doesn't need +to be re-computed if the lineage is re-computed. The temporary storage directory is specified by the +`spark.local.dir` configuration parameter when configuring the Spark context. + +Shuffle behavior can be tuned by adjusting a variety of configuration parameters. See the +'Shuffle Behavior' section within the [Spark Configuration Guide](configuration.html). + ## RDD Persistence One of the most important capabilities in Spark is *persisting* (or *caching*) a dataset in memory @@ -1027,7 +1106,7 @@ replicate it across nodes, or store it off-heap in [Tachyon](http://tachyon-proj These levels are set by passing a `StorageLevel` object ([Scala](api/scala/index.html#org.apache.spark.storage.StorageLevel), [Java](api/java/index.html?org/apache/spark/storage/StorageLevel.html), -[Python](api/python/pyspark.storagelevel.StorageLevel-class.html)) +[Python](api/python/pyspark.html#pyspark.StorageLevel)) to `persist()`. The `cache()` method is a shorthand for using the default storage level, which is `StorageLevel.MEMORY_ONLY` (store deserialized objects in memory). The full set of storage levels is: @@ -1290,7 +1369,7 @@ scala> accum.value {% endhighlight %} While this code used the built-in support for accumulators of type Int, programmers can also -create their own types by subclassing [AccumulatorParam](api/python/pyspark.accumulators.AccumulatorParam-class.html). +create their own types by subclassing [AccumulatorParam](api/python/pyspark.html#pyspark.AccumulatorParam). The AccumulatorParam interface has two methods: `zero` for providing a "zero value" for your data type, and `addInPlace` for adding two values together. For example, supposing we had a `Vector` class representing mathematical vectors, we could write: diff --git a/docs/quick-start.md b/docs/quick-start.md index bf643bb70e153..81143da865cf0 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -1,6 +1,7 @@ --- layout: global title: Quick Start +description: Quick start tutorial for Spark SPARK_VERSION_SHORT --- * This will become a table of contents (this text will be scraped). diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 78358499fd01f..bf398164dd284 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -110,7 +110,7 @@ cluster, or `mesos://zk://host:2181` for a multi-master Mesos cluster using ZooK The driver also needs some configuration in `spark-env.sh` to interact properly with Mesos: 1. In `spark-env.sh` set some environment variables: - * `export MESOS_NATIVE_LIBRARY=`. This path is typically + * `export MESOS_NATIVE_JAVA_LIBRARY=`. This path is typically `/lib/libmesos.so` where the prefix is `/usr/local` by default. See Mesos installation instructions above. On Mac OS X, the library is called `libmesos.dylib` instead of `libmesos.so`. @@ -167,9 +167,6 @@ acquire. By default, it will acquire *all* cores in the cluster (that get offere only makes sense if you run just one application at a time. You can cap the maximum number of cores using `conf.set("spark.cores.max", "10")` (for example). -# Known issues -- When using the "fine-grained" mode, make sure that your executors always leave 32 MB free on the slaves. Otherwise it can happen that your Spark job does not proceed anymore. Currently, Apache Mesos only offers resources if there are at least 32 MB memory allocatable. But as Spark allocates memory only for the executor and cpu only for tasks, it can happen on high slave memory usage that no new tasks will be started anymore. More details can be found in [MESOS-1688](https://issues.apache.org/jira/browse/MESOS-1688). Alternatively use the "coarse-gained" mode, which is not affected by this issue. - # Running Alongside Hadoop You can run Spark and Mesos alongside your existing Hadoop cluster by just launching them as a @@ -197,7 +194,11 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.coarse false - Set the run mode for Spark on Mesos. For more information about the run mode, refer to #Mesos Run Mode section above. + If set to "true", runs over Mesos clusters in + "coarse-grained" sharing mode, + where Spark acquires one long-lived Mesos task on each machine instead of one Mesos task per + Spark task. This gives lower-latency scheduling for short queries, but leaves resources in use + for the whole duration of the Spark job. @@ -211,19 +212,23 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.executor.home - SPARK_HOME + driver side SPARK_HOME - The location where the mesos executor will look for Spark binaries to execute, and uses the SPARK_HOME setting on default. - This variable is only used when no spark.executor.uri is provided, and assumes Spark is installed on the specified location - on each slave. + Set the directory in which Spark is installed on the executors in Mesos. By default, the + executors will simply use the driver's Spark home directory, which may not be visible to + them. Note that this is only relevant if a Spark binary package is not specified through + spark.executor.uri. spark.mesos.executor.memoryOverhead - 384 + executor memory * 0.07, with minimum of 384 - The amount of memory that Mesos executor will request for the task to account for the overhead of running the executor itself. - The final total amount of memory allocated is the maximum value between executor memory plus memoryOverhead, and overhead fraction (1.07) plus the executor memory. + This value is an additive for spark.executor.memory, specified in MiB, + which is used to calculate the total Mesos task memory. A value of 384 + implies a 384MiB overhead. Additionally, there is a hard-coded 7% minimum + overhead. The final overhead will be the larger of either + `spark.mesos.executor.memoryOverhead` or 7% of `spark.executor.memory`. diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 68ab127bcf087..8c5bf9672461f 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -104,6 +104,13 @@ Most of the configs are the same for Spark on YARN as for other deployment modes Comma-separated list of files to be placed in the working directory of each executor. + + spark.executor.instances + 2 + + The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. + + spark.yarn.executor.memoryOverhead executorMemory * 0.07, with minimum of 384 @@ -267,6 +274,6 @@ If you need a reference to the proper location to put log files in the YARN so t # Important notes - Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. -- The local directories used by Spark executors will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. +- In `yarn-cluster` mode, the local directories used by the Spark executors and the Spark driver will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. In `yarn-client` mode, the Spark executors will use the local directories configured for YARN while the Spark driver will use those defined in `spark.local.dir`. This is because the Spark driver does not run on the YARN cluster in `yarn-client` mode, only the Spark executors do. - The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. - The `--jars` option allows the `SparkContext.addJar` function to work if you are using it with local files and running in `yarn-cluster` mode. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files. diff --git a/docs/security.md b/docs/security.md index 6e0a54fbc4ad7..c034ba12ff1fc 100644 --- a/docs/security.md +++ b/docs/security.md @@ -1,6 +1,7 @@ --- layout: global -title: Spark Security +displayTitle: Spark Security +title: Security --- Spark currently supports authentication via a shared secret. Authentication can be configured to be on via the `spark.authenticate` configuration parameter. This parameter controls whether the Spark communication protocols do authentication using the shared secret. This authentication is a basic handshake to make sure both sides have the same shared secret and are allowed to communicate. If the shared secret is not identical they will not be allowed to communicate. The shared secret is created as follows: diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 5c6084fb46255..74d8653a8b845 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -222,8 +222,7 @@ SPARK_WORKER_OPTS supports the following system properties: false Enable periodic cleanup of worker / application directories. Note that this only affects standalone - mode, as YARN works differently. Applications directories are cleaned up regardless of whether - the application is still running. + mode, as YARN works differently. Only the directories of stopped applications are cleaned up. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index be8c5c2c1522e..332618edf0c55 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1,6 +1,7 @@ --- layout: global -title: Spark SQL Programming Guide +displayTitle: Spark SQL and DataFrame Guide +title: Spark SQL and DataFrames --- * This will become a table of contents (this text will be scraped). @@ -8,168 +9,348 @@ title: Spark SQL Programming Guide # Overview +Spark SQL is a Spark module for structured data processing. It provides a programming abstraction called DataFrames and can also act as distributed SQL query engine. + + +# DataFrames + +A DataFrame is a distributed collection of data organized into named columns. It is conceptually equivalent to a table in a relational database or a data frame in R/Python, but with richer optimizations under the hood. DataFrames can be constructed from a wide array of sources such as: structured data files, tables in Hive, external databases, or existing RDDs. + +The DataFrame API is available in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), [Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), and [Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame). + +All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell` or the `pyspark` shell. + + +## Starting Point: `SQLContext` +
-Spark SQL allows relational queries expressed in SQL, HiveQL, or Scala to be executed using -Spark. At the core of this component is a new type of RDD, -[SchemaRDD](api/scala/index.html#org.apache.spark.sql.SchemaRDD). SchemaRDDs are composed of -[Row](api/scala/index.html#org.apache.spark.sql.package@Row:org.apache.spark.sql.catalyst.expressions.Row.type) objects, along with -a schema that describes the data types of each column in the row. A SchemaRDD is similar to a table -in a traditional relational database. A SchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io) -file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). +The entry point into all functionality in Spark SQL is the +[`SQLContext`](api/scala/index.html#org.apache.spark.sql.`SQLContext`) class, or one of its +descendants. To create a basic `SQLContext`, all you need is a SparkContext. + +{% highlight scala %} +val sc: SparkContext // An existing SparkContext. +val sqlContext = new org.apache.spark.sql.SQLContext(sc) -All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell`. +// this is used to implicitly convert an RDD to a DataFrame. +import sqlContext.implicits._ +{% endhighlight %}
-
-Spark SQL allows relational queries expressed in SQL or HiveQL to be executed using -Spark. At the core of this component is a new type of RDD, -[JavaSchemaRDD](api/scala/index.html#org.apache.spark.sql.api.java.JavaSchemaRDD). JavaSchemaRDDs are composed of -[Row](api/scala/index.html#org.apache.spark.sql.api.java.Row) objects, along with -a schema that describes the data types of each column in the row. A JavaSchemaRDD is similar to a table -in a traditional relational database. A JavaSchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io) -file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). +
+ +The entry point into all functionality in Spark SQL is the +[`SQLContext`](api/java/index.html#org.apache.spark.sql.SQLContext) class, or one of its +descendants. To create a basic `SQLContext`, all you need is a SparkContext. + +{% highlight java %} +JavaSparkContext sc = ...; // An existing JavaSparkContext. +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); +{% endhighlight %} +
-Spark SQL allows relational queries expressed in SQL or HiveQL to be executed using -Spark. At the core of this component is a new type of RDD, -[SchemaRDD](api/python/pyspark.sql.SchemaRDD-class.html). SchemaRDDs are composed of -[Row](api/python/pyspark.sql.Row-class.html) objects, along with -a schema that describes the data types of each column in the row. A SchemaRDD is similar to a table -in a traditional relational database. A SchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io) -file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). +The entry point into all relational functionality in Spark is the +[`SQLContext`](api/python/pyspark.sql.html#pyspark.sql.SQLContext) class, or one +of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. + +{% highlight python %} +from pyspark.sql import SQLContext +sqlContext = SQLContext(sc) +{% endhighlight %} -All of the examples on this page use sample data included in the Spark distribution and can be run in the `pyspark` shell.
-**Spark SQL is currently an alpha component. While we will minimize API changes, some APIs may change in future releases.** +In addition to the basic `SQLContext`, you can also create a `HiveContext`, which provides a +superset of the functionality provided by the basic `SQLContext`. Additional features include +the ability to write queries using the more complete HiveQL parser, access to Hive UDFs, and the +ability to read data from Hive tables. To use a `HiveContext`, you do not need to have an +existing Hive setup, and all of the data sources available to a `SQLContext` are still available. +`HiveContext` is only packaged separately to avoid including all of Hive's dependencies in the default +Spark build. If these dependencies are not a problem for your application then using `HiveContext` +is recommended for the 1.3 release of Spark. Future releases will focus on bringing `SQLContext` up +to feature parity with a `HiveContext`. + +The specific variant of SQL that is used to parse queries can also be selected using the +`spark.sql.dialect` option. This parameter can be changed using either the `setConf` method on +a `SQLContext` or by using a `SET key=value` command in SQL. For a `SQLContext`, the only dialect +available is "sql" which uses a simple SQL parser provided by Spark SQL. In a `HiveContext`, the +default is "hiveql", though "sql" is also available. Since the HiveQL parser is much more complete, +this is recommended for most use cases. + + +## Creating DataFrames -*************************************************************************************************** +With a `SQLContext`, applications can create `DataFrame`s from an existing `RDD`, from a Hive table, or from data sources. -# Getting Started +As an example, the following creates a `DataFrame` based on the content of a JSON file:
- -The entry point into all relational functionality in Spark is the -[SQLContext](api/scala/index.html#org.apache.spark.sql.SQLContext) class, or one of its -descendants. To create a basic SQLContext, all you need is a SparkContext. - {% highlight scala %} val sc: SparkContext // An existing SparkContext. val sqlContext = new org.apache.spark.sql.SQLContext(sc) -// createSchemaRDD is used to implicitly convert an RDD to a SchemaRDD. -import sqlContext.createSchemaRDD -{% endhighlight %} +val df = sqlContext.jsonFile("examples/src/main/resources/people.json") -In addition to the basic SQLContext, you can also create a HiveContext, which provides a -superset of the functionality provided by the basic SQLContext. Additional features include -the ability to write queries using the more complete HiveQL parser, access to HiveUDFs, and the -ability to read data from Hive tables. To use a HiveContext, you do not need to have an -existing Hive setup, and all of the data sources available to a SQLContext are still available. -HiveContext is only packaged separately to avoid including all of Hive's dependencies in the default -Spark build. If these dependencies are not a problem for your application then using HiveContext -is recommended for the 1.2 release of Spark. Future releases will focus on bringing SQLContext up to -feature parity with a HiveContext. +// Displays the content of the DataFrame to stdout +df.show() +{% endhighlight %}
- -The entry point into all relational functionality in Spark is the -[JavaSQLContext](api/scala/index.html#org.apache.spark.sql.api.java.JavaSQLContext) class, or one -of its descendants. To create a basic JavaSQLContext, all you need is a JavaSparkContext. - {% highlight java %} JavaSparkContext sc = ...; // An existing JavaSparkContext. -JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); -{% endhighlight %} +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); -In addition to the basic SQLContext, you can also create a HiveContext, which provides a strict -super set of the functionality provided by the basic SQLContext. Additional features include -the ability to write queries using the more complete HiveQL parser, access to HiveUDFs, and the -ability to read data from Hive tables. To use a HiveContext, you do not need to have an -existing Hive setup, and all of the data sources available to a SQLContext are still available. -HiveContext is only packaged separately to avoid including all of Hive's dependencies in the default -Spark build. If these dependencies are not a problem for your application then using HiveContext -is recommended for the 1.2 release of Spark. Future releases will focus on bringing SQLContext up to -feature parity with a HiveContext. +DataFrame df = sqlContext.jsonFile("examples/src/main/resources/people.json"); + +// Displays the content of the DataFrame to stdout +df.show(); +{% endhighlight %}
+{% highlight python %} +from pyspark.sql import SQLContext +sqlContext = SQLContext(sc) -The entry point into all relational functionality in Spark is the -[SQLContext](api/python/pyspark.sql.SQLContext-class.html) class, or one -of its decedents. To create a basic SQLContext, all you need is a SparkContext. +df = sqlContext.jsonFile("examples/src/main/resources/people.json") +# Displays the content of the DataFrame to stdout +df.show() +{% endhighlight %} + +
+
+ + +## DataFrame Operations + +DataFrames provide a domain-specific language for structured data manipulation in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), [Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), and [Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame). + +Here we include some basic examples of structured data processing using DataFrames: + + +
+
+{% highlight scala %} +val sc: SparkContext // An existing SparkContext. +val sqlContext = new org.apache.spark.sql.SQLContext(sc) + +// Create the DataFrame +val df = sqlContext.jsonFile("examples/src/main/resources/people.json") + +// Show the content of the DataFrame +df.show() +// age name +// null Michael +// 30 Andy +// 19 Justin + +// Print the schema in a tree format +df.printSchema() +// root +// |-- age: long (nullable = true) +// |-- name: string (nullable = true) + +// Select only the "name" column +df.select("name").show() +// name +// Michael +// Andy +// Justin + +// Select everybody, but increment the age by 1 +df.select(df("name"), df("age") + 1).show() +// name (age + 1) +// Michael null +// Andy 31 +// Justin 20 + +// Select people older than 21 +df.filter(df("age") > 21).show() +// age name +// 30 Andy + +// Count people by age +df.groupBy("age").count().show() +// age count +// null 1 +// 19 1 +// 30 1 +{% endhighlight %} + +
+ +
+{% highlight java %} +val sc: JavaSparkContext // An existing SparkContext. +val sqlContext = new org.apache.spark.sql.SQLContext(sc) + +// Create the DataFrame +DataFrame df = sqlContext.jsonFile("examples/src/main/resources/people.json"); + +// Show the content of the DataFrame +df.show(); +// age name +// null Michael +// 30 Andy +// 19 Justin + +// Print the schema in a tree format +df.printSchema(); +// root +// |-- age: long (nullable = true) +// |-- name: string (nullable = true) + +// Select only the "name" column +df.select("name").show(); +// name +// Michael +// Andy +// Justin + +// Select everybody, but increment the age by 1 +df.select(df.col("name"), df.col("age").plus(1)).show(); +// name (age + 1) +// Michael null +// Andy 31 +// Justin 20 + +// Select people older than 21 +df.filter(df.col("age").gt(21)).show(); +// age name +// 30 Andy + +// Count people by age +df.groupBy("age").count().show(); +// age count +// null 1 +// 19 1 +// 30 1 +{% endhighlight %} + +
+ +
{% highlight python %} from pyspark.sql import SQLContext sqlContext = SQLContext(sc) -{% endhighlight %} -In addition to the basic SQLContext, you can also create a HiveContext, which provides a strict -super set of the functionality provided by the basic SQLContext. Additional features include -the ability to write queries using the more complete HiveQL parser, access to HiveUDFs, and the -ability to read data from Hive tables. To use a HiveContext, you do not need to have an -existing Hive setup, and all of the data sources available to a SQLContext are still available. -HiveContext is only packaged separately to avoid including all of Hive's dependencies in the default -Spark build. If these dependencies are not a problem for your application then using HiveContext -is recommended for the 1.2 release of Spark. Future releases will focus on bringing SQLContext up to -feature parity with a HiveContext. +# Create the DataFrame +df = sqlContext.jsonFile("examples/src/main/resources/people.json") + +# Show the content of the DataFrame +df.show() +## age name +## null Michael +## 30 Andy +## 19 Justin + +# Print the schema in a tree format +df.printSchema() +## root +## |-- age: long (nullable = true) +## |-- name: string (nullable = true) + +# Select only the "name" column +df.select("name").show() +## name +## Michael +## Andy +## Justin + +# Select everybody, but increment the age by 1 +df.select(df.name, df.age + 1).show() +## name (age + 1) +## Michael null +## Andy 31 +## Justin 20 + +# Select people older than 21 +df.filter(df.age > 21).show() +## age name +## 30 Andy + +# Count people by age +df.groupBy("age").count().show() +## age count +## null 1 +## 19 1 +## 30 1 +{% endhighlight %} + +
+ +## Running SQL Queries Programmatically + +The `sql` function on a `SQLContext` enables applications to run SQL queries programmatically and returns the result as a `DataFrame`. + +
+
+{% highlight scala %} +val sqlContext = ... // An existing SQLContext +val df = sqlContext.sql("SELECT * FROM table") +{% endhighlight %}
-The specific variant of SQL that is used to parse queries can also be selected using the -`spark.sql.dialect` option. This parameter can be changed using either the `setConf` method on -a SQLContext or by using a `SET key=value` command in SQL. For a SQLContext, the only dialect -available is "sql" which uses a simple SQL parser provided by Spark SQL. In a HiveContext, the -default is "hiveql", though "sql" is also available. Since the HiveQL parser is much more complete, - this is recommended for most use cases. +
+{% highlight java %} +val sqlContext = ... // An existing SQLContext +val df = sqlContext.sql("SELECT * FROM table") +{% endhighlight %} +
-# Data Sources +
+{% highlight python %} +from pyspark.sql import SQLContext +sqlContext = SQLContext(sc) +df = sqlContext.sql("SELECT * FROM table") +{% endhighlight %} +
+
-Spark SQL supports operating on a variety of data sources through the `SchemaRDD` interface. -A SchemaRDD can be operated on as normal RDDs and can also be registered as a temporary table. -Registering a SchemaRDD as a table allows you to run SQL queries over its data. This section -describes the various methods for loading data into a SchemaRDD. -## RDDs +## Interoperating with RDDs -Spark SQL supports two different methods for converting existing RDDs into SchemaRDDs. The first +Spark SQL supports two different methods for converting existing RDDs into DataFrames. The first method uses reflection to infer the schema of an RDD that contains specific types of objects. This reflection based approach leads to more concise code and works well when you already know the schema while writing your Spark application. -The second method for creating SchemaRDDs is through a programmatic interface that allows you to +The second method for creating DataFrames is through a programmatic interface that allows you to construct a schema and then apply it to an existing RDD. While this method is more verbose, it allows -you to construct SchemaRDDs when the columns and their types are not known until runtime. +you to construct DataFrames when the columns and their types are not known until runtime. ### Inferring the Schema Using Reflection
-The Scala interaface for Spark SQL supports automatically converting an RDD containing case classes -to a SchemaRDD. The case class +The Scala interface for Spark SQL supports automatically converting an RDD containing case classes +to a DataFrame. The case class defines the schema of the table. The names of the arguments to the case class are read using reflection and become the names of the columns. Case classes can also be nested or contain complex -types such as Sequences or Arrays. This RDD can be implicitly converted to a SchemaRDD and then be +types such as Sequences or Arrays. This RDD can be implicitly converted to a DataFrame and then be registered as a table. Tables can be used in subsequent SQL statements. {% highlight scala %} // sc is an existing SparkContext. val sqlContext = new org.apache.spark.sql.SQLContext(sc) -// createSchemaRDD is used to implicitly convert an RDD to a SchemaRDD. -import sqlContext.createSchemaRDD +// this is used to implicitly convert an RDD to a DataFrame. +import sqlContext.implicits._ // Define the schema using a case class. // Note: Case classes in Scala 2.10 can support only up to 22 fields. To work around this limit, @@ -177,13 +358,13 @@ import sqlContext.createSchemaRDD case class Person(name: String, age: Int) // Create an RDD of Person objects and register it as a table. -val people = sc.textFile("examples/src/main/resources/people.txt").map(_.split(",")).map(p => Person(p(0), p(1).trim.toInt)) +val people = sc.textFile("examples/src/main/resources/people.txt").map(_.split(",")).map(p => Person(p(0), p(1).trim.toInt)).toDF() people.registerTempTable("people") // SQL statements can be run by using the sql methods provided by sqlContext. val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") -// The results of SQL queries are SchemaRDDs and support all the normal RDD operations. +// The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. teenagers.map(t => "Name: " + t(0)).collect().foreach(println) {% endhighlight %} @@ -193,7 +374,7 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println)
Spark SQL supports automatically converting an RDD of [JavaBeans](http://stackoverflow.com/questions/3295496/what-is-a-javabean-exactly) -into a Schema RDD. The BeanInfo, obtained using reflection, defines the schema of the table. +into a DataFrame. The BeanInfo, obtained using reflection, defines the schema of the table. Currently, Spark SQL does not support JavaBeans that contain nested or contain complex types such as Lists or Arrays. You can create a JavaBean by creating a class that implements Serializable and has getters and setters for all of its fields. @@ -224,12 +405,12 @@ public static class Person implements Serializable { {% endhighlight %} -A schema can be applied to an existing RDD by calling `applySchema` and providing the Class object +A schema can be applied to an existing RDD by calling `createDataFrame` and providing the Class object for the JavaBean. {% highlight java %} // sc is an existing JavaSparkContext. -JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); // Load a text file and convert each line to a JavaBean. JavaRDD people = sc.textFile("examples/src/main/resources/people.txt").map( @@ -246,13 +427,13 @@ JavaRDD people = sc.textFile("examples/src/main/resources/people.txt").m }); // Apply a schema to an RDD of JavaBeans and register it as a table. -JavaSchemaRDD schemaPeople = sqlContext.applySchema(people, Person.class); +DataFrame schemaPeople = sqlContext.createDataFrame(people, Person.class); schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. -JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") +DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") -// The results of SQL queries are SchemaRDDs and support all the normal RDD operations. +// The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. List teenagerNames = teenagers.map(new Function() { public String call(Row row) { @@ -266,7 +447,7 @@ List teenagerNames = teenagers.map(new Function() {
-Spark SQL can convert an RDD of Row objects to a SchemaRDD, inferring the datatypes. Rows are constructed by passing a list of +Spark SQL can convert an RDD of Row objects to a DataFrame, inferring the datatypes. Rows are constructed by passing a list of key/value pairs as kwargs to the Row class. The keys of this list define the column names of the table, and the types are inferred by looking at the first row. Since we currently only look at the first row, it is important that there is no missing data in the first row of the RDD. In future versions we @@ -283,11 +464,11 @@ lines = sc.textFile("examples/src/main/resources/people.txt") parts = lines.map(lambda l: l.split(",")) people = parts.map(lambda p: Row(name=p[0], age=int(p[1]))) -# Infer the schema, and register the SchemaRDD as a table. +# Infer the schema, and register the DataFrame as a table. schemaPeople = sqlContext.inferSchema(people) schemaPeople.registerTempTable("people") -# SQL can be run over SchemaRDDs that have been registered as a table. +# SQL can be run over DataFrames that have been registered as a table. teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") # The results of SQL queries are RDDs and support all the normal RDD operations. @@ -309,12 +490,12 @@ for teenName in teenNames.collect(): When case classes cannot be defined ahead of time (for example, the structure of records is encoded in a string, or a text dataset will be parsed and fields will be projected differently for different users), -a `SchemaRDD` can be created programmatically with three steps. +a `DataFrame` can be created programmatically with three steps. 1. Create an RDD of `Row`s from the original RDD; 2. Create the schema represented by a `StructType` matching the structure of `Row`s in the RDD created in Step 1. -3. Apply the schema to the RDD of `Row`s via `applySchema` method provided +3. Apply the schema to the RDD of `Row`s via `createDataFrame` method provided by `SQLContext`. For example: @@ -328,8 +509,11 @@ val people = sc.textFile("examples/src/main/resources/people.txt") // The schema is encoded in a string val schemaString = "name age" -// Import Spark SQL data types and Row. -import org.apache.spark.sql._ +// Import Row. +import org.apache.spark.sql.Row; + +// Import Spark SQL data types +import org.apache.spark.sql.types.{StructType,StructField,StringType}; // Generate the schema based on the string of schema val schema = @@ -340,15 +524,15 @@ val schema = val rowRDD = people.map(_.split(",")).map(p => Row(p(0), p(1).trim)) // Apply the schema to the RDD. -val peopleSchemaRDD = sqlContext.applySchema(rowRDD, schema) +val peopleDataFrame = sqlContext.createDataFrame(rowRDD, schema) -// Register the SchemaRDD as a table. -peopleSchemaRDD.registerTempTable("people") +// Register the DataFrames as a table. +peopleDataFrame.registerTempTable("people") // SQL statements can be run by using the sql methods provided by sqlContext. val results = sqlContext.sql("SELECT name FROM people") -// The results of SQL queries are SchemaRDDs and support all the normal RDD operations. +// The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. results.map(t => "Name: " + t(0)).collect().foreach(println) {% endhighlight %} @@ -361,26 +545,26 @@ results.map(t => "Name: " + t(0)).collect().foreach(println) When JavaBean classes cannot be defined ahead of time (for example, the structure of records is encoded in a string, or a text dataset will be parsed and fields will be projected differently for different users), -a `SchemaRDD` can be created programmatically with three steps. +a `DataFrame` can be created programmatically with three steps. 1. Create an RDD of `Row`s from the original RDD; 2. Create the schema represented by a `StructType` matching the structure of `Row`s in the RDD created in Step 1. -3. Apply the schema to the RDD of `Row`s via `applySchema` method provided -by `JavaSQLContext`. +3. Apply the schema to the RDD of `Row`s via `createDataFrame` method provided +by `SQLContext`. For example: {% highlight java %} // Import factory methods provided by DataType. -import org.apache.spark.sql.api.java.DataType +import org.apache.spark.sql.types.DataType; // Import StructType and StructField -import org.apache.spark.sql.api.java.StructType -import org.apache.spark.sql.api.java.StructField +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.StructField; // Import Row. -import org.apache.spark.sql.api.java.Row +import org.apache.spark.sql.Row; // sc is an existing JavaSparkContext. -JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); // Load a text file and convert each line to a JavaBean. JavaRDD people = sc.textFile("examples/src/main/resources/people.txt"); @@ -405,15 +589,15 @@ JavaRDD rowRDD = people.map( }); // Apply the schema to the RDD. -JavaSchemaRDD peopleSchemaRDD = sqlContext.applySchema(rowRDD, schema); +DataFrame peopleDataFrame = sqlContext.createDataFrame(rowRDD, schema); -// Register the SchemaRDD as a table. -peopleSchemaRDD.registerTempTable("people"); +// Register the DataFrame as a table. +peopleDataFrame.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. -JavaSchemaRDD results = sqlContext.sql("SELECT name FROM people"); +DataFrame results = sqlContext.sql("SELECT name FROM people"); -// The results of SQL queries are SchemaRDDs and support all the normal RDD operations. +// The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. List names = results.map(new Function() { public String call(Row row) { @@ -430,17 +614,18 @@ List names = results.map(new Function() { When a dictionary of kwargs cannot be defined ahead of time (for example, the structure of records is encoded in a string, or a text dataset will be parsed and fields will be projected differently for different users), -a `SchemaRDD` can be created programmatically with three steps. +a `DataFrame` can be created programmatically with three steps. 1. Create an RDD of tuples or lists from the original RDD; 2. Create the schema represented by a `StructType` matching the structure of tuples or lists in the RDD created in the step 1. -3. Apply the schema to the RDD via `applySchema` method provided by `SQLContext`. +3. Apply the schema to the RDD via `createDataFrame` method provided by `SQLContext`. For example: {% highlight python %} # Import SQLContext and data types -from pyspark.sql import * +from pyspark.sql import SQLContext +from pyspark.sql.types import * # sc is an existing SparkContext. sqlContext = SQLContext(sc) @@ -457,12 +642,12 @@ fields = [StructField(field_name, StringType(), True) for field_name in schemaSt schema = StructType(fields) # Apply the schema to the RDD. -schemaPeople = sqlContext.applySchema(people, schema) +schemaPeople = sqlContext.createDataFrame(people, schema) -# Register the SchemaRDD as a table. +# Register the DataFrame as a table. schemaPeople.registerTempTable("people") -# SQL can be run over SchemaRDDs that have been registered as a table. +# SQL can be run over DataFrames that have been registered as a table. results = sqlContext.sql("SELECT name FROM people") # The results of SQL queries are RDDs and support all the normal RDD operations. @@ -471,11 +656,157 @@ for name in names.collect(): print name {% endhighlight %} +
+ +
+ + +# Data Sources + +Spark SQL supports operating on a variety of data sources through the `DataFrame` interface. +A DataFrame can be operated on as normal RDDs and can also be registered as a temporary table. +Registering a DataFrame as a table allows you to run SQL queries over its data. This section +describes the general methods for loading and saving data using the Spark Data Sources and then +goes into specific options that are available for the built-in data sources. + +## Generic Load/Save Functions + +In the simplest form, the default data source (`parquet` unless otherwise configured by +`spark.sql.sources.default`) will be used for all operations. + +
+
+ +{% highlight scala %} +val df = sqlContext.load("people.parquet") +df.select("name", "age").save("namesAndAges.parquet") +{% endhighlight %} + +
+ +
+ +{% highlight java %} + +DataFrame df = sqlContext.load("people.parquet"); +df.select("name", "age").save("namesAndAges.parquet"); + +{% endhighlight %} + +
+ +
+ +{% highlight python %} + +df = sqlContext.load("people.parquet") +df.select("name", "age").save("namesAndAges.parquet") + +{% endhighlight %} + +
+
+ +### Manually Specifying Options + +You can also manually specify the data source that will be used along with any extra options +that you would like to pass to the data source. Data sources are specified by their fully qualified +name (i.e., `org.apache.spark.sql.parquet`), but for built-in sources you can also use the shorted +name (`json`, `parquet`, `jdbc`). DataFrames of any type can be converted into other types +using this syntax. + +
+
+ +{% highlight scala %} +val df = sqlContext.load("people.json", "json") +df.select("name", "age").save("namesAndAges.parquet", "parquet") +{% endhighlight %}
+
+ +{% highlight java %} + +DataFrame df = sqlContext.load("people.json", "json"); +df.select("name", "age").save("namesAndAges.parquet", "parquet"); + +{% endhighlight %} +
+
+ +{% highlight python %} + +df = sqlContext.load("people.json", "json") +df.select("name", "age").save("namesAndAges.parquet", "parquet") + +{% endhighlight %} + +
+
+ +### Save Modes + +Save operations can optionally take a `SaveMode`, that specifies how to handle existing data if +present. It is important to realize that these save modes do not utilize any locking and are not +atomic. Thus, it is not safe to have multiple writers attempting to write to the same location. +Additionally, when performing a `Overwrite`, the data will be deleted before writing out the +new data. + + + + + + + + + + + + + + + + + + + + + + + +
Scala/JavaPythonMeaning
SaveMode.ErrorIfExists (default)"error" (default) + When saving a DataFrame to a data source, if data already exists, + an exception is expected to be thrown. +
SaveMode.Append"append" + When saving a DataFrame to a data source, if data/table already exists, + contents of the DataFrame are expected to be appended to existing data. +
SaveMode.Overwrite"overwrite" + Overwrite mode means that when saving a DataFrame to a data source, + if data/table already exists, existing data is expected to be overwritten by the contents of + the DataFrame. +
SaveMode.Ignore"ignore" + Ignore mode means that when saving a DataFrame to a data source, if data already exists, + the save operation is expected to not save the contents of the DataFrame and to not + change the existing data. This is similar to a `CREATE TABLE IF NOT EXISTS` in SQL. +
+ +### Saving to Persistent Tables + +When working with a `HiveContext`, `DataFrames` can also be saved as persistent tables using the +`saveAsTable` command. Unlike the `registerTempTable` command, `saveAsTable` will materialize the +contents of the dataframe and create a pointer to the data in the HiveMetastore. Persistent tables +will still exist even after your Spark program has restarted, as long as you maintain your connection +to the same metastore. A DataFrame for a persistent table can be created by calling the `table` +method on a `SQLContext` with the name of the table. + +By default `saveAsTable` will create a "managed table", meaning that the location of the data will +be controlled by the metastore. Managed tables will also have their data deleted automatically +when a table is dropped. + ## Parquet Files [Parquet](http://parquet.io) is a columnar format that is supported by many other data processing systems. @@ -492,16 +823,16 @@ Using the data from the above example: {% highlight scala %} // sqlContext from the previous example is used in this example. -// createSchemaRDD is used to implicitly convert an RDD to a SchemaRDD. -import sqlContext.createSchemaRDD +// This is used to implicitly convert an RDD to a DataFrame. +import sqlContext.implicits._ val people: RDD[Person] = ... // An RDD of case class objects, from the previous example. -// The RDD is implicitly converted to a SchemaRDD by createSchemaRDD, allowing it to be stored using Parquet. +// The RDD is implicitly converted to a DataFrame by implicits, allowing it to be stored using Parquet. people.saveAsParquetFile("people.parquet") // Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. -// The result of loading a Parquet file is also a SchemaRDD. +// The result of loading a Parquet file is also a DataFrame. val parquetFile = sqlContext.parquetFile("people.parquet") //Parquet files can also be registered as tables and then used in SQL statements. @@ -517,18 +848,18 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println) {% highlight java %} // sqlContext from the previous example is used in this example. -JavaSchemaRDD schemaPeople = ... // The JavaSchemaRDD from the previous example. +DataFrame schemaPeople = ... // The DataFrame from the previous example. -// JavaSchemaRDDs can be saved as Parquet files, maintaining the schema information. +// DataFrames can be saved as Parquet files, maintaining the schema information. schemaPeople.saveAsParquetFile("people.parquet"); // Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. -// The result of loading a parquet file is also a JavaSchemaRDD. -JavaSchemaRDD parquetFile = sqlContext.parquetFile("people.parquet"); +// The result of loading a parquet file is also a DataFrame. +DataFrame parquetFile = sqlContext.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); -JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); +DataFrame teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); List teenagerNames = teenagers.map(new Function() { public String call(Row row) { return "Name: " + row.getString(0); @@ -543,13 +874,13 @@ List teenagerNames = teenagers.map(new Function() { {% highlight python %} # sqlContext from the previous example is used in this example. -schemaPeople # The SchemaRDD from the previous example. +schemaPeople # The DataFrame from the previous example. -# SchemaRDDs can be saved as Parquet files, maintaining the schema information. +# DataFrames can be saved as Parquet files, maintaining the schema information. schemaPeople.saveAsParquetFile("people.parquet") # Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. -# The result of loading a parquet file is also a SchemaRDD. +# The result of loading a parquet file is also a DataFrame. parquetFile = sqlContext.parquetFile("people.parquet") # Parquet files can also be registered as tables and then used in SQL statements. @@ -562,11 +893,150 @@ for teenName in teenNames.collect():
+
+ +{% highlight sql %} + +CREATE TEMPORARY TABLE parquetTable +USING org.apache.spark.sql.parquet +OPTIONS ( + path "examples/src/main/resources/people.parquet" +) + +SELECT * FROM parquetTable + +{% endhighlight %} + +
+ +
+ +### Partition discovery + +Table partitioning is a common optimization approach used in systems like Hive. In a partitioned +table, data are usually stored in different directories, with partitioning column values encoded in +the path of each partition directory. The Parquet data source is now able to discover and infer +partitioning information automatically. For exmaple, we can store all our previously used +population data into a partitioned table using the following directory structure, with two extra +columns, `gender` and `country` as partitioning columns: + +{% highlight text %} + +path +└── to + └── table + ├── gender=male + │   ├── ... + │   │ + │   ├── country=US + │   │   └── data.parquet + │   ├── country=CN + │   │   └── data.parquet + │   └── ... + └── gender=female +    ├── ... +    │ +    ├── country=US +    │   └── data.parquet +    ├── country=CN +    │   └── data.parquet +    └── ... + +{% endhighlight %} + +By passing `path/to/table` to either `SQLContext.parquetFile` or `SQLContext.load`, Spark SQL will +automatically extract the partitioning information from the paths. Now the schema of the returned +DataFrame becomes: + +{% highlight text %} + +root +|-- name: string (nullable = true) +|-- age: long (nullable = true) +|-- gender: string (nullable = true) +|-- country: string (nullable = true) + +{% endhighlight %} + +Notice that the data types of the partitioning columns are automatically inferred. Currently, +numeric data types and string type are supported. + +### Schema merging + +Like ProtocolBuffer, Avro, and Thrift, Parquet also supports schema evolution. Users can start with +a simple schema, and gradually add more columns to the schema as needed. In this way, users may end +up with multiple Parquet files with different but mutually compatible schemas. The Parquet data +source is now able to automatically detect this case and merge schemas of all these files. + +
+ +
+ +{% highlight scala %} +// sqlContext from the previous example is used in this example. +// This is used to implicitly convert an RDD to a DataFrame. +import sqlContext.implicits._ + +// Create a simple DataFrame, stored into a partition directory +val df1 = sparkContext.makeRDD(1 to 5).map(i => (i, i * 2)).toDF("single", "double") +df1.saveAsParquetFile("data/test_table/key=1") + +// Create another DataFrame in a new partition directory, +// adding a new column and dropping an existing column +val df2 = sparkContext.makeRDD(6 to 10).map(i => (i, i * 3)).toDF("single", "triple") +df2.saveAsParquetFile("data/test_table/key=2") + +// Read the partitioned table +val df3 = sqlContext.parquetFile("data/test_table") +df3.printSchema() + +// The final schema consists of all 3 columns in the Parquet files together +// with the partiioning column appeared in the partition directory paths. +// root +// |-- single: int (nullable = true) +// |-- double: int (nullable = true) +// |-- triple: int (nullable = true) +// |-- key : int (nullable = true) +{% endhighlight %} + +
+ +
+ +{% highlight python %} +# sqlContext from the previous example is used in this example. + +# Create a simple DataFrame, stored into a partition directory +df1 = sqlContext.createDataFrame(sc.parallelize(range(1, 6))\ + .map(lambda i: Row(single=i, double=i * 2))) +df1.save("data/test_table/key=1", "parquet") + +# Create another DataFrame in a new partition directory, +# adding a new column and dropping an existing column +df2 = sqlContext.createDataFrame(sc.parallelize(range(6, 11)) + .map(lambda i: Row(single=i, triple=i * 3))) +df2.save("data/test_table/key=2", "parquet") + +# Read the partitioned table +df3 = sqlContext.parquetFile("data/test_table") +df3.printSchema() + +# The final schema consists of all 3 columns in the Parquet files together +# with the partiioning column appeared in the partition directory paths. +# root +# |-- single: int (nullable = true) +# |-- double: int (nullable = true) +# |-- triple: int (nullable = true) +# |-- key : int (nullable = true) +{% endhighlight %} + +
+
### Configuration -Configuration of Parquet can be done using the `setConf` method on SQLContext or by running +Configuration of Parquet can be done using the `setConf` method on `SQLContext` or by running `SET key=value` commands using SQL. @@ -580,6 +1050,15 @@ Configuration of Parquet can be done using the `setConf` method on SQLContext or flag tells Spark SQL to interpret binary data as a string to provide compatibility with these systems. + + + + + @@ -619,8 +1098,8 @@ Configuration of Parquet can be done using the `setConf` method on SQLContext or
-Spark SQL can automatically infer the schema of a JSON dataset and load it as a SchemaRDD. -This conversion can be done using one of two methods in a SQLContext: +Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. +This conversion can be done using one of two methods in a `SQLContext`: * `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. * `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. @@ -636,7 +1115,7 @@ val sqlContext = new org.apache.spark.sql.SQLContext(sc) // A JSON dataset is pointed to by path. // The path can be either a single text file or a directory storing text files. val path = "examples/src/main/resources/people.json" -// Create a SchemaRDD from the file(s) pointed to by path +// Create a DataFrame from the file(s) pointed to by path val people = sqlContext.jsonFile(path) // The inferred schema can be visualized using the printSchema() method. @@ -645,13 +1124,13 @@ people.printSchema() // |-- age: integer (nullable = true) // |-- name: string (nullable = true) -// Register this SchemaRDD as a table. +// Register this DataFrame as a table. people.registerTempTable("people") // SQL statements can be run by using the sql methods provided by sqlContext. val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") -// Alternatively, a SchemaRDD can be created for a JSON dataset represented by +// Alternatively, a DataFrame can be created for a JSON dataset represented by // an RDD[String] storing one JSON object per string. val anotherPeopleRDD = sc.parallelize( """{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}""" :: Nil) @@ -661,8 +1140,8 @@ val anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD)
-Spark SQL can automatically infer the schema of a JSON dataset and load it as a JavaSchemaRDD. -This conversion can be done using one of two methods in a JavaSQLContext : +Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. +This conversion can be done using one of two methods in a `SQLContext` : * `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. * `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. @@ -673,13 +1152,13 @@ a regular multi-line JSON file will most often fail. {% highlight java %} // sc is an existing JavaSparkContext. -JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); // A JSON dataset is pointed to by path. // The path can be either a single text file or a directory storing text files. String path = "examples/src/main/resources/people.json"; -// Create a JavaSchemaRDD from the file(s) pointed to by path -JavaSchemaRDD people = sqlContext.jsonFile(path); +// Create a DataFrame from the file(s) pointed to by path +DataFrame people = sqlContext.jsonFile(path); // The inferred schema can be visualized using the printSchema() method. people.printSchema(); @@ -687,24 +1166,24 @@ people.printSchema(); // |-- age: integer (nullable = true) // |-- name: string (nullable = true) -// Register this JavaSchemaRDD as a table. +// Register this DataFrame as a table. people.registerTempTable("people"); // SQL statements can be run by using the sql methods provided by sqlContext. -JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); +DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); -// Alternatively, a JavaSchemaRDD can be created for a JSON dataset represented by +// Alternatively, a DataFrame can be created for a JSON dataset represented by // an RDD[String] storing one JSON object per string. List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = sc.parallelize(jsonData); -JavaSchemaRDD anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD); +DataFrame anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD); {% endhighlight %}
-Spark SQL can automatically infer the schema of a JSON dataset and load it as a SchemaRDD. -This conversion can be done using one of two methods in a SQLContext: +Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. +This conversion can be done using one of two methods in a `SQLContext`: * `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. * `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. @@ -721,7 +1200,7 @@ sqlContext = SQLContext(sc) # A JSON dataset is pointed to by path. # The path can be either a single text file or a directory storing text files. path = "examples/src/main/resources/people.json" -# Create a SchemaRDD from the file(s) pointed to by path +# Create a DataFrame from the file(s) pointed to by path people = sqlContext.jsonFile(path) # The inferred schema can be visualized using the printSchema() method. @@ -730,13 +1209,13 @@ people.printSchema() # |-- age: integer (nullable = true) # |-- name: string (nullable = true) -# Register this SchemaRDD as a table. +# Register this DataFrame as a table. people.registerTempTable("people") -# SQL statements can be run by using the sql methods provided by sqlContext. +# SQL statements can be run by using the sql methods provided by `sqlContext`. teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") -# Alternatively, a SchemaRDD can be created for a JSON dataset represented by +# Alternatively, a DataFrame can be created for a JSON dataset represented by # an RDD[String] storing one JSON object per string. anotherPeopleRDD = sc.parallelize([ '{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}']) @@ -744,6 +1223,22 @@ anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD) {% endhighlight %}
+
+ +{% highlight sql %} + +CREATE TEMPORARY TABLE jsonTable +USING org.apache.spark.sql.json +OPTIONS ( + path "examples/src/main/resources/people.json" +) + +SELECT * FROM jsonTable + +{% endhighlight %} + +
+
## Hive Tables @@ -763,7 +1258,7 @@ Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and adds support for finding tables in the MetaStore and writing queries using HiveQL. Users who do -not have an existing Hive deployment can still create a HiveContext. When not configured by the +not have an existing Hive deployment can still create a `HiveContext`. When not configured by the hive-site.xml, the context automatically creates `metastore_db` and `warehouse` in the current directory. @@ -782,14 +1277,14 @@ sqlContext.sql("FROM src SELECT key, value").collect().foreach(println)
-When working with Hive one must construct a `JavaHiveContext`, which inherits from `JavaSQLContext`, and +When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and adds support for finding tables in the MetaStore and writing queries using HiveQL. In addition to -the `sql` method a `JavaHiveContext` also provides an `hql` methods, which allows queries to be +the `sql` method a `HiveContext` also provides an `hql` methods, which allows queries to be expressed in HiveQL. {% highlight java %} // sc is an existing JavaSparkContext. -JavaHiveContext sqlContext = new org.apache.spark.sql.hive.api.java.HiveContext(sc); +HiveContext sqlContext = new org.apache.spark.sql.hive.HiveContext(sc); sqlContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)"); sqlContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src"); @@ -824,6 +1319,121 @@ results = sqlContext.sql("FROM src SELECT key, value").collect()
+## JDBC To Other Databases + +Spark SQL also includes a data source that can read data from other databases using JDBC. This +functionality should be preferred over using [JdbcRDD](api/scala/index.html#org.apache.spark.rdd.JdbcRDD). +This is because the results are returned +as a DataFrame and they can easily be processed in Spark SQL or joined with other data sources. +The JDBC data source is also easier to use from Java or Python as it does not require the user to +provide a ClassTag. +(Note that this is different than the Spark SQL JDBC server, which allows other applications to +run queries using Spark SQL). + +To get started you will need to include the JDBC driver for you particular database on the +spark classpath. For example, to connect to postgres from the Spark Shell you would run the +following command: + +{% highlight bash %} +SPARK_CLASSPATH=postgresql-9.3-1102-jdbc41.jar bin/spark-shell +{% endhighlight %} + +Tables from the remote database can be loaded as a DataFrame or Spark SQL Temporary table using +the Data Sources API. The following options are supported: + +
spark.sql.parquet.int96AsTimestamptrue + Some Parquet-producing systems, in particular Impala, store Timestamp into INT96. Spark would also + store Timestamp as INT96 because we need to avoid precision lost of the nanoseconds field. This + flag tells Spark SQL to interpret INT96 data as a timestamp to provide compatibility with these systems. +
spark.sql.parquet.cacheMetadata true
+ + + + + + + + + + + + + + + + + + +
Property NameMeaning
url + The JDBC URL to connect to. +
dbtable + The JDBC table that should be read. Note that anything that is valid in a `FROM` clause of + a SQL query can be used. For example, instead of a full table you could also use a + subquery in parentheses. +
driver + The class name of the JDBC driver needed to connect to this URL. This class with be loaded + on the master and workers before running an JDBC commands to allow the driver to + register itself with the JDBC subsystem. +
partitionColumn, lowerBound, upperBound, numPartitions + These options must all be specified if any of them is specified. They describe how to + partition the table when reading in parallel from multiple workers. + partitionColumn must be a numeric column from the table in question. +
+ +
+ +
+ +{% highlight scala %} +val jdbcDF = sqlContext.load("jdbc", Map( + "url" -> "jdbc:postgresql:dbserver", + "dbtable" -> "schema.tablename")) +{% endhighlight %} + +
+ +
+ +{% highlight java %} + +Map options = new HashMap(); +options.put("url", "jdbc:postgresql:dbserver"); +options.put("dbtable", "schema.tablename"); + +DataFrame jdbcDF = sqlContext.load("jdbc", options) +{% endhighlight %} + + +
+ +
+ +{% highlight python %} + +df = sqlContext.load(source="jdbc", url="jdbc:postgresql:dbserver", dbtable="schema.tablename") + +{% endhighlight %} + +
+ +
+ +{% highlight sql %} + +CREATE TEMPORARY TABLE jdbcTable +USING org.apache.spark.sql.jdbc +OPTIONS ( + url "jdbc:postgresql:dbserver", + dbtable "schema.tablename" +) + +{% endhighlight %} + +
+
+ +## Troubleshooting + + * The JDBC driver class must be visible to the primordial class loader on the client session and on all executors. This is because Java's DriverManager class does a security check that results in it ignoring all drivers not visible to the primordial class loader when one goes to open a connection. One convenient way to do this is to modify compute_classpath.sh on all worker nodes to include your driver JARs. + * Some databases, such as H2, convert all names to upper case. You'll need to use upper case to refer to those names in Spark SQL. + + # Performance Tuning For some workloads it is possible to improve performance by either caching data in memory, or by @@ -831,11 +1441,11 @@ turning on some experimental options. ## Caching Data In Memory -Spark SQL can cache tables using an in-memory columnar format by calling `sqlContext.cacheTable("tableName")` or `schemaRDD.cache()`. +Spark SQL can cache tables using an in-memory columnar format by calling `sqlContext.cacheTable("tableName")` or `dataFrame.cache()`. Then Spark SQL will scan only required columns and will automatically tune compression to minimize memory usage and GC pressure. You can call `sqlContext.uncacheTable("tableName")` to remove the table from memory. -Configuration of in-memory caching can be done using the `setConf` method on SQLContext or by running +Configuration of in-memory caching can be done using the `setConf` method on `SQLContext` or by running `SET key=value` commands using SQL. @@ -894,15 +1504,14 @@ that these options will be deprecated in future release as more optimizations ar
-# Other SQL Interfaces +# Distributed SQL Engine -Spark SQL also supports interfaces for running SQL queries directly without the need to write any -code. +Spark SQL can also act as a distributed query engine using its JDBC/ODBC or command-line interface. In this mode, end-users or applications can interact with Spark SQL directly to run SQL queries, without the need to write any code. ## Running the Thrift JDBC/ODBC server The Thrift JDBC/ODBC server implemented here corresponds to the [`HiveServer2`](https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2) -in Hive 0.12. You can test the JDBC server with the beeline script that comes with either Spark or Hive 0.12. +in Hive 0.13. You can test the JDBC server with the beeline script that comes with either Spark or Hive 0.13. To start the JDBC/ODBC server, run the following in the Spark directory: @@ -947,10 +1556,10 @@ Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. You may also use the beeline script that comes with Hive. -Thrift JDBC server also supports sending thrift RPC messages over HTTP transport. -Use the following setting to enable HTTP mode as system property or in `hive-site.xml` file in `conf/`: +Thrift JDBC server also supports sending thrift RPC messages over HTTP transport. +Use the following setting to enable HTTP mode as system property or in `hive-site.xml` file in `conf/`: - hive.server2.transport.mode - Set this to value: http + hive.server2.transport.mode - Set this to value: http hive.server2.thrift.http.port - HTTP port number fo listen on; default is 10001 hive.server2.http.endpoint - HTTP endpoint; default is cliservice @@ -972,7 +1581,88 @@ Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. You may run `./bin/spark-sql --help` for a complete list of all available options. -# Compatibility with Other Systems +# Migration Guide + +## Upgrading from Spark SQL 1.0-1.2 to 1.3 + +In Spark 1.3 we removed the "Alpha" label from Spark SQL and as part of this did a cleanup of the +available APIs. From Spark 1.3 onwards, Spark SQL will provide binary compatibility with other +releases in the 1.X series. This compatibility guarantee excludes APIs that are explicitly marked +as unstable (i.e., DeveloperAPI or Experimental). + +#### Rename of SchemaRDD to DataFrame + +The largest change that users will notice when upgrading to Spark SQL 1.3 is that `SchemaRDD` has +been renamed to `DataFrame`. This is primarily because DataFrames no longer inherit from RDD +directly, but instead provide most of the functionality that RDDs provide though their own +implementation. DataFrames can still be converted to RDDs by calling the `.rdd` method. + +In Scala there is a type alias from `SchemaRDD` to `DataFrame` to provide source compatibility for +some use cases. It is still recommended that users update their code to use `DataFrame` instead. +Java and Python users will need to update their code. + +#### Unification of the Java and Scala APIs + +Prior to Spark 1.3 there were separate Java compatible classes (`JavaSQLContext` and `JavaSchemaRDD`) +that mirrored the Scala API. In Spark 1.3 the Java API and Scala API have been unified. Users +of either language should use `SQLContext` and `DataFrame`. In general theses classes try to +use types that are usable from both languages (i.e. `Array` instead of language specific collections). +In some cases where no common type exists (e.g., for passing in closures or Maps) function overloading +is used instead. + +Additionally the Java specific types API has been removed. Users of both Scala and Java should +use the classes present in `org.apache.spark.sql.types` to describe schema programmatically. + + +#### Isolation of Implicit Conversions and Removal of dsl Package (Scala-only) + +Many of the code examples prior to Spark 1.3 started with `import sqlContext._`, which brought +all of the functions from sqlContext into scope. In Spark 1.3 we have isolated the implicit +conversions for converting `RDD`s into `DataFrame`s into an object inside of the `SQLContext`. +Users should now write `import sqlContext.implicits._`. + +Additionally, the implicit conversions now only augment RDDs that are composed of `Product`s (i.e., +case classes or tuples) with a method `toDF`, instead of applying automatically. + +When using function inside of the DSL (now replaced with the `DataFrame` API) users used to import +`org.apache.spark.sql.catalyst.dsl`. Instead the public dataframe functions API should be used: +`import org.apache.spark.sql.functions._`. + +#### Removal of the type aliases in org.apache.spark.sql for DataType (Scala-only) + +Spark 1.3 removes the type aliases that were present in the base sql package for `DataType`. Users +should instead import the classes in `org.apache.spark.sql.types` + +#### UDF Registration Moved to `sqlContext.udf` (Java & Scala) + +Functions that are used to register UDFs, either for use in the DataFrame DSL or SQL, have been +moved into the udf object in `SQLContext`. + +
+
+{% highlight java %} + +sqlContext.udf.register("strLen", (s: String) => s.length()) + +{% endhighlight %} +
+ +
+{% highlight java %} + +sqlContext.udf().register("strLen", (String s) -> { s.length(); }); + +{% endhighlight %} +
+ +
+ +Python UDF registration is unchanged. + +#### Python DataTypes No Longer Singletons + +When using DataTypes in Python you will need to construct them (i.e. `StringType()`) instead of +referencing a singleton. ## Migration Guide for Shark User @@ -1092,15 +1782,11 @@ in Hive deployments. * Tables with buckets: bucket is the hash partitioning within a Hive table partition. Spark SQL doesn't support buckets yet. + **Esoteric Hive Features** -* Tables with partitions using different input formats: In Spark SQL, all table partitions need to - have the same input format. -* Non-equi outer join: For the uncommon use case of using outer joins with non-equi join conditions - (e.g. condition "`key < 10`"), Spark SQL will output wrong result for the `NULL` tuple. -* `UNION` type and `DATE` type +* `UNION` type * Unique join -* Single query multi insert * Column statistics collecting: Spark SQL does not piggyback scans to collect column statistics at the moment and only supports populating the sizeInBytes field of the hive metastore. @@ -1116,9 +1802,6 @@ less important due to Spark SQL's in-memory computational model. Others are slot releases of Spark SQL. * Block level bitmap indexes and virtual columns (used to build indexes) -* Automatically convert a join to map join: For joining a large table with multiple small tables, - Hive automatically converts the join into a map join. We are adding this auto conversion in the - next release. * Automatically determine the number of reducers for joins and groupbys: Currently in Spark SQL, you need to control the degree of parallelism post-shuffle using "`SET spark.sql.shuffle.partitions=[num_tasks];`". * Meta-data only query: For queries that can be answered by using only meta data, Spark SQL still @@ -1129,33 +1812,10 @@ releases of Spark SQL. Hive can optionally merge the small files into fewer large files to avoid overflowing the HDFS metadata. Spark SQL does not support that. -# Writing Language-Integrated Relational Queries -**Language-Integrated queries are experimental and currently only supported in Scala.** +# Data Types -Spark SQL also supports a domain specific language for writing queries. Once again, -using the data from the above examples: - -{% highlight scala %} -// sc is an existing SparkContext. -val sqlContext = new org.apache.spark.sql.SQLContext(sc) -// Importing the SQL context gives access to all the public SQL functions and implicit conversions. -import sqlContext._ -val people: RDD[Person] = ... // An RDD of case class objects, from the first example. - -// The following is the same as 'SELECT name FROM people WHERE age >= 10 AND age <= 19' -val teenagers = people.where('age >= 10).where('age <= 19).select('name) -teenagers.map(t => "Name: " + t(0)).collect().foreach(println) -{% endhighlight %} - -The DSL uses Scala symbols to represent columns in the underlying table, which are identifiers -prefixed with a tick (`'`). Implicit conversions turn these symbols into expressions that are -evaluated by the SQL execution engine. A full list of the functions supported can be found in the -[ScalaDoc](api/scala/index.html#org.apache.spark.sql.SchemaRDD). - - - -# Spark SQL DataType Reference +Spark SQL and DataFrames support the following data types: * Numeric types - `ByteType`: Represents 1-byte signed integer numbers. @@ -1198,10 +1858,10 @@ evaluated by the SQL execution engine. A full list of the functions supported c
-All data types of Spark SQL are located in the package `org.apache.spark.sql`. +All data types of Spark SQL are located in the package `org.apache.spark.sql.types`. You can access them by doing {% highlight scala %} -import org.apache.spark.sql._ +import org.apache.spark.sql.types._ {% endhighlight %} @@ -1253,7 +1913,7 @@ import org.apache.spark.sql._ - + @@ -1447,7 +2107,7 @@ please use factory methods provided in - +
DecimalType scala.math.BigDecimal java.math.BigDecimal DecimalType
StructType org.apache.spark.sql.api.java.Row org.apache.spark.sql.Row DataTypes.createStructType(fields)
Note: fields is a List or an array of StructFields. @@ -1468,10 +2128,10 @@ please use factory methods provided in
-All data types of Spark SQL are located in the package of `pyspark.sql`. +All data types of Spark SQL are located in the package of `pyspark.sql.types`. You can access them by doing {% highlight python %} -from pyspark.sql import * +from pyspark.sql.types import * {% endhighlight %} diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index ac01dd3d8019a..c8ab146bcae0a 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -5,6 +5,8 @@ title: Spark Streaming + Flume Integration Guide [Apache Flume](https://flume.apache.org/) is a distributed, reliable, and available service for efficiently collecting, aggregating, and moving large amounts of log data. Here we explain how to configure Flume and Spark Streaming to receive data from Flume. There are two approaches to this. +Python API Flume is not yet available in the Python API. + ## Approach 1: Flume-style Push-based Approach Flume is designed to push data between Flume agents. In this approach, Spark Streaming essentially sets up a receiver that acts an Avro agent for Flume, to which Flume can push the data. Here are the configuration steps. @@ -64,7 +66,7 @@ configuring Flume agents. 3. **Deploying:** Package `spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). -## Approach 2 (Experimental): Pull-based Approach using a Custom Sink +## Approach 2: Pull-based Approach using a Custom Sink Instead of Flume pushing data directly to Spark Streaming, this approach runs a custom Flume sink that allows the following. - Flume pushes data into the sink, and the data stays buffered. diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 77c0abbbacbd0..64714f0b799fc 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -2,58 +2,155 @@ layout: global title: Spark Streaming + Kafka Integration Guide --- -[Apache Kafka](http://kafka.apache.org/) is publish-subscribe messaging rethought as a distributed, partitioned, replicated commit log service. Here we explain how to configure Spark Streaming to receive data from Kafka. +[Apache Kafka](http://kafka.apache.org/) is publish-subscribe messaging rethought as a distributed, partitioned, replicated commit log service. Here we explain how to configure Spark Streaming to receive data from Kafka. There are two approaches to this - the old approach using Receivers and Kafka's high-level API, and a new experimental approach (introduced in Spark 1.3) without using Receivers. They have different programming models, performance characteristics, and semantics guarantees, so read on for more details. -1. **Linking:** In your SBT/Maven project definition, link your streaming application against the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). +## Approach 1: Receiver-based Approach +This approach uses a Receiver to receive the data. The Received is implemented using the Kafka high-level consumer API. As with all receivers, the data received from Kafka through a Receiver is stored in Spark executors, and then jobs launched by Spark Streaming processes the data. + +However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming. To ensure zero data loss, enable the Write Ahead Logs (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs. + +Next, we discuss how to use this approach in your streaming application. + +1. **Linking:** For Scala/Java applications using SBT/Maven project definitions, link your streaming application with the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). groupId = org.apache.spark artifactId = spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}} version = {{site.SPARK_VERSION_SHORT}} -2. **Programming:** In the streaming application code, import `KafkaUtils` and create input DStream as follows. + For Python applications, you will have to add this above library and its dependencies when deploying your application. See the *Deploying* subsection below. + +2. **Programming:** In the streaming application code, import `KafkaUtils` and create an input DStream as follows.
import org.apache.spark.streaming.kafka._ - val kafkaStream = KafkaUtils.createStream( - streamingContext, [zookeeperQuorum], [group id of the consumer], [per-topic number of Kafka partitions to consume]) + val kafkaStream = KafkaUtils.createStream(streamingContext, + [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]) - See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) + You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala).
import org.apache.spark.streaming.kafka.*; - JavaPairReceiverInputDStream kafkaStream = KafkaUtils.createStream( - streamingContext, [zookeeperQuorum], [group id of the consumer], [per-topic number of Kafka partitions to consume]); + JavaPairReceiverInputDStream kafkaStream = + KafkaUtils.createStream(streamingContext, + [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]); - See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) + You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java). + +
+
+ from pyspark.streaming.kafka import KafkaUtils + + kafkaStream = KafkaUtils.createStream(streamingContext, \ + [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]) + + By default, the Python API will decode Kafka data as UTF8 encoded strings. You can specify your custom decoding function to decode the byte arrays in Kafka records to any arbitrary data type. See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/kafka_wordcount.py).
- *Points to remember:* + **Points to remember:** - Topic partitions in Kafka does not correlate to partitions of RDDs generated in Spark Streaming. So increasing the number of topic-specific partitions in the `KafkaUtils.createStream()` only increases the number of threads using which topics that are consumed within a single receiver. It does not increase the parallelism of Spark in processing the data. Refer to the main document for more information on that. - Multiple Kafka input DStreams can be created with different groups and topics for parallel receiving of data using multiple receivers. -3. **Deploying:** Package `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). - -Note that the Kafka receiver used by default is an -[*unreliable* receiver](streaming-programming-guide.html#receiver-reliability) section in the -programming guide). In Spark 1.2, we have added an experimental *reliable* Kafka receiver that -provides stronger -[fault-tolerance guarantees](streaming-programming-guide.html#fault-tolerance-semantics) of zero -data loss on failures. This receiver is automatically used when the write ahead log -(also introduced in Spark 1.2) is enabled -(see [Deployment](#deploying-applications.html) section in the programming guide). This -may reduce the receiving throughput of individual Kafka receivers compared to the unreliable -receivers, but this can be corrected by running -[more receivers in parallel](streaming-programming-guide.html#level-of-parallelism-in-data-receiving) -to increase aggregate throughput. Additionally, it is recommended that the replication of the -received data within Spark be disabled when the write ahead log is enabled as the log is already stored -in a replicated storage system. This can be done by setting the storage level for the input -stream to `StorageLevel.MEMORY_AND_DISK_SER` (that is, use + - If you have enabled Write Ahead Logs with a replicated file system like HDFS, the received data is already being replicated in the log. Hence, the storage level in storage level for the input stream to `StorageLevel.MEMORY_AND_DISK_SER` (that is, use `KafkaUtils.createStream(..., StorageLevel.MEMORY_AND_DISK_SER)`). + +3. **Deploying:** As with any Spark applications, `spark-submit` is used to launch your application. However, the details are slightly different for Scala/Java applications and Python applications. + + For Scala and Java applications, if you are using SBT or Maven for project management, then package `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). + + For Python applications which lack SBT/Maven project management, `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies can be directly added to `spark-submit` using `--packages` (see [Application Submission Guide](submitting-applications.html)). That is, + + ./bin/spark-submit --packages org.apache.spark:spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} ... + + Alternatively, you can also download the JAR of the Maven artifact `spark-streaming-kafka-assembly` from the + [Maven repository](http://search.maven.org/#search|ga|1|a%3A%22spark-streaming-kafka-assembly_2.10%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) and add it to `spark-submit` with `--jars`. + +## Approach 2: Direct Approach (No Receivers) +This is a new receiver-less "direct" approach has been introduced in Spark 1.3 to ensure stronger end-to-end guarantees. Instead of using receivers to receive data, this approach periodically queries Kafka for the latest offsets in each topic+partition, and accordingly defines the offset ranges to process in each batch. When the jobs to process the data are launched, Kafka's simple consumer API is used to read the defined ranges of offsets from Kafka (similar to read files from a file system). Note that this is an experimental feature in Spark 1.3 and is only available in the Scala and Java API. + +This approach has the following advantages over the received-based approach (i.e. Approach 1). + +- *Simplified Parallelism:* No need to create multiple input Kafka streams and union-ing them. With `directStream`, Spark Streaming will create as many RDD partitions as there is Kafka partitions to consume, which will all read data from Kafka in parallel. So there is one-to-one mapping between Kafka and RDD partitions, which is easier to understand and tune. + +- *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write Ahead Log. This second approach eliminate the problem as there is no receiver, and hence no need for Write Ahead Logs. + +- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper and offsets tracked only by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. + +Note that one disadvantage of this approach is that it does not update offsets in Zookeeper, hence Zookeeper-based Kafka monitoring tools will not show progress. However, you can access the offsets processed by this approach in each batch and update Zookeeper yourself (see below). + +Next, we discuss how to use this approach in your streaming application. + +1. **Linking:** This approach is supported only in Scala/Java application. Link your SBT/Maven project with the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). + + groupId = org.apache.spark + artifactId = spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}} + version = {{site.SPARK_VERSION_SHORT}} + +2. **Programming:** In the streaming application code, import `KafkaUtils` and create an input DStream as follows. + +
+
+ import org.apache.spark.streaming.kafka._ + + val directKafkaStream = KafkaUtils.createDirectStream[ + [key class], [value class], [key decoder class], [value decoder class] ]( + streamingContext, [map of Kafka parameters], [set of topics to consume]) + + See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala). +
+
+ import org.apache.spark.streaming.kafka.*; + + JavaPairReceiverInputDStream directKafkaStream = + KafkaUtils.createDirectStream(streamingContext, + [key class], [value class], [key decoder class], [value decoder class], + [map of Kafka parameters], [set of topics to consume]); + + See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java). + +
+
+ + In the Kafka parameters, you must specify either `metadata.broker.list` or `bootstrap.servers`. + By default, it will start consuming from the latest offset of each Kafka partition. If you set configuration `auto.offset.reset` in Kafka parameters to `smallest`, then it will start consuming from the smallest offset. + + You can also start consuming from any arbitrary offset using other variations of `KafkaUtils.createDirectStream`. Furthermore, if you want to access the Kafka offsets consumed in each batch, you can do the following. + +
+
+ directKafkaStream.foreachRDD { rdd => + val offsetRanges = rdd.asInstanceOf[HasOffsetRanges] + // offsetRanges.length = # of Kafka partitions being consumed + ... + } +
+
+ directKafkaStream.foreachRDD( + new Function, Void>() { + @Override + public Void call(JavaPairRDD rdd) throws IOException { + OffsetRange[] offsetRanges = ((HasOffsetRanges)rdd).offsetRanges + // offsetRanges.length = # of Kafka partitions being consumed + ... + return null; + } + } + ); +
+
+ + You can use this to update Zookeeper yourself if you want Zookeeper-based Kafka monitoring tools to show progress of the streaming application. + + Another thing to note is that since this approach does not use Receivers, the standard receiver-related (that is, [configurations](configuration.html) of the form `spark.streaming.receiver.*` ) will not apply to the input DStreams created by this approach (will apply to other input DStreams though). Instead, use the [configurations](configuration.html) `spark.streaming.kafka.*`. An important one is `spark.streaming.kafka.maxRatePerPartition` which is the maximum rate at which each Kafka partition will be read by this direct API. + +3. **Deploying:** Similar to the first approach, you can package `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR and the launch the application using `spark-submit`. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation. \ No newline at end of file diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index e37a2bb37b9a4..e2c6a1c180d7a 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1,6 +1,8 @@ --- layout: global -title: Spark Streaming Programming Guide +displayTitle: Spark Streaming Programming Guide +title: Spark Streaming +description: Spark Streaming programming guide and tutorial for Spark SPARK_VERSION_SHORT --- * This will become a table of contents (this text will be scraped). @@ -430,7 +432,7 @@ some of the common ones are as follows.
For an up-to-date list, please refer to the -[Apache repository](http://search.maven.org/#search%7Cga%7C1%7Cg%3A%22org.apache.spark%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) +[Maven repository](http://search.maven.org/#search%7Cga%7C1%7Cg%3A%22org.apache.spark%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) for the full list of supported sources and artifacts. *** @@ -660,8 +662,7 @@ methods for creating DStreams from files and Akka actors as input sources. For simple text files, there is an easier method `streamingContext.textFileStream(dataDirectory)`. And file streams do not require running a receiver, hence does not require allocating cores. - Python API As of Spark 1.2, - `fileStream` is not available in the Python API, only `textFileStream` is available. + Python API `fileStream` is not available in the Python API, only `textFileStream` is available. - **Streams based on Custom Actors:** DStreams can be created with data streams received through Akka actors by using `streamingContext.actorStream(actorProps, actor-name)`. See the [Custom Receiver @@ -680,8 +681,9 @@ for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.strea ### Advanced Sources {:.no_toc} -Python API As of Spark 1.2, -these sources are not available in the Python API. + +Python API As of Spark 1.3, +out of these sources, *only* Kafka is available in the Python API. We will add more advanced sources in the Python API in future. This category of sources require interfacing with external non-Spark libraries, some of them with complex dependencies (e.g., Kafka and Flume). Hence, to minimize issues related to version conflicts @@ -702,7 +704,7 @@ create a DStream using data from Twitter's stream of tweets, you have to do the {% highlight scala %} import org.apache.spark.streaming.twitter._ -TwitterUtils.createStream(ssc) +TwitterUtils.createStream(ssc, None) {% endhighlight %}
@@ -721,6 +723,12 @@ and it in the classpath. Some of these advanced sources are as follows. +- **Kafka:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kafka 0.8.1.1. See the [Kafka Integration Guide](streaming-kafka-integration.html) for more details. + +- **Flume:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Flume 1.4.0. See the [Flume Integration Guide](streaming-flume-integration.html) for more details. + +- **Kinesis:** See the [Kinesis Integration Guide](streaming-kinesis-integration.html) for more details. + - **Twitter:** Spark Streaming's TwitterUtils uses Twitter4j 3.0.3 to get the public stream of tweets using [Twitter's Streaming API](https://dev.twitter.com/docs/streaming-apis). Authentication information can be provided by any of the [methods](http://twitter4j.org/en/configuration.html) supported by @@ -730,17 +738,10 @@ Some of these advanced sources are as follows. ([TwitterPopularTags]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala) and [TwitterAlgebirdCMS]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala)). -- **Flume:** Spark Streaming {{site.SPARK_VERSION_SHORT}} can received data from Flume 1.4.0. See the [Flume Integration Guide](streaming-flume-integration.html) for more details. - -- **Kafka:** Spark Streaming {{site.SPARK_VERSION_SHORT}} can receive data from Kafka 0.8.0. See the [Kafka Integration Guide](streaming-kafka-integration.html) for more details. - -- **Kinesis:** See the [Kinesis Integration Guide](streaming-kinesis-integration.html) for more details. - ### Custom Sources {:.no_toc} -Python API As of Spark 1.2, -these sources are not available in the Python API. +Python API This is not yet supported in Python. Input DStreams can also be created out of custom data sources. All you have to do is implement an user-defined **receiver** (see next section to understand what that is) that can receive data from @@ -844,7 +845,7 @@ Some of the common ones are as follows.
-The last two transformations are worth highlighting again. +A few of these transformations are worth discussing in more detail. #### UpdateStateByKey Operation {:.no_toc} @@ -876,6 +877,12 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi val runningCounts = pairs.updateStateByKey[Int](updateFunction _) {% endhighlight %} +The update function will be called for each word, with `newValues` having a sequence of 1's (from +the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete +Scala code, take a look at the example +[StatefulNetworkWordCount.scala]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache +/spark/examples/streaming/StatefulNetworkWordCount.scala). +
@@ -897,6 +904,12 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi JavaPairDStream runningCounts = pairs.updateStateByKey(updateFunction); {% endhighlight %} +The update function will be called for each word, with `newValues` having a sequence of 1's (from +the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete +Java code, take a look at the example +[JavaStatefulNetworkWordCount.java]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming +/JavaStatefulNetworkWordCount.java). +
@@ -914,14 +927,14 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi runningCounts = pairs.updateStateByKey(updateFunction) {% endhighlight %} -
-
- The update function will be called for each word, with `newValues` having a sequence of 1's (from the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete -Scala code, take a look at the example +Python code, take a look at the example [stateful_network_wordcount.py]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/stateful_network_wordcount.py). +
+ + Note that using `updateStateByKey` requires the checkpoint directory to be configured, which is discussed in detail in the [checkpointing](#checkpointing) section. @@ -983,7 +996,7 @@ In fact, you can also use [machine learning](mllib-guide.html) and #### Window Operations {:.no_toc} -Finally, Spark Streaming also provides *windowed computations*, which allow you to apply +Spark Streaming also provides *windowed computations*, which allow you to apply transformations over a sliding window of data. This following figure illustrates this sliding window. @@ -1106,6 +1119,100 @@ said two parameters - windowLength and slideInterval. +#### Join Operations +{:.no_toc} +Finally, its worth highlighting how easily you can perform different kinds of joins in Spark Streaming. + + +##### Stream-stream joins +{:.no_toc} +Streams can be very easily joined with other streams. + +
+
+{% highlight scala %} +val stream1: DStream[String, String] = ... +val stream2: DStream[String, String] = ... +val joinedStream = stream1.join(stream2) +{% endhighlight %} +
+
+{% highlight java %} +JavaPairDStream stream1 = ... +JavaPairDStream stream2 = ... +JavaPairDStream joinedStream = stream1.join(stream2); +{% endhighlight %} +
+
+{% highlight python %} +stream1 = ... +stream2 = ... +joinedStream = stream1.join(stream2) +{% endhighlight %} +
+
+Here, in each batch interval, the RDD generated by `stream1` will be joined with the RDD generated by `stream2`. You can also do `leftOuterJoin`, `rightOuterJoin`, `fullOuterJoin`. Furthermore, it is often very useful to do joins over windows of the streams. That is pretty easy as well. + +
+
+{% highlight scala %} +val windowedStream1 = stream1.window(Seconds(20)) +val windowedStream2 = stream2.window(Minutes(1)) +val joinedStream = windowedStream1.join(windowedStream2) +{% endhighlight %} +
+
+{% highlight java %} +JavaPairDStream windowedStream1 = stream1.window(Durations.seconds(20)); +JavaPairDStream windowedStream2 = stream2.window(Durations.minutes(1)); +JavaPairDStream joinedStream = windowedStream1.join(windowedStream2); +{% endhighlight %} +
+
+{% highlight python %} +windowedStream1 = stream1.window(20) +windowedStream2 = stream2.window(60) +joinedStream = windowedStream1.join(windowedStream2) +{% endhighlight %} +
+
+ +##### Stream-dataset joins +{:.no_toc} +This has already been shown earlier while explain `DStream.transform` operation. Here is yet another example of joining a windowed stream with a dataset. + +
+
+{% highlight scala %} +val dataset: RDD[String, String] = ... +val windowedStream = stream.window(Seconds(20))... +val joinedStream = windowedStream.transform { rdd => rdd.join(dataset) } +{% endhighlight %} +
+
+{% highlight java %} +JavaPairRDD dataset = ... +JavaPairDStream windowedStream = stream.window(Durations.seconds(20)); +JavaPairDStream joinedStream = windowedStream.transform( + new Function>, JavaRDD>>() { + @Override + public JavaRDD> call(JavaRDD> rdd) { + return rdd.join(dataset); + } + } +); +{% endhighlight %} +
+
+{% highlight python %} +dataset = ... # some RDD +windowedStream = stream.window(20) +joinedStream = windowedStream.transform(lambda rdd: rdd.join(dataset)) +{% endhighlight %} +
+
+ +In fact, you can also dynamically change the dataset you want to join against. The function provided to `transform` is evaluated every batch interval and therefore will use the current dataset that `dataset` reference points to. The complete list of DStream transformations is available in the API documentation. For the Scala API, see [DStream](api/scala/index.html#org.apache.spark.streaming.dstream.DStream) @@ -1313,6 +1420,178 @@ Note that the connections in the pool should be lazily created on demand and tim *** +## DataFrame and SQL Operations +You can easily use [DataFrames and SQL](sql-programming-guide.html) operations on streaming data. You have to create a SQLContext using the SparkContext that the StreamingContext is using. Furthermore this has to done such that it can be restarted on driver failures. This is done by creating a lazily instantiated singleton instance of SQLContext. This is shown in the following example. It modifies the earlier [word count example](#a-quick-example) to generate word counts using DataFrames and SQL. Each RDD is converted to a DataFrame, registered as a temporary table and then queried using SQL. + +
+
+{% highlight scala %} + +/** Lazily instantiated singleton instance of SQLContext */ +object SQLContextSingleton { + @transient private var instance: SQLContext = null + + // Instantiate SQLContext on demand + def getInstance(sparkContext: SparkContext): SQLContext = synchronized { + if (instance == null) { + instance = new SQLContext(sparkContext) + } + instance + } +} + +... + +/** Case class for converting RDD to DataFrame */ +case class Row(word: String) + +... + +/** DataFrame operations inside your streaming program */ + +val words: DStream[String] = ... + +words.foreachRDD { rdd => + + // Get the singleton instance of SQLContext + val sqlContext = SQLContextSingleton.getInstance(rdd.sparkContext) + import sqlContext.implicits._ + + // Convert RDD[String] to RDD[case class] to DataFrame + val wordsDataFrame = rdd.map(w => Row(w)).toDF() + + // Register as table + wordsDataFrame.registerTempTable("words") + + // Do word count on DataFrame using SQL and print it + val wordCountsDataFrame = + sqlContext.sql("select word, count(*) as total from words group by word") + wordCountsDataFrame.show() +} + +{% endhighlight %} + +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala). +
+
+{% highlight java %} + +/** Lazily instantiated singleton instance of SQLContext */ +class JavaSQLContextSingleton { + static private transient SQLContext instance = null; + static public SQLContext getInstance(SparkContext sparkContext) { + if (instance == null) { + instance = new SQLContext(sparkContext); + } + return instance; + } +} + +... + +/** Java Bean class for converting RDD to DataFrame */ +public class JavaRow implements java.io.Serializable { + private String word; + + public String getWord() { + return word; + } + + public void setWord(String word) { + this.word = word; + } +} + +... + +/** DataFrame operations inside your streaming program */ + +JavaDStream words = ... + +words.foreachRDD( + new Function2, Time, Void>() { + @Override + public Void call(JavaRDD rdd, Time time) { + SQLContext sqlContext = JavaSQLContextSingleton.getInstance(rdd.context()); + + // Convert RDD[String] to RDD[case class] to DataFrame + JavaRDD rowRDD = rdd.map(new Function() { + public JavaRow call(String word) { + JavaRow record = new JavaRow(); + record.setWord(word); + return record; + } + }); + DataFrame wordsDataFrame = sqlContext.createDataFrame(rowRDD, JavaRow.class); + + // Register as table + wordsDataFrame.registerTempTable("words"); + + // Do word count on table using SQL and print it + DataFrame wordCountsDataFrame = + sqlContext.sql("select word, count(*) as total from words group by word"); + wordCountsDataFrame.show(); + return null; + } + } +); +{% endhighlight %} + +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java). +
+
+{% highlight python %} + +# Lazily instantiated global instance of SQLContext +def getSqlContextInstance(sparkContext): + if ('sqlContextSingletonInstance' not in globals()): + globals()['sqlContextSingletonInstance'] = SQLContext(sparkContext) + return globals()['sqlContextSingletonInstance'] + +... + +# DataFrame operations inside your streaming program + +words = ... # DStream of strings + +def process(time, rdd): + print "========= %s =========" % str(time) + try: + # Get the singleton instance of SQLContext + sqlContext = getSqlContextInstance(rdd.context) + + # Convert RDD[String] to RDD[Row] to DataFrame + rowRdd = rdd.map(lambda w: Row(word=w)) + wordsDataFrame = sqlContext.createDataFrame(rowRdd) + + # Register as table + wordsDataFrame.registerTempTable("words") + + # Do word count on table using SQL and print it + wordCountsDataFrame = sqlContext.sql("select word, count(*) as total from words group by word") + wordCountsDataFrame.show() + except: + pass + +words.foreachRDD(process) +{% endhighlight %} + +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/sql_network_wordcount.py). + +
+
+ +You can also run SQL queries on tables defined on streaming data from a different thread (that is, asynchronous to the running StreamingContext). Just make sure that you set the StreamingContext to remember sufficient amount of streaming data such that query can run. Otherwise the StreamingContext, which is unaware of the any asynchronous SQL queries, will delete off old streaming data before the query can complete. For example, if you want to query the last batch, but your query can take 5 minutes to run, then call `streamingContext.remember(Minutes(5))` (in Scala, or equivalent in other languages). + +See the [DataFrames and SQL](sql-programming-guide.html) guide to learn more about DataFrames. + +*** + +## MLlib Operations +You can also easily use machine learning algorithms provided by [MLlib](mllib-guide.html). First of all, there are streaming machine learning algorithms (e.g. (Streaming Linear Regression](mllib-linear-methods.html#streaming-linear-regression), [Streaming KMeans](mllib-clustering.html#streaming-k-means), etc.) which can simultaneously learn from the streaming data as well as apply the model on the streaming data. Beyond these, for a much larger class of machine learning algorithms, you can learn a learning model offline (i.e. using historical data) and then apply the model online on streaming data. See the [MLlib](mllib-guide.html) guide for more details. + +*** + ## Caching / Persistence Similar to RDDs, DStreams also allow developers to persist the stream's data in memory. That is, using `persist()` method on a DStream will automatically persist every RDD of that DStream in @@ -1566,9 +1845,8 @@ To run a Spark Streaming applications, you need to have the following. + *Mesos* - [Marathon](https://github.com/mesosphere/marathon) has been used to achieve this with Mesos. - -- *[Experimental in Spark 1.2] Configuring write ahead logs* - In Spark 1.2, - we have introduced a new experimental feature of write ahead logs for achieving strong +- *[Since Spark 1.2] Configuring write ahead logs* - Since Spark 1.2, + we have introduced _write ahead logs_ for achieving strong fault-tolerance guarantees. If enabled, all the data received from a receiver gets written into a write ahead log in the configuration checkpoint directory. This prevents data loss on driver recovery, thus ensuring zero data loss (discussed in detail in the @@ -1654,7 +1932,7 @@ improve the performance of you application. At a high level, you need to conside 2. Setting the right batch size such that the batches of data can be processed as fast as they are received (that is, data processing keeps up with the data ingestion). -## Reducing the Processing Time of each Batch +## Reducing the Batch Processing Times There are a number of optimizations that can be done in Spark to minimize the processing time of each batch. These have been discussed in detail in [Tuning Guide](tuning.html). This section highlights some of the most important ones. @@ -1726,16 +2004,15 @@ documentation), or set the `spark.default.parallelism` ### Data Serialization {:.no_toc} -The overhead of data serialization can be significant, especially when sub-second batch sizes are - to be achieved. There are two aspects to it. +The overheads of data serialization can be reduce by tuning the serialization formats. In case of streaming, there are two types of data that are being serialized. -* **Serialization of RDD data in Spark**: Please refer to the detailed discussion on data - serialization in the [Tuning Guide](tuning.html). However, note that unlike Spark, by default - RDDs are persisted as serialized byte arrays to minimize pauses related to GC. +* **Input data**: By default, the input data received through Receivers is stored in the executors' memory with [StorageLevel.MEMORY_AND_DISK_SER_2](api/scala/index.html#org.apache.spark.storage.StorageLevel$). That is, the data is serialized into bytes to reduce GC overheads, and replicated for tolerating executor failures. Also, the data is kept first in memory, and spilled over to disk only if the memory is unsufficient to hold all the input data necessary for the streaming computation. This serialization obviously has overheads -- the receiver must deserialize the received data and re-serialize it using Spark's serialization format. -* **Serialization of input data**: To ingest external data into Spark, data received as bytes - (say, from the network) needs to deserialized from bytes and re-serialized into Spark's - serialization format. Hence, the deserialization overhead of input data may be a bottleneck. +* **Persisted RDDs generated by Streaming Operations**: RDDs generated by streaming computations may be persisted in memory. For example, window operation persist data in memory as they would be processed multiple times. However, unlike Spark, by default RDDs are persisted with [StorageLevel.MEMORY_ONLY_SER](api/scala/index.html#org.apache.spark.storage.StorageLevel$) (i.e. serialized) to minimize GC overheads. + +In both cases, using Kryo serialization can reduce both CPU and memory overheads. See the [Spark Tuning Guide](tuning.html#data-serialization)) for more details. Consider registering custom classes, and disabling object reference tracking for Kryo (see Kryo-related configurations in the [Configuration Guide](configuration.html#compression-and-serialization)). + +In specific cases where the amount of data that needs to be retained for the streaming application is not large, it may be feasible to persist data (both types) as deserialized objects without incurring excessive GC overheads. For example, if you are using batch intervals of few seconds and no window operations, then you can try disabling serialization in persisted data by explicitly setting the storage level accordingly. This would reduce the CPU overheads due to serialization, potentially improving performance without too much GC overheads. ### Task Launching Overheads {:.no_toc} @@ -1755,7 +2032,7 @@ thus allowing sub-second batch size to be viable. *** -## Setting the Right Batch Size +## Setting the Right Batch Interval For a Spark Streaming application running on a cluster to be stable, the system should be able to process data as fast as it is being received. In other words, batches of data should be processed as fast as they are being generated. Whether this is true for an application can be found by @@ -1787,40 +2064,40 @@ temporary data rate increases maybe fine as long as the delay reduces back to a ## Memory Tuning Tuning the memory usage and GC behavior of Spark applications have been discussed in great detail -in the [Tuning Guide](tuning.html). It is recommended that you read that. In this section, -we highlight a few customizations that are strongly recommended to minimize GC related pauses -in Spark Streaming applications and achieving more consistent batch processing times. - -* **Default persistence level of DStreams**: Unlike RDDs, the default persistence level of DStreams -serializes the data in memory (that is, -[StorageLevel.MEMORY_ONLY_SER](api/scala/index.html#org.apache.spark.storage.StorageLevel$) for -DStream compared to -[StorageLevel.MEMORY_ONLY](api/scala/index.html#org.apache.spark.storage.StorageLevel$) for RDDs). -Even though keeping the data serialized incurs higher serialization/deserialization overheads, -it significantly reduces GC pauses. - -* **Clearing persistent RDDs**: By default, all persistent RDDs generated by Spark Streaming will - be cleared from memory based on Spark's built-in policy (LRU). If `spark.cleaner.ttl` is set, - then persistent RDDs that are older than that value are periodically cleared. As mentioned - [earlier](#operation), this needs to be careful set based on operations used in the Spark - Streaming program. However, a smarter unpersisting of RDDs can be enabled by setting the - [configuration property](configuration.html#spark-properties) `spark.streaming.unpersist` to - `true`. This makes the system to figure out which RDDs are not necessary to be kept around and - unpersists them. This is likely to reduce - the RDD memory usage of Spark, potentially improving GC behavior as well. - -* **Concurrent garbage collector**: Using the concurrent mark-and-sweep GC further -minimizes the variability of GC pauses. Even though concurrent GC is known to reduce the +in the [Tuning Guide](tuning.html#memory-tuning). It is strongly recommended that you read that. In this section, we discuss a few tuning parameters specifically in the context of Spark Streaming applications. + +The amount of cluster memory required by a Spark Streaming application depends heavily on the type of transformations used. For example, if you want to use a window operation on last 10 minutes of data, then your cluster should have sufficient memory to hold 10 minutes of worth of data in memory. Or if you want to use `updateStateByKey` with a large number of keys, then the necessary memory will be high. On the contrary, if you want to do a simple map-filter-store operation, then necessary memory will be low. + +In general, since the data received through receivers are stored with StorageLevel.MEMORY_AND_DISK_SER_2, the data that does not fit in memory will spill over to the disk. This may reduce the performance of the streaming application, and hence it is advised to provide sufficient memory as required by your streaming application. Its best to try and see the memory usage on a small scale and estimate accordingly. + +Another aspect of memory tuning is garbage collection. For a streaming application that require low latency, it is undesirable to have large pauses caused by JVM Garbage Collection. + +There are a few parameters that can help you tune the memory usage and GC overheads. + +* **Persistence Level of DStreams**: As mentioned earlier in the [Data Serialization](#data-serialization) section, the input data and RDDs are by default persisted as serialized bytes. This reduces both, the memory usage and GC overheads, compared to deserialized persistence. Enabling Kryo serialization further reduces serialized sizes and memory usage. Further reduction in memory usage can be achieved with compression (see the Spark configuration `spark.rdd.compress`), at the cost of CPU time. + +* **Clearing old data**: By default, all input data and persisted RDDs generated by DStream transformations are automatically cleared. Spark Streaming decides when to clear the data based on the transformations that are used. For example, if you are using window operation of 10 minutes, then Spark Streaming will keep around last 10 minutes of data, and actively throw away older data. +Data can be retained for longer duration (e.g. interactively querying older data) by setting `streamingContext.remember`. + +* **CMS Garbage Collector**: Use of the concurrent mark-and-sweep GC is strongly recommended for keeping GC-related pauses consistently low. Even though concurrent GC is known to reduce the overall processing throughput of the system, its use is still recommended to achieve more -consistent batch processing times. +consistent batch processing times. Make sure you set the CMS GC on both the driver (using `--driver-java-options` in `spark-submit`) and the executors (using [Spark configuration](configuration.html#runtime-environment) `spark.executor.extraJavaOptions`). + +* **Other tips**: To further reduce GC overheads, here are some more tips to try. + - Use Tachyon for off-heap storage of persisted RDDs. See more detail in the [Spark Programming Guide](programming-guide.html#rdd-persistence). + - Use more executors with smaller heap sizes. This will reduce the GC pressure within each JVM heap. + *************************************************************************************************** *************************************************************************************************** # Fault-tolerance Semantics In this section, we will discuss the behavior of Spark Streaming applications in the event -of node failures. To understand this, let us remember the basic fault-tolerance semantics of -Spark's RDDs. +of failures. + +## Background +{:.no_toc} +To understand the semantics provided by Spark Streaming, let us remember the basic fault-tolerance semantics of Spark's RDDs. 1. An RDD is an immutable, deterministically re-computable, distributed dataset. Each RDD remembers the lineage of deterministic operations that were used on a fault-tolerant input @@ -1854,13 +2131,43 @@ Furthermore, there are two kinds of failures that we should be concerned about: With this basic knowledge, let us understand the fault-tolerance semantics of Spark Streaming. -## Semantics with files as input source +## Definitions +{:.no_toc} +The semantics of streaming systems are often captured in terms of how many times each record can be processed by the system. There are three types of guarantees that a system can provide under all possible operating conditions (despite failures, etc.) + +1. *At most once*: Each record will be either processed once or not processed at all. +2. *At least once*: Each record will be processed one or more times. This is stronger than *at-most once* as it ensure that no data will be lost. But there may be duplicates. +3. *Exactly once*: Each record will be processed exactly once - no data will be lost and no data will be processed multiple times. This is obviously the strongest guarantee of the three. + +## Basic Semantics +{:.no_toc} +In any stream processing system, broadly speaking, there are three steps in processing the data. + +1. *Receiving the data*: The data is received from sources using Receivers or otherwise. + +1. *Transforming the data*: The data received data is transformed using DStream and RDD transformations. + +1. *Pushing out the data*: The final transformed data is pushed out to external systems like file systems, databases, dashboards, etc. + +If a streaming application has to achieve end-to-end exactly-once guarantees, then each step has to provide exactly-once guarantee. That is, each record must be received exactly once, transformed exactly once, and pushed to downstream systems exactly once. Let's understand the semantics of these steps in the context of Spark Streaming. + +1. *Receiving the data*: Different input sources provided different guarantees. This is discussed in detail in the next subsection. + +1. *Transforming the data*: All data that has been received will be processed _exactly once_, thanks to the guarantees that RDDs provide. Even if there are failures, as long as the received input data is accessible, the final transformed RDDs will always have the same contents. + +1. *Pushing out the data*: Output operations by default ensure _at-least once_ semantics because it depends on the type of output operation (idempotent, or not) and the semantics of the downstream system (supports transactions or not). But users can implement their own transaction mechanisms to achieve _exactly-once_ semantics. This is discussed in more details later in the section. + +## Semantics of Received Data +{:.no_toc} +Different input sources provide different guarantees, ranging from _at-least once_ to _exactly once_. Read for more details. + +### With Files {:.no_toc} If all of the input data is already present in a fault-tolerant files system like HDFS, Spark Streaming can always recover from any failure and process all the data. This gives *exactly-once* semantics, that all the data will be processed exactly once no matter what fails. -## Semantics with input sources based on receivers +### With Receiver-based Sources {:.no_toc} For input sources based on receivers, the fault-tolerance semantics depend on both the failure scenario and the type of receiver. @@ -1879,10 +2186,9 @@ receivers, data received but not replicated can get lost. If the driver node fai then besides these losses, all the past data that was received and replicated in memory will be lost. This will affect the results of the stateful transformations. -To avoid this loss of past received data, Spark 1.2 introduces an experimental feature of _write +To avoid this loss of past received data, Spark 1.2 introduced _write ahead logs_ which saves the received data to fault-tolerant storage. With the [write ahead logs -enabled](#deploying-applications) and reliable receivers, there is zero data loss and -exactly-once semantics. +enabled](#deploying-applications) and reliable receivers, there is zero data loss. In terms of semantics, it provides at-least once guarantee. The following table summarizes the semantics under failures: @@ -1894,23 +2200,30 @@ The following table summarizes the semantics under failures: - Spark 1.1 or earlier, or
- Spark 1.2 without write ahead log + Spark 1.1 or earlier, OR
+ Spark 1.2 or later without write ahead logs Buffered data lost with unreliable receivers
- Zero data loss with reliable receivers and files
+ Zero data loss with reliable receivers
+ At-least once semantics Buffered data lost with unreliable receivers
Past data lost with all receivers
- Zero data loss with files - + Undefined semantics + - Spark 1.2 with write ahead log - Zero data loss with reliable receivers and files - Zero data loss with reliable receivers and files + Spark 1.2 or later with write ahead logs + + Zero data loss with reliable receivers
+ At-least once semantics + + + Zero data loss with reliable receivers and files
+ At-least once semantics + @@ -1919,17 +2232,24 @@ The following table summarizes the semantics under failures: +### With Kafka Direct API +{:.no_toc} +In Spark 1.3, we have introduced a new Kafka Direct API, which can ensure that all the Kafka data is received by Spark Streaming exactly once. Along with this, if you implement exactly-once output operation, you can achieve end-to-end exactly-once guarantees. This approach (experimental as of Spark 1.3) is further discussed in the [Kafka Integration Guide](streaming-kafka-integration.html). + ## Semantics of output operations {:.no_toc} -Since all data is modeled as RDDs with their lineage of deterministic operations, any recomputation - always leads to the same result. As a result, all DStream transformations are guaranteed to have - _exactly-once_ semantics. That is, the final transformed result will be same even if there were - was a worker node failure. However, output operations (like `foreachRDD`) have _at-least once_ - semantics, that is, the transformed data may get written to an external entity more than once in - the event of a worker failure. While this is acceptable for saving to HDFS using the - `saveAs***Files` operations (as the file will simply get over-written by the same data), - additional transactions-like mechanisms may be necessary to achieve exactly-once semantics - for output operations. +Output operations (like `foreachRDD`) have _at-least once_ semantics, that is, +the transformed data may get written to an external entity more than once in +the event of a worker failure. While this is acceptable for saving to file systems using the +`saveAs***Files` operations (as the file will simply get overwritten with the same data), +additional effort may be necessary to achieve exactly-once semantics. There are two approaches. + +- *Idempotent updates*: Multiple attempts always write the same data. For example, `saveAs***Files` always writes the same data to the generated files. + +- *Transactional updates*: All updates are made transactionally so that updates are made exactly once atomically. One way to do this would be the following. + + - Use the batch time (available in `foreachRDD`) and the partition index of the transformed RDD to create an identifier. This identifier uniquely identifies a blob data in the streaming application. + - Update external system with this blob transactionally (that is, exactly once, atomically) using the identifier. That is, if the identifier is not already committed, commit the partition data and the identifier atomically. Else if this was already committed, skip the update. *************************************************************************************************** @@ -1987,7 +2307,11 @@ package and renamed for better clarity. *************************************************************************************************** # Where to Go from Here - +* Additional guides + - [Kafka Integration Guide](streaming-kafka-integration.html) + - [Flume Integration Guide](streaming-flume-integration.html) + - [Kinesis Integration Guide](streaming-kinesis-integration.html) + - [Custom Receiver Guide](streaming-custom-receivers.html) * API documentation - Scala docs * [StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) and @@ -2009,8 +2333,8 @@ package and renamed for better clarity. [ZeroMQUtils](api/java/index.html?org/apache/spark/streaming/zeromq/ZeroMQUtils.html), and [MQTTUtils](api/java/index.html?org/apache/spark/streaming/mqtt/MQTTUtils.html) - Python docs - * [StreamingContext](api/python/pyspark.streaming.html#pyspark.streaming.StreamingContext) - * [DStream](api/python/pyspark.streaming.html#pyspark.streaming.DStream) + * [StreamingContext](api/python/pyspark.streaming.html#pyspark.streaming.StreamingContext) and [DStream](api/python/pyspark.streaming.html#pyspark.streaming.DStream) + * [KafkaUtils](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils) * More examples in [Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming) and [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming) diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 14a87f8436984..57b074778f2b0 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -174,6 +174,11 @@ This can use up a significant amount of space over time and will need to be clea is handled automatically, and with Spark standalone, automatic cleanup can be configured with the `spark.worker.cleanup.appDataTtl` property. +Users may also include any other dependencies by supplying a comma-delimited list of maven coordinates +with `--packages`. All transitive dependencies will be handled when using this command. Additional +repositories (or resolvers in SBT) can be added in a comma-delimited fashion with the flag `--repositories`. +These commands can be used with `pyspark`, `spark-shell`, and `spark-submit` to include Spark Packages. + For Python, the equivalent `--py-files` option can be used to distribute `.egg`, `.zip` and `.py` libraries to executors. diff --git a/docs/tuning.md b/docs/tuning.md index efaac9d3d405f..cbd227868b248 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -1,6 +1,8 @@ --- layout: global -title: Tuning Spark +displayTitle: Tuning Spark +title: Tuning +description: Tuning and performance optimization guide for Spark SPARK_VERSION_SHORT --- * This will become a table of contents (this text will be scraped). diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index abab209a05ba0..e121dad46278c 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -32,6 +32,7 @@ import sys import tarfile import tempfile +import textwrap import time import urllib2 import warnings @@ -39,6 +40,9 @@ from optparse import OptionParser from sys import stderr +SPARK_EC2_VERSION = "1.3.1" +SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) + VALID_SPARK_VERSIONS = set([ "0.7.3", "0.8.0", @@ -52,11 +56,13 @@ "1.1.0", "1.1.1", "1.2.0", + "1.2.1", + "1.3.0", + "1.3.1", ]) -DEFAULT_SPARK_VERSION = "1.2.0" +DEFAULT_SPARK_VERSION = SPARK_EC2_VERSION DEFAULT_SPARK_GITHUB_REPO = "https://github.com/apache/spark" -SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) MESOS_SPARK_EC2_BRANCH = "branch-1.3" # A URL prefix from which to fetch AMI information @@ -103,12 +109,10 @@ class UsageError(Exception): # Configure and parse our command-line arguments def parse_args(): parser = OptionParser( - usage="spark-ec2 [options] " - + "\n\n can be: launch, destroy, login, stop, start, get-master, reboot-slaves", - add_help_option=False) - parser.add_option( - "-h", "--help", action="help", - help="Show this help message and exit") + prog="spark-ec2", + version="%prog {v}".format(v=SPARK_EC2_VERSION), + usage="%prog [options] \n\n" + + " can be: launch, destroy, login, stop, start, get-master, reboot-slaves") parser.add_option( "-s", "--slaves", type="int", default=1, help="Number of slaves to launch (default: %default)") @@ -569,6 +573,9 @@ def launch_cluster(conn, opts, cluster_name): master_nodes = master_res.instances print "Launched master in %s, regid = %s" % (zone, master_res.id) + # This wait time corresponds to SPARK-4983 + print "Waiting for AWS to propagate instance metadata..." + time.sleep(5) # Give the instances descriptive names for master in master_nodes: master.add_tag( @@ -675,21 +682,32 @@ def setup_spark_cluster(master, opts): print "Ganglia started at http://%s:5080/ganglia" % master -def is_ssh_available(host, opts): +def is_ssh_available(host, opts, print_ssh_output=True): """ Check if SSH is available on a host. """ - try: - with open(os.devnull, 'w') as devnull: - ret = subprocess.check_call( - ssh_command(opts) + ['-t', '-t', '-o', 'ConnectTimeout=3', - '%s@%s' % (opts.user, host), stringify_command('true')], - stdout=devnull, - stderr=devnull - ) - return ret == 0 - except subprocess.CalledProcessError as e: - return False + s = subprocess.Popen( + ssh_command(opts) + ['-t', '-t', '-o', 'ConnectTimeout=3', + '%s@%s' % (opts.user, host), stringify_command('true')], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT # we pipe stderr through stdout to preserve output order + ) + cmd_output = s.communicate()[0] # [1] is stderr, which we redirected to stdout + + if s.returncode != 0 and print_ssh_output: + # extra leading newline is for spacing in wait_for_cluster_state() + print textwrap.dedent("""\n + Warning: SSH connection error. (This could be temporary.) + Host: {h} + SSH return code: {r} + SSH output: {o} + """).format( + h=host, + r=s.returncode, + o=cmd_output.strip() + ) + + return s.returncode == 0 def is_cluster_ssh_available(cluster_instances, opts): @@ -697,7 +715,7 @@ def is_cluster_ssh_available(cluster_instances, opts): Check if SSH is available on all the instances in a cluster. """ for i in cluster_instances: - if not is_ssh_available(host=i.ip_address, opts=opts): + if not is_ssh_available(host=i.public_dns_name, opts=opts): return False else: return True @@ -896,6 +914,7 @@ def stringify_command(parts): def ssh_args(opts): parts = ['-o', 'StrictHostKeyChecking=no'] + parts += ['-o', 'UserKnownHostsFile=/dev/null'] if opts.identity_file is not None: parts += ['-i', opts.identity_file] return parts @@ -1082,11 +1101,12 @@ def real_main(): time.sleep(30) # Yes, it does have to be this long :-( for group in groups: try: - conn.delete_security_group(group.name) - print "Deleted security group " + group.name + # It is needed to use group_id to make it work with VPC + conn.delete_security_group(group_id=group.id) + print "Deleted security group %s" % group.name except boto.exception.EC2ResponseError: success = False - print "Failed to delete security group " + group.name + print "Failed to delete security group %s" % group.name # Unfortunately, group.revoke() returns True even if a rule was not # deleted, so this needs to be rerun if something fails diff --git a/examples/pom.xml b/examples/pom.xml index 8caad2bc2e27a..e1a3eccc5ad28 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.3.2-SNAPSHOT ../pom.xml diff --git a/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java b/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java new file mode 100644 index 0000000000000..bab9f2478e779 --- /dev/null +++ b/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Arrays; +import java.util.regex.Pattern; + +import scala.Tuple2; + +import com.google.common.collect.Lists; +import kafka.serializer.StringDecoder; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.*; +import org.apache.spark.streaming.api.java.*; +import org.apache.spark.streaming.kafka.KafkaUtils; +import org.apache.spark.streaming.Durations; + +/** + * Consumes messages from one or more topics in Kafka and does wordcount. + * Usage: DirectKafkaWordCount + * is a list of one or more Kafka brokers + * is a list of one or more kafka topics to consume from + * + * Example: + * $ bin/run-example streaming.KafkaWordCount broker1-host:port,broker2-host:port topic1,topic2 + */ + +public final class JavaDirectKafkaWordCount { + private static final Pattern SPACE = Pattern.compile(" "); + + public static void main(String[] args) { + if (args.length < 2) { + System.err.println("Usage: DirectKafkaWordCount \n" + + " is a list of one or more Kafka brokers\n" + + " is a list of one or more kafka topics to consume from\n\n"); + System.exit(1); + } + + StreamingExamples.setStreamingLogLevels(); + + String brokers = args[0]; + String topics = args[1]; + + // Create context with 2 second batch interval + SparkConf sparkConf = new SparkConf().setAppName("JavaDirectKafkaWordCount"); + JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, Durations.seconds(2)); + + HashSet topicsSet = new HashSet(Arrays.asList(topics.split(","))); + HashMap kafkaParams = new HashMap(); + kafkaParams.put("metadata.broker.list", brokers); + + // Create direct kafka stream with brokers and topics + JavaPairInputDStream messages = KafkaUtils.createDirectStream( + jssc, + String.class, + String.class, + StringDecoder.class, + StringDecoder.class, + kafkaParams, + topicsSet + ); + + // Get the lines, split them into words, count the words and print + JavaDStream lines = messages.map(new Function, String>() { + @Override + public String call(Tuple2 tuple2) { + return tuple2._2(); + } + }); + JavaDStream words = lines.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String x) { + return Lists.newArrayList(SPACE.split(x)); + } + }); + JavaPairDStream wordCounts = words.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(String s) { + return new Tuple2(s, 1); + } + }).reduceByKey( + new Function2() { + @Override + public Integer call(Integer i1, Integer i2) { + return i1 + i2; + } + }); + wordCounts.print(); + + // Start the computation + jssc.start(); + jssc.awaitTermination(); + } +} diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala new file mode 100644 index 0000000000000..1c8a20bf8f1ae --- /dev/null +++ b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming + +import kafka.serializer.StringDecoder + +import org.apache.spark.streaming._ +import org.apache.spark.streaming.kafka._ +import org.apache.spark.SparkConf + +/** + * Consumes messages from one or more topics in Kafka and does wordcount. + * Usage: DirectKafkaWordCount + * is a list of one or more Kafka brokers + * is a list of one or more kafka topics to consume from + * + * Example: + * $ bin/run-example streaming.DirectKafkaWordCount broker1-host:port,broker2-host:port \ + * topic1,topic2 + */ +object DirectKafkaWordCount { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println(s""" + |Usage: DirectKafkaWordCount + | is a list of one or more Kafka brokers + | is a list of one or more kafka topics to consume from + | + """".stripMargin) + System.exit(1) + } + + StreamingExamples.setStreamingLogLevels() + + val Array(brokers, topics) = args + + // Create context with 2 second batch interval + val sparkConf = new SparkConf().setAppName("DirectKafkaWordCount") + val ssc = new StreamingContext(sparkConf, Seconds(2)) + + // Create direct kafka stream with brokers and topics + val topicsSet = topics.split(",").toSet + val kafkaParams = Map[String, String]("metadata.broker.list" -> brokers) + val messages = KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, topicsSet) + + // Get the lines, split them into words, count the words and print + val lines = messages.map(_._2) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1L)).reduceByKey(_ + _) + wordCounts.print() + + // Start the computation + ssc.start() + ssc.awaitTermination() + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java index 0fbee6e433608..9bbc14ea40875 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java @@ -34,8 +34,8 @@ import org.apache.spark.ml.tuning.CrossValidatorModel; import org.apache.spark.ml.tuning.ParamGridBuilder; import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; /** * A simple example demonstrating model selection using CrossValidator. @@ -71,7 +71,7 @@ public static void main(String[] args) { new LabeledDocument(9L, "a e c l", 0.0), new LabeledDocument(10L, "spark compile", 1.0), new LabeledDocument(11L, "hadoop software", 0.0)); - DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -112,14 +112,15 @@ public static void main(String[] args) { new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); - DataFrame test = jsql.applySchema(jsc.parallelize(localTest), Document.class); + DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. cvModel uses the best model found (lrModel). - cvModel.transform(test).registerTempTable("prediction"); - DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); - for (Row r: predictions.collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) + DataFrame predictions = cvModel.transform(test); + for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + ", prediction=" + r.get(3)); } + + jsc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java new file mode 100644 index 0000000000000..19d0eb216848e --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import java.util.List; + +import com.google.common.collect.Lists; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.classification.Classifier; +import org.apache.spark.ml.classification.ClassificationModel; +import org.apache.spark.ml.param.IntParam; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.param.Params; +import org.apache.spark.ml.param.Params$; +import org.apache.spark.mllib.linalg.BLAS; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; + + +/** + * A simple example demonstrating how to write your own learning algorithm using Estimator, + * Transformer, and other abstractions. + * This mimics {@link org.apache.spark.ml.classification.LogisticRegression}. + * + * Run with + *
+ * bin/run-example ml.JavaDeveloperApiExample
+ * 
+ */ +public class JavaDeveloperApiExample { + + public static void main(String[] args) throws Exception { + SparkConf conf = new SparkConf().setAppName("JavaDeveloperApiExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // Prepare training data. + List localTraining = Lists.newArrayList( + new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), + new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), + new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), + new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); + DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class); + + // Create a LogisticRegression instance. This instance is an Estimator. + MyJavaLogisticRegression lr = new MyJavaLogisticRegression(); + // Print out the parameters, documentation, and any default values. + System.out.println("MyJavaLogisticRegression parameters:\n" + lr.explainParams() + "\n"); + + // We may set parameters using setter methods. + lr.setMaxIter(10); + + // Learn a LogisticRegression model. This uses the parameters stored in lr. + MyJavaLogisticRegressionModel model = lr.fit(training); + + // Prepare test data. + List localTest = Lists.newArrayList( + new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), + new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), + new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); + DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); + + // Make predictions on test documents. cvModel uses the best model found (lrModel). + DataFrame results = model.transform(test); + double sumPredictions = 0; + for (Row r : results.select("features", "label", "prediction").collect()) { + sumPredictions += r.getDouble(2); + } + if (sumPredictions != 0.0) { + throw new Exception("MyJavaLogisticRegression predicted something other than 0," + + " even though all weights are 0!"); + } + + jsc.stop(); + } +} + +/** + * Example of defining a type of {@link Classifier}. + * + * NOTE: This is private since it is an example. In practice, you may not want it to be private. + */ +class MyJavaLogisticRegression + extends Classifier + implements Params { + + /** + * Param for max number of iterations + *

+ * NOTE: The usual way to add a parameter to a model or algorithm is to include: + * - val myParamName: ParamType + * - def getMyParamName + * - def setMyParamName + */ + IntParam maxIter = new IntParam(this, "maxIter", "max number of iterations"); + + int getMaxIter() { return (Integer) get(maxIter); } + + public MyJavaLogisticRegression() { + setMaxIter(100); + } + + // The parameter setter is in this class since it should return type MyJavaLogisticRegression. + MyJavaLogisticRegression setMaxIter(int value) { + return (MyJavaLogisticRegression) set(maxIter, value); + } + + // This method is used by fit(). + // In Java, we have to make it public since Java does not understand Scala's protected modifier. + public MyJavaLogisticRegressionModel train(DataFrame dataset, ParamMap paramMap) { + // Extract columns from data using helper method. + JavaRDD oldDataset = extractLabeledPoints(dataset, paramMap).toJavaRDD(); + + // Do learning to estimate the weight vector. + int numFeatures = oldDataset.take(1).get(0).features().size(); + Vector weights = Vectors.zeros(numFeatures); // Learning would happen here. + + // Create a model, and return it. + return new MyJavaLogisticRegressionModel(this, paramMap, weights); + } +} + +/** + * Example of defining a type of {@link ClassificationModel}. + * + * NOTE: This is private since it is an example. In practice, you may not want it to be private. + */ +class MyJavaLogisticRegressionModel + extends ClassificationModel implements Params { + + private MyJavaLogisticRegression parent_; + public MyJavaLogisticRegression parent() { return parent_; } + + private ParamMap fittingParamMap_; + public ParamMap fittingParamMap() { return fittingParamMap_; } + + private Vector weights_; + public Vector weights() { return weights_; } + + public MyJavaLogisticRegressionModel( + MyJavaLogisticRegression parent_, + ParamMap fittingParamMap_, + Vector weights_) { + this.parent_ = parent_; + this.fittingParamMap_ = fittingParamMap_; + this.weights_ = weights_; + } + + // This uses the default implementation of transform(), which reads column "features" and outputs + // columns "prediction" and "rawPrediction." + + // This uses the default implementation of predict(), which chooses the label corresponding to + // the maximum value returned by [[predictRaw()]]. + + /** + * Raw prediction for each possible label. + * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives + * a measure of confidence in each possible label (where larger = more confident). + * This internal method is used to implement [[transform()]] and output [[rawPredictionCol]]. + * + * @return vector where element i is the raw prediction for label i. + * This raw prediction may be any real number, where a larger value indicates greater + * confidence for that label. + * + * In Java, we have to make this method public since Java does not understand Scala's protected + * modifier. + */ + public Vector predictRaw(Vector features) { + double margin = BLAS.dot(features, weights_); + // There are 2 classes (binary classification), so we return a length-2 vector, + // where index i corresponds to class i (i = 0, 1). + return Vectors.dense(-margin, margin); + } + + /** + * Number of classes the label can take. 2 indicates binary classification. + */ + public int numClasses() { return 2; } + + /** + * Create a copy of the model. + * The copy is shallow, except for the embedded paramMap, which gets a deep copy. + *

+ * This is used for the defaul implementation of [[transform()]]. + * + * In Java, we have to make this method public since Java does not understand Scala's protected + * modifier. + */ + public MyJavaLogisticRegressionModel copy() { + MyJavaLogisticRegressionModel m = + new MyJavaLogisticRegressionModel(parent_, fittingParamMap_, weights_); + Params$.MODULE$.inheritValues(this.paramMap(), this, m); + return m; + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index eaaa344be49c8..4e02acce696e6 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -29,8 +29,8 @@ import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; /** * A simple example demonstrating ways to specify parameters for Estimators and Transformers. @@ -54,7 +54,7 @@ public static void main(String[] args) { new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); - DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class); + DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class); // Create a LogisticRegression instance. This instance is an Estimator. LogisticRegression lr = new LogisticRegression(); @@ -81,7 +81,7 @@ public static void main(String[] args) { // One can also combine ParamMaps. ParamMap paramMap2 = new ParamMap(); - paramMap2.put(lr.scoreCol().w("probability")); // Change output column name + paramMap2.put(lr.probabilityCol().w("myProbability")); // Change output column name ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2); // Now learn a new model using the paramMapCombined parameters. @@ -94,18 +94,18 @@ public static void main(String[] args) { new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); - DataFrame test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class); + DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); // Make predictions on test documents using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. - // Note that model2.transform() outputs a 'probability' column instead of the usual 'score' - // column since we renamed the lr.scoreCol parameter previously. - model2.transform(test).registerTempTable("results"); - DataFrame results = - jsql.sql("SELECT features, label, probability, prediction FROM results"); - for (Row r: results.collect()) { + // Note that model2.transform() outputs a 'myProbability' column instead of the usual + // 'probability' column since we renamed the lr.probabilityCol parameter previously. + DataFrame results = model2.transform(test); + for (Row r: results.select("features", "label", "myProbability", "prediction").collect()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2) + ", prediction=" + r.get(3)); } + + jsc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java index 82d665a3e1386..ef1ec103a879f 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java @@ -30,8 +30,8 @@ import org.apache.spark.ml.feature.HashingTF; import org.apache.spark.ml.feature.Tokenizer; import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; /** * A simple text classification pipeline that recognizes "spark" from input text. It uses the Java @@ -54,7 +54,7 @@ public static void main(String[] args) { new LabeledDocument(1L, "b d", 0.0), new LabeledDocument(2L, "spark f g h", 1.0), new LabeledDocument(3L, "hadoop mapreduce", 0.0)); - DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -79,14 +79,15 @@ public static void main(String[] args) { new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); - DataFrame test = jsql.applySchema(jsc.parallelize(localTest), Document.class); + DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. - model.transform(test).registerTempTable("prediction"); - DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); - for (Row r: predictions.collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) + DataFrame predictions = model.transform(test); + for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + ", prediction=" + r.get(3)); } + + jsc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java new file mode 100644 index 0000000000000..36baf5868736c --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import java.util.ArrayList; + +import com.google.common.base.Joiner; +import com.google.common.collect.Lists; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.fpm.FPGrowth; +import org.apache.spark.mllib.fpm.FPGrowthModel; + +/** + * Java example for mining frequent itemsets using FP-growth. + * Example usage: ./bin/run-example mllib.JavaFPGrowthExample ./data/mllib/sample_fpgrowth.txt + */ +public class JavaFPGrowthExample { + + public static void main(String[] args) { + String inputFile; + double minSupport = 0.3; + int numPartition = -1; + if (args.length < 1) { + System.err.println( + "Usage: JavaFPGrowth [minSupport] [numPartition]"); + System.exit(1); + } + inputFile = args[0]; + if (args.length >= 2) { + minSupport = Double.parseDouble(args[1]); + } + if (args.length >= 3) { + numPartition = Integer.parseInt(args[2]); + } + + SparkConf sparkConf = new SparkConf().setAppName("JavaFPGrowthExample"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + + JavaRDD> transactions = sc.textFile(inputFile).map( + new Function>() { + @Override + public ArrayList call(String s) { + return Lists.newArrayList(s.split(" ")); + } + } + ); + + FPGrowthModel model = new FPGrowth() + .setMinSupport(minSupport) + .setNumPartitions(numPartition) + .run(transactions); + + for (FPGrowth.FreqItemset s: model.freqItemsets().toJavaRDD().collect()) { + System.out.println("[" + Joiner.on(",").join(s.javaItems()) + "], " + s.freq()); + } + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java new file mode 100644 index 0000000000000..36207ae38d9a9 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.clustering.DistributedLDAModel; +import org.apache.spark.mllib.clustering.LDA; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.SparkConf; + +public class JavaLDAExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("LDA Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + + // Load and parse the data + String path = "data/mllib/sample_lda_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public Vector call(String s) { + String[] sarray = s.trim().split(" "); + double[] values = new double[sarray.length]; + for (int i = 0; i < sarray.length; i++) + values[i] = Double.parseDouble(sarray[i]); + return Vectors.dense(values); + } + } + ); + // Index documents with unique IDs + JavaPairRDD corpus = JavaPairRDD.fromJavaRDD(parsedData.zipWithIndex().map( + new Function, Tuple2>() { + public Tuple2 call(Tuple2 doc_id) { + return doc_id.swap(); + } + } + )); + corpus.cache(); + + // Cluster the documents into three topics using LDA + DistributedLDAModel ldaModel = new LDA().setK(3).run(corpus); + + // Output topics. Each is a distribution over words (matching word count vectors) + System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize() + + " words):"); + Matrix topics = ldaModel.topicsMatrix(); + for (int topic = 0; topic < 3; topic++) { + System.out.print("Topic " + topic + ":"); + for (int word = 0; word < ldaModel.vocabSize(); word++) { + System.out.print(" " + topics.apply(word, topic)); + } + System.out.println(); + } + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java new file mode 100644 index 0000000000000..6c6f9768f015e --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import scala.Tuple3; + +import com.google.common.collect.Lists; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.clustering.PowerIterationClustering; +import org.apache.spark.mllib.clustering.PowerIterationClusteringModel; + +/** + * Java example for graph clustering using power iteration clustering (PIC). + */ +public class JavaPowerIterationClusteringExample { + public static void main(String[] args) { + SparkConf sparkConf = new SparkConf().setAppName("JavaPowerIterationClusteringExample"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + + @SuppressWarnings("unchecked") + JavaRDD> similarities = sc.parallelize(Lists.newArrayList( + new Tuple3(0L, 1L, 0.9), + new Tuple3(1L, 2L, 0.9), + new Tuple3(2L, 3L, 0.9), + new Tuple3(3L, 4L, 0.1), + new Tuple3(4L, 5L, 0.9))); + + PowerIterationClustering pic = new PowerIterationClustering() + .setK(2) + .setMaxIterations(10); + PowerIterationClusteringModel model = pic.run(similarities); + + for (PowerIterationClustering.Assignment a: model.assignments().toJavaRDD().collect()) { + System.out.println(a.id() + " -> " + a.cluster()); + } + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index 8defb769ffaaf..8159ffbe2d269 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -55,7 +55,7 @@ public void setAge(int age) { public static void main(String[] args) throws Exception { SparkConf sparkConf = new SparkConf().setAppName("JavaSparkSQL"); JavaSparkContext ctx = new JavaSparkContext(sparkConf); - SQLContext sqlCtx = new SQLContext(ctx); + SQLContext sqlContext = new SQLContext(ctx); System.out.println("=== Data source: RDD ==="); // Load a text file and convert each line to a Java Bean. @@ -74,11 +74,11 @@ public Person call(String line) { }); // Apply a schema to an RDD of Java Beans and register it as a table. - DataFrame schemaPeople = sqlCtx.applySchema(people, Person.class); + DataFrame schemaPeople = sqlContext.createDataFrame(people, Person.class); schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. - DataFrame teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); // The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. @@ -99,12 +99,12 @@ public String call(Row row) { // Read in the parquet file created above. // Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a DataFrame. - DataFrame parquetFile = sqlCtx.parquetFile("people.parquet"); + DataFrame parquetFile = sqlContext.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); DataFrame teenagers2 = - sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); + sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); teenagerNames = teenagers2.toJavaRDD().map(new Function() { @Override public String call(Row row) { @@ -120,7 +120,7 @@ public String call(Row row) { // The path can be either a single text file or a directory storing text files. String path = "examples/src/main/resources/people.json"; // Create a DataFrame from the file(s) pointed by path - DataFrame peopleFromJsonFile = sqlCtx.jsonFile(path); + DataFrame peopleFromJsonFile = sqlContext.jsonFile(path); // Because the schema of a JSON dataset is automatically inferred, to write queries, // it is better to take a look at what is the schema. @@ -133,8 +133,8 @@ public String call(Row row) { // Register this DataFrame as a table. peopleFromJsonFile.registerTempTable("people"); - // SQL statements can be run by using the sql methods provided by sqlCtx. - DataFrame teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + // SQL statements can be run by using the sql methods provided by sqlContext. + DataFrame teenagers3 = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); // The results of SQL queries are DataFrame and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. @@ -151,7 +151,7 @@ public String call(Row row) { List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = ctx.parallelize(jsonData); - DataFrame peopleFromJsonRDD = sqlCtx.jsonRDD(anotherPeopleRDD.rdd()); + DataFrame peopleFromJsonRDD = sqlContext.jsonRDD(anotherPeopleRDD.rdd()); // Take a look at the schema of this new DataFrame. peopleFromJsonRDD.printSchema(); @@ -164,7 +164,7 @@ public String call(Row row) { peopleFromJsonRDD.registerTempTable("people2"); - DataFrame peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2"); + DataFrame peopleWithCity = sqlContext.sql("SELECT name, address.city FROM people2"); List nameAndCity = peopleWithCity.toJavaRDD().map(new Function() { @Override public String call(Row row) { diff --git a/core/src/test/scala/org/apache/spark/util/FakeClock.scala b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecord.java similarity index 71% rename from core/src/test/scala/org/apache/spark/util/FakeClock.scala rename to examples/src/main/java/org/apache/spark/examples/streaming/JavaRecord.java index 0a45917b08dd2..e63697a79f23a 100644 --- a/core/src/test/scala/org/apache/spark/util/FakeClock.scala +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecord.java @@ -15,12 +15,17 @@ * limitations under the License. */ -package org.apache.spark.util +package org.apache.spark.examples.streaming; -class FakeClock extends Clock { - private var time = 0L +/** Java Bean class to be used with the example JavaSqlNetworkWordCount. */ +public class JavaRecord implements java.io.Serializable { + private String word; - def advance(millis: Long): Unit = time += millis + public String getWord() { + return word; + } - def getTime(): Long = time + public void setWord(String word) { + this.word = word; + } } diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java new file mode 100644 index 0000000000000..46562ddbbcb57 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming; + +import java.util.regex.Pattern; + +import com.google.common.collect.Lists; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.api.java.StorageLevels; +import org.apache.spark.streaming.Durations; +import org.apache.spark.streaming.Time; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + +/** + * Use DataFrames and SQL to count words in UTF8 encoded, '\n' delimited text received from the + * network every second. + * + * Usage: JavaSqlNetworkWordCount + * and describe the TCP server that Spark Streaming would connect to receive data. + * + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ bin/run-example org.apache.spark.examples.streaming.JavaSqlNetworkWordCount localhost 9999` + */ + +public final class JavaSqlNetworkWordCount { + private static final Pattern SPACE = Pattern.compile(" "); + + public static void main(String[] args) { + if (args.length < 2) { + System.err.println("Usage: JavaNetworkWordCount "); + System.exit(1); + } + + StreamingExamples.setStreamingLogLevels(); + + // Create the context with a 1 second batch size + SparkConf sparkConf = new SparkConf().setAppName("JavaSqlNetworkWordCount"); + JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1)); + + // Create a JavaReceiverInputDStream on target ip:port and count the + // words in input stream of \n delimited text (eg. generated by 'nc') + // Note that no duplication in storage level only for running locally. + // Replication necessary in distributed scenario for fault tolerance. + JavaReceiverInputDStream lines = ssc.socketTextStream( + args[0], Integer.parseInt(args[1]), StorageLevels.MEMORY_AND_DISK_SER); + JavaDStream words = lines.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String x) { + return Lists.newArrayList(SPACE.split(x)); + } + }); + + // Convert RDDs of the words DStream to DataFrame and run SQL query + words.foreachRDD(new Function2, Time, Void>() { + @Override + public Void call(JavaRDD rdd, Time time) { + SQLContext sqlContext = JavaSQLContextSingleton.getInstance(rdd.context()); + + // Convert JavaRDD[String] to JavaRDD[bean class] to DataFrame + JavaRDD rowRDD = rdd.map(new Function() { + public JavaRecord call(String word) { + JavaRecord record = new JavaRecord(); + record.setWord(word); + return record; + } + }); + DataFrame wordsDataFrame = sqlContext.createDataFrame(rowRDD, JavaRecord.class); + + // Register as table + wordsDataFrame.registerTempTable("words"); + + // Do word count on table using SQL and print it + DataFrame wordCountsDataFrame = + sqlContext.sql("select word, count(*) as total from words group by word"); + System.out.println("========= " + time + "========="); + wordCountsDataFrame.show(); + return null; + } + }); + + ssc.start(); + ssc.awaitTermination(); + } +} + +/** Lazily instantiated singleton instance of SQLContext */ +class JavaSQLContextSingleton { + static private transient SQLContext instance = null; + static public SQLContext getInstance(SparkContext sparkContext) { + if (instance == null) { + instance = new SQLContext(sparkContext); + } + return instance; + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java new file mode 100644 index 0000000000000..d46c7107c7a21 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming; + +import java.util.Arrays; +import java.util.List; +import java.util.regex.Pattern; + +import scala.Tuple2; + +import com.google.common.base.Optional; +import com.google.common.collect.Lists; + +import org.apache.spark.HashPartitioner; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.StorageLevels; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.streaming.Durations; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + +/** + * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every + * second starting with initial value of word count. + * Usage: JavaStatefulNetworkWordCount + * and describe the TCP server that Spark Streaming would connect to receive + * data. + *

+ * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ bin/run-example + * org.apache.spark.examples.streaming.JavaStatefulNetworkWordCount localhost 9999` + */ +public class JavaStatefulNetworkWordCount { + private static final Pattern SPACE = Pattern.compile(" "); + + public static void main(String[] args) { + if (args.length < 2) { + System.err.println("Usage: JavaStatefulNetworkWordCount "); + System.exit(1); + } + + StreamingExamples.setStreamingLogLevels(); + + // Update the cumulative count function + final Function2, Optional, Optional> updateFunction = + new Function2, Optional, Optional>() { + @Override + public Optional call(List values, Optional state) { + Integer newSum = state.or(0); + for (Integer value : values) { + newSum += value; + } + return Optional.of(newSum); + } + }; + + // Create the context with a 1 second batch size + SparkConf sparkConf = new SparkConf().setAppName("JavaStatefulNetworkWordCount"); + JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1)); + ssc.checkpoint("."); + + // Initial RDD input to updateStateByKey + List> tuples = Arrays.asList(new Tuple2("hello", 1), + new Tuple2("world", 1)); + JavaPairRDD initialRDD = ssc.sc().parallelizePairs(tuples); + + JavaReceiverInputDStream lines = ssc.socketTextStream( + args[0], Integer.parseInt(args[1]), StorageLevels.MEMORY_AND_DISK_SER_2); + + JavaDStream words = lines.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String x) { + return Lists.newArrayList(SPACE.split(x)); + } + }); + + JavaPairDStream wordsDstream = words.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(String s) { + return new Tuple2(s, 1); + } + }); + + // This will give a Dstream made of state (which is the cumulative count of the words) + JavaPairDStream stateDstream = wordsDstream.updateStateByKey(updateFunction, + new HashPartitioner(ssc.sc().defaultParallelism()), initialRDD); + + stateDstream.print(); + ssc.start(); + ssc.awaitTermination(); + } +} diff --git a/examples/src/main/python/ml/simple_text_classification_pipeline.py b/examples/src/main/python/ml/simple_text_classification_pipeline.py index c7df3d7b74767..c73edb7fd6b20 100644 --- a/examples/src/main/python/ml/simple_text_classification_pipeline.py +++ b/examples/src/main/python/ml/simple_text_classification_pipeline.py @@ -16,10 +16,10 @@ # from pyspark import SparkContext -from pyspark.sql import SQLContext, Row from pyspark.ml import Pipeline -from pyspark.ml.feature import HashingTF, Tokenizer from pyspark.ml.classification import LogisticRegression +from pyspark.ml.feature import HashingTF, Tokenizer +from pyspark.sql import Row, SQLContext """ @@ -33,46 +33,36 @@ if __name__ == "__main__": sc = SparkContext(appName="SimpleTextClassificationPipeline") - sqlCtx = SQLContext(sc) + sqlContext = SQLContext(sc) # Prepare training documents, which are labeled. - LabeledDocument = Row('id', 'text', 'label') - training = sqlCtx.inferSchema( - sc.parallelize([(0L, "a b c d e spark", 1.0), - (1L, "b d", 0.0), - (2L, "spark f g h", 1.0), - (3L, "hadoop mapreduce", 0.0)]) - .map(lambda x: LabeledDocument(*x))) + LabeledDocument = Row("id", "text", "label") + training = sc.parallelize([(0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0)]) \ + .map(lambda x: LabeledDocument(*x)).toDF() # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. - tokenizer = Tokenizer() \ - .setInputCol("text") \ - .setOutputCol("words") - hashingTF = HashingTF() \ - .setInputCol(tokenizer.getOutputCol()) \ - .setOutputCol("features") - lr = LogisticRegression() \ - .setMaxIter(10) \ - .setRegParam(0.01) - pipeline = Pipeline() \ - .setStages([tokenizer, hashingTF, lr]) + tokenizer = Tokenizer(inputCol="text", outputCol="words") + hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") + lr = LogisticRegression(maxIter=10, regParam=0.01) + pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) # Fit the pipeline to training documents. model = pipeline.fit(training) # Prepare test documents, which are unlabeled. - Document = Row('id', 'text') - test = sqlCtx.inferSchema( - sc.parallelize([(4L, "spark i j k"), - (5L, "l m n"), - (6L, "mapreduce spark"), - (7L, "apache hadoop")]) - .map(lambda x: Document(*x))) + Document = Row("id", "text") + test = sc.parallelize([(4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop")]) \ + .map(lambda x: Document(*x)).toDF() # Make predictions on test documents and print columns of interest. prediction = model.transform(test) - prediction.registerTempTable("prediction") - selected = sqlCtx.sql("SELECT id, text, prediction from prediction") + selected = prediction.select("id", "text", "prediction") for row in selected.collect(): print row diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py index b5a70db2b9a3c..fcbf56cbf0c52 100644 --- a/examples/src/main/python/mllib/dataset_example.py +++ b/examples/src/main/python/mllib/dataset_example.py @@ -44,19 +44,19 @@ def summarize(dataset): print >> sys.stderr, "Usage: dataset_example.py " exit(-1) sc = SparkContext(appName="DatasetExample") - sqlCtx = SQLContext(sc) + sqlContext = SQLContext(sc) if len(sys.argv) == 2: input = sys.argv[1] else: input = "data/mllib/sample_libsvm_data.txt" points = MLUtils.loadLibSVMFile(sc, input) - dataset0 = sqlCtx.inferSchema(points).setName("dataset0").cache() + dataset0 = sqlContext.inferSchema(points).setName("dataset0").cache() summarize(dataset0) tempdir = tempfile.NamedTemporaryFile(delete=False).name os.unlink(tempdir) print "Save dataset as a Parquet file to %s." % tempdir dataset0.saveAsParquetFile(tempdir) print "Load it back and summarize it again." - dataset1 = sqlCtx.parquetFile(tempdir).setName("dataset1").cache() + dataset1 = sqlContext.parquetFile(tempdir).setName("dataset1").cache() summarize(dataset1) shutil.rmtree(tempdir) diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py index 7f5c68e3d0fe2..d89361f324917 100644 --- a/examples/src/main/python/sql.py +++ b/examples/src/main/python/sql.py @@ -19,7 +19,7 @@ from pyspark import SparkContext from pyspark.sql import SQLContext -from pyspark.sql import Row, StructField, StructType, StringType, IntegerType +from pyspark.sql.types import Row, StructField, StructType, StringType, IntegerType if __name__ == "__main__": @@ -31,7 +31,7 @@ Row(name="Smith", age=23), Row(name="Sarah", age=18)]) # Infer schema from the first row, create a DataFrame and print the schema - some_df = sqlContext.inferSchema(some_rdd) + some_df = sqlContext.createDataFrame(some_rdd) some_df.printSchema() # Another RDD is created from a list of tuples @@ -40,7 +40,7 @@ schema = StructType([StructField("person_name", StringType(), False), StructField("person_age", IntegerType(), False)]) # Create a DataFrame by applying the schema to the RDD and print the schema - another_df = sqlContext.applySchema(another_rdd, schema) + another_df = sqlContext.createDataFrame(another_rdd, schema) another_df.printSchema() # root # |-- age: integer (nullable = true) diff --git a/examples/src/main/python/status_api_demo.py b/examples/src/main/python/status_api_demo.py new file mode 100644 index 0000000000000..a33bdc475a06d --- /dev/null +++ b/examples/src/main/python/status_api_demo.py @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import time +import threading +import Queue + +from pyspark import SparkConf, SparkContext + + +def delayed(seconds): + def f(x): + time.sleep(seconds) + return x + return f + + +def call_in_background(f, *args): + result = Queue.Queue(1) + t = threading.Thread(target=lambda: result.put(f(*args))) + t.daemon = True + t.start() + return result + + +def main(): + conf = SparkConf().set("spark.ui.showConsoleProgress", "false") + sc = SparkContext(appName="PythonStatusAPIDemo", conf=conf) + + def run(): + rdd = sc.parallelize(range(10), 10).map(delayed(2)) + reduced = rdd.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x + y) + return reduced.map(delayed(2)).collect() + + result = call_in_background(run) + status = sc.statusTracker() + while result.empty(): + ids = status.getJobIdsForGroup() + for id in ids: + job = status.getJobInfo(id) + print "Job", id, "status: ", job.status + for sid in job.stageIds: + info = status.getStageInfo(sid) + if info: + print "Stage %d: %d tasks total (%d active, %d complete)" % \ + (sid, info.numTasks, info.numActiveTasks, info.numCompletedTasks) + time.sleep(1) + + print "Job results are:", result.get() + sc.stop() + +if __name__ == "__main__": + main() diff --git a/examples/src/main/python/streaming/kafka_wordcount.py b/examples/src/main/python/streaming/kafka_wordcount.py index ed398a82b8bb0..51e1ff822fc55 100644 --- a/examples/src/main/python/streaming/kafka_wordcount.py +++ b/examples/src/main/python/streaming/kafka_wordcount.py @@ -17,13 +17,13 @@ """ Counts words in UTF8 encoded, '\n' delimited text received from the network every second. - Usage: network_wordcount.py + Usage: kafka_wordcount.py To run this on your local machine, you need to setup Kafka and create a producer first, see http://kafka.apache.org/documentation.html#quickstart and then run the example - `$ bin/spark-submit --driver-class-path external/kafka-assembly/target/scala-*/\ + `$ bin/spark-submit --jars external/kafka-assembly/target/scala-*/\ spark-streaming-kafka-assembly-*.jar examples/src/main/python/streaming/kafka_wordcount.py \ localhost:2181 test` """ diff --git a/examples/src/main/python/streaming/sql_network_wordcount.py b/examples/src/main/python/streaming/sql_network_wordcount.py new file mode 100644 index 0000000000000..f89bc562d856b --- /dev/null +++ b/examples/src/main/python/streaming/sql_network_wordcount.py @@ -0,0 +1,82 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Use DataFrames and SQL to count words in UTF8 encoded, '\n' delimited text received from the + network every second. + + Usage: sql_network_wordcount.py + and describe the TCP server that Spark Streaming would connect to receive data. + + To run this on your local machine, you need to first run a Netcat server + `$ nc -lk 9999` + and then run the example + `$ bin/spark-submit examples/src/main/python/streaming/sql_network_wordcount.py localhost 9999` +""" + +import os +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.sql import SQLContext, Row + + +def getSqlContextInstance(sparkContext): + if ('sqlContextSingletonInstance' not in globals()): + globals()['sqlContextSingletonInstance'] = SQLContext(sparkContext) + return globals()['sqlContextSingletonInstance'] + + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: sql_network_wordcount.py " + exit(-1) + host, port = sys.argv[1:] + sc = SparkContext(appName="PythonSqlNetworkWordCount") + ssc = StreamingContext(sc, 1) + + # Create a socket stream on target ip:port and count the + # words in input stream of \n delimited text (eg. generated by 'nc') + lines = ssc.socketTextStream(host, int(port)) + words = lines.flatMap(lambda line: line.split(" ")) + + # Convert RDDs of the words DStream to DataFrame and run SQL query + def process(time, rdd): + print "========= %s =========" % str(time) + + try: + # Get the singleton instance of SQLContext + sqlContext = getSqlContextInstance(rdd.context) + + # Convert RDD[String] to RDD[Row] to DataFrame + rowRdd = rdd.map(lambda w: Row(word=w)) + wordsDataFrame = sqlContext.createDataFrame(rowRdd) + + # Register as table + wordsDataFrame.registerTempTable("words") + + # Do word count on table using SQL and print it + wordCountsDataFrame = \ + sqlContext.sql("select word, count(*) as total from words group by word") + wordCountsDataFrame.show() + except: + pass + + words.foreachRDD(process) + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index 1b53f3edbe92e..4c129dbe2d12d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -29,7 +29,7 @@ object BroadcastTest { val blockSize = if (args.length > 3) args(3) else "4096" val sparkConf = new SparkConf().setAppName("Broadcast Test") - .set("spark.broadcast.factory", s"org.apache.spark.broadcast.${bcName}BroaddcastFactory") + .set("spark.broadcast.factory", s"org.apache.spark.broadcast.${bcName}BroadcastFactory") .set("spark.broadcast.blockSize", blockSize) val sc = new SparkContext(sparkConf) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala index 283bb80f1c788..6c0af20461d3b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala @@ -23,6 +23,7 @@ import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.ml.feature.{HashingTF, Tokenizer} import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{Row, SQLContext} /** @@ -43,10 +44,10 @@ object CrossValidatorExample { val conf = new SparkConf().setAppName("CrossValidatorExample") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - import sqlContext._ + import sqlContext.implicits._ // Prepare training documents, which are labeled. - val training = sparkContext.parallelize(Seq( + val training = sc.parallelize(Seq( LabeledDocument(0L, "a b c d e spark", 1.0), LabeledDocument(1L, "b d", 0.0), LabeledDocument(2L, "spark f g h", 1.0), @@ -89,21 +90,21 @@ object CrossValidatorExample { crossval.setNumFolds(2) // Use 3+ in practice // Run cross-validation, and choose the best set of parameters. - val cvModel = crossval.fit(training) + val cvModel = crossval.fit(training.toDF()) // Prepare test documents, which are unlabeled. - val test = sparkContext.parallelize(Seq( + val test = sc.parallelize(Seq( Document(4L, "spark i j k"), Document(5L, "l m n"), Document(6L, "mapreduce spark"), Document(7L, "apache hadoop"))) // Make predictions on test documents. cvModel uses the best model found (lrModel). - cvModel.transform(test) - .select("id", "text", "score", "prediction") + cvModel.transform(test.toDF()) + .select("id", "text", "probability", "prediction") .collect() - .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => - println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) + .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => + println(s"($id, $text) --> prob=$prob, prediction=$prediction") } sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala new file mode 100644 index 0000000000000..df26798e41b7b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.ml.classification.{Classifier, ClassifierParams, ClassificationModel} +import org.apache.spark.ml.param.{Params, IntParam, ParamMap} +import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + + +/** + * A simple example demonstrating how to write your own learning algorithm using Estimator, + * Transformer, and other abstractions. + * This mimics [[org.apache.spark.ml.classification.LogisticRegression]]. + * Run with + * {{{ + * bin/run-example ml.DeveloperApiExample + * }}} + */ +object DeveloperApiExample { + + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("DeveloperApiExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Prepare training data. + val training = sc.parallelize(Seq( + LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), + LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)))) + + // Create a LogisticRegression instance. This instance is an Estimator. + val lr = new MyLogisticRegression() + // Print out the parameters, documentation, and any default values. + println("MyLogisticRegression parameters:\n" + lr.explainParams() + "\n") + + // We may set parameters using setter methods. + lr.setMaxIter(10) + + // Learn a LogisticRegression model. This uses the parameters stored in lr. + val model = lr.fit(training.toDF()) + + // Prepare test data. + val test = sc.parallelize(Seq( + LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), + LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), + LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) + + // Make predictions on test data. + val sumPredictions: Double = model.transform(test.toDF()) + .select("features", "label", "prediction") + .collect() + .map { case Row(features: Vector, label: Double, prediction: Double) => + prediction + }.sum + assert(sumPredictions == 0.0, + "MyLogisticRegression predicted something other than 0, even though all weights are 0!") + + sc.stop() + } +} + +/** + * Example of defining a parameter trait for a user-defined type of [[Classifier]]. + * + * NOTE: This is private since it is an example. In practice, you may not want it to be private. + */ +private trait MyLogisticRegressionParams extends ClassifierParams { + + /** + * Param for max number of iterations + * + * NOTE: The usual way to add a parameter to a model or algorithm is to include: + * - val myParamName: ParamType + * - def getMyParamName + * - def setMyParamName + * Here, we have a trait to be mixed in with the Estimator and Model (MyLogisticRegression + * and MyLogisticRegressionModel). We place the setter (setMaxIter) method in the Estimator + * class since the maxIter parameter is only used during training (not in the Model). + */ + val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") + def getMaxIter: Int = get(maxIter) +} + +/** + * Example of defining a type of [[Classifier]]. + * + * NOTE: This is private since it is an example. In practice, you may not want it to be private. + */ +private class MyLogisticRegression + extends Classifier[Vector, MyLogisticRegression, MyLogisticRegressionModel] + with MyLogisticRegressionParams { + + setMaxIter(100) // Initialize + + // The parameter setter is in this class since it should return type MyLogisticRegression. + def setMaxIter(value: Int): this.type = set(maxIter, value) + + // This method is used by fit() + override protected def train( + dataset: DataFrame, + paramMap: ParamMap): MyLogisticRegressionModel = { + // Extract columns from data using helper method. + val oldDataset = extractLabeledPoints(dataset, paramMap) + + // Do learning to estimate the weight vector. + val numFeatures = oldDataset.take(1)(0).features.size + val weights = Vectors.zeros(numFeatures) // Learning would happen here. + + // Create a model, and return it. + new MyLogisticRegressionModel(this, paramMap, weights) + } +} + +/** + * Example of defining a type of [[ClassificationModel]]. + * + * NOTE: This is private since it is an example. In practice, you may not want it to be private. + */ +private class MyLogisticRegressionModel( + override val parent: MyLogisticRegression, + override val fittingParamMap: ParamMap, + val weights: Vector) + extends ClassificationModel[Vector, MyLogisticRegressionModel] + with MyLogisticRegressionParams { + + // This uses the default implementation of transform(), which reads column "features" and outputs + // columns "prediction" and "rawPrediction." + + // This uses the default implementation of predict(), which chooses the label corresponding to + // the maximum value returned by [[predictRaw()]]. + + /** + * Raw prediction for each possible label. + * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives + * a measure of confidence in each possible label (where larger = more confident). + * This internal method is used to implement [[transform()]] and output [[rawPredictionCol]]. + * + * @return vector where element i is the raw prediction for label i. + * This raw prediction may be any real number, where a larger value indicates greater + * confidence for that label. + */ + override protected def predictRaw(features: Vector): Vector = { + val margin = BLAS.dot(features, weights) + // There are 2 classes (binary classification), so we return a length-2 vector, + // where index i corresponds to class i (i = 0, 1). + Vectors.dense(-margin, margin) + } + + /** Number of classes the label can take. 2 indicates binary classification. */ + override val numClasses: Int = 2 + + /** + * Create a copy of the model. + * The copy is shallow, except for the embedded paramMap, which gets a deep copy. + * + * This is used for the defaul implementation of [[transform()]]. + */ + override protected def copy(): MyLogisticRegressionModel = { + val m = new MyLogisticRegressionModel(parent, fittingParamMap, weights) + Params.inheritValues(this.paramMap, this, m) + m + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala index b7885829459a3..25f21113bf622 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala @@ -93,8 +93,8 @@ object MovieLensALS { | bin/spark-submit --class org.apache.spark.examples.ml.MovieLensALS \ | examples/target/scala-*/spark-examples-*.jar \ | --rank 10 --maxIter 15 --regParam 0.1 \ - | --movies path/to/movielens/movies.dat \ - | --ratings path/to/movielens/ratings.dat + | --movies data/mllib/als/sample_movielens_movies.txt \ + | --ratings data/mllib/als/sample_movielens_ratings.txt """.stripMargin) } @@ -109,7 +109,7 @@ object MovieLensALS { val conf = new SparkConf().setAppName(s"MovieLensALS with $params") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - import sqlContext._ + import sqlContext.implicits._ val ratings = sc.textFile(params.ratings).map(Rating.parseRating).cache() @@ -137,9 +137,9 @@ object MovieLensALS { .setRegParam(params.regParam) .setNumBlocks(params.numBlocks) - val model = als.fit(training) + val model = als.fit(training.toDF()) - val predictions = model.transform(test).cache() + val predictions = model.transform(test.toDF()).cache() // Evaluate the model. // TODO: Create an evaluator to compute RMSE. @@ -157,17 +157,23 @@ object MovieLensALS { println(s"Test RMSE = $rmse.") // Inspect false positives. - predictions.registerTempTable("prediction") - sc.textFile(params.movies).map(Movie.parseMovie).registerTempTable("movie") - sqlContext.sql( - """ - |SELECT userId, prediction.movieId, title, rating, prediction - | FROM prediction JOIN movie ON prediction.movieId = movie.movieId - | WHERE rating <= 1 AND prediction >= 4 - | LIMIT 100 - """.stripMargin) - .collect() - .foreach(println) + // Note: We reference columns in 2 ways: + // (1) predictions("movieId") lets us specify the movieId column in the predictions + // DataFrame, rather than the movieId column in the movies DataFrame. + // (2) $"userId" specifies the userId column in the predictions DataFrame. + // We could also write predictions("userId") but do not have to since + // the movies DataFrame does not have a column "userId." + val movies = sc.textFile(params.movies).map(Movie.parseMovie).toDF() + val falsePositives = predictions.join(movies) + .where((predictions("movieId") === movies("movieId")) + && ($"rating" <= 1) && ($"prediction" >= 4)) + .select($"userId", predictions("movieId"), $"title", $"rating", $"prediction") + val numFalsePositives = falsePositives.count() + println(s"Found $numFalsePositives false positives") + if (numFalsePositives > 0) { + println(s"Example false positives:") + falsePositives.limit(100).collect().foreach(println) + } sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index 95cc9801eaeb9..bf805149d0af6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -37,12 +37,12 @@ object SimpleParamsExample { val conf = new SparkConf().setAppName("SimpleParamsExample") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - import sqlContext._ + import sqlContext.implicits._ // Prepare training data. - // We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of Java Beans - // into DataFrames, where it uses the bean metadata to infer the schema. - val training = sparkContext.parallelize(Seq( + // We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes + // into DataFrames, where it uses the case class metadata to infer the schema. + val training = sc.parallelize(Seq( LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), @@ -58,7 +58,7 @@ object SimpleParamsExample { .setRegParam(0.01) // Learn a LogisticRegression model. This uses the parameters stored in lr. - val model1 = lr.fit(training) + val model1 = lr.fit(training.toDF()) // Since model1 is a Model (i.e., a Transformer produced by an Estimator), // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this @@ -72,29 +72,29 @@ object SimpleParamsExample { paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. // One can also combine ParamMaps. - val paramMap2 = ParamMap(lr.scoreCol -> "probability") // Change output column name + val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name val paramMapCombined = paramMap ++ paramMap2 // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. - val model2 = lr.fit(training, paramMapCombined) + val model2 = lr.fit(training.toDF(), paramMapCombined) println("Model 2 was fit using parameters: " + model2.fittingParamMap) - // Prepare test documents. - val test = sparkContext.parallelize(Seq( + // Prepare test data. + val test = sc.parallelize(Seq( LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) - // Make predictions on test documents using the Transformer.transform() method. + // Make predictions on test data using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. - // Note that model2.transform() outputs a 'probability' column instead of the usual 'score' - // column since we renamed the lr.scoreCol parameter previously. - model2.transform(test) - .select("features", "label", "probability", "prediction") + // Note that model2.transform() outputs a 'myProbability' column instead of the usual + // 'probability' column since we renamed the lr.probabilityCol parameter previously. + model2.transform(test.toDF()) + .select("features", "label", "myProbability", "prediction") .collect() - .foreach { case Row(features: Vector, label: Double, prob: Double, prediction: Double) => - println("(" + features + ", " + label + ") -> prob=" + prob + ", prediction=" + prediction) + .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) => + println(s"($features, $label) -> prob=$prob, prediction=$prediction") } sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala index 065db62b0f5ed..6772efd2c581c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -23,6 +23,7 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.feature.{HashingTF, Tokenizer} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{Row, SQLContext} @BeanInfo @@ -44,10 +45,10 @@ object SimpleTextClassificationPipeline { val conf = new SparkConf().setAppName("SimpleTextClassificationPipeline") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - import sqlContext._ + import sqlContext.implicits._ // Prepare training documents, which are labeled. - val training = sparkContext.parallelize(Seq( + val training = sc.parallelize(Seq( LabeledDocument(0L, "a b c d e spark", 1.0), LabeledDocument(1L, "b d", 0.0), LabeledDocument(2L, "spark f g h", 1.0), @@ -68,21 +69,21 @@ object SimpleTextClassificationPipeline { .setStages(Array(tokenizer, hashingTF, lr)) // Fit the pipeline to training documents. - val model = pipeline.fit(training) + val model = pipeline.fit(training.toDF()) // Prepare test documents, which are unlabeled. - val test = sparkContext.parallelize(Seq( + val test = sc.parallelize(Seq( Document(4L, "spark i j k"), Document(5L, "l m n"), Document(6L, "mapreduce spark"), Document(7L, "apache hadoop"))) // Make predictions on test documents. - model.transform(test) - .select("id", "text", "score", "prediction") + model.transform(test.toDF()) + .select("id", "text", "probability", "prediction") .collect() - .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => - println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) + .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => + println(s"($id, $text) --> prob=$prob, prediction=$prediction") } sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala index ab58375649d25..e943d6c889fab 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -71,7 +71,7 @@ object DatasetExample { val conf = new SparkConf().setAppName(s"DatasetExample with $params") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - import sqlContext._ // for implicit conversions + import sqlContext.implicits._ // for implicit conversions // Load input data val origData: RDD[LabeledPoint] = params.dataFormat match { @@ -81,18 +81,18 @@ object DatasetExample { println(s"Loaded ${origData.count()} instances from file: ${params.input}") // Convert input data to DataFrame explicitly. - val df: DataFrame = origData.toDataFrame + val df: DataFrame = origData.toDF() println(s"Inferred schema:\n${df.schema.prettyJson}") println(s"Converted to DataFrame with ${df.count()} records") - // Select columns, using implicit conversion to DataFrames. - val labelsDf: DataFrame = origData.select("label") + // Select columns + val labelsDf: DataFrame = df.select("label") val labels: RDD[Double] = labelsDf.map { case Row(v: Double) => v } val numLabels = labels.count() val meanLabel = labels.fold(0.0)(_ + _) / numLabels println(s"Selected label column with average value $meanLabel") - val featuresDf: DataFrame = origData.select("features") + val featuresDf: DataFrame = df.select("features") val features: RDD[Vector] = featuresDf.map { case Row(v: Vector) => v } val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( (summary, feat) => summary.add(feat), diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 205d80dd02682..262fd2c9611d0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -272,6 +272,8 @@ object DecisionTreeRunner { case Variance => impurity.Variance } + params.checkpointDir.foreach(sc.setCheckpointDir) + val strategy = new Strategy( algo = params.algo, @@ -282,7 +284,6 @@ object DecisionTreeRunner { minInstancesPerNode = params.minInstancesPerNode, minInfoGain = params.minInfoGain, useNodeIdCache = params.useNodeIdCache, - checkpointDir = params.checkpointDir, checkpointInterval = params.checkpointInterval) if (params.numTrees == 1) { val startTime = System.nanoTime() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala new file mode 100644 index 0000000000000..13f24a1e59610 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scopt.OptionParser + +import org.apache.spark.mllib.fpm.FPGrowth +import org.apache.spark.{SparkConf, SparkContext} + +/** + * Example for mining frequent itemsets using FP-growth. + * Example usage: ./bin/run-example mllib.FPGrowthExample \ + * --minSupport 0.8 --numPartition 2 ./data/mllib/sample_fpgrowth.txt + */ +object FPGrowthExample { + + case class Params( + input: String = null, + minSupport: Double = 0.3, + numPartition: Int = -1) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("FPGrowthExample") { + head("FPGrowth: an example FP-growth app.") + opt[Double]("minSupport") + .text(s"minimal support level, default: ${defaultParams.minSupport}") + .action((x, c) => c.copy(minSupport = x)) + opt[Int]("numPartition") + .text(s"number of partition, default: ${defaultParams.numPartition}") + .action((x, c) => c.copy(numPartition = x)) + arg[String]("") + .text("input paths to input data set, whose file format is that each line " + + "contains a transaction with each item in String and separated by a space") + .required() + .action((x, c) => c.copy(input = x)) + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"FPGrowthExample with $params") + val sc = new SparkContext(conf) + val transactions = sc.textFile(params.input).map(_.split(" ")).cache() + + println(s"Number of transactions: ${transactions.count()}") + + val model = new FPGrowth() + .setMinSupport(params.minSupport) + .setNumPartitions(params.numPartition) + .run(transactions) + + println(s"Number of frequent itemsets: ${model.freqItemsets.count()}") + + model.freqItemsets.collect().foreach { itemset => + println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) + } + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index f4c545ad70e96..11399a7633638 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -134,7 +134,7 @@ object LDAExample { .setTopicConcentration(params.topicConcentration) .setCheckpointInterval(params.checkpointInterval) if (params.checkpointDir.nonEmpty) { - lda.setCheckpointDir(params.checkpointDir.get) + sc.setCheckpointDir(params.checkpointDir.get) } val startTime = System.nanoTime() val ldaModel = lda.run(corpus) @@ -159,7 +159,7 @@ object LDAExample { } println() } - + sc.stop() } /** diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala new file mode 100644 index 0000000000000..91c9772744f18 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.log4j.{Level, Logger} +import scopt.OptionParser + +import org.apache.spark.mllib.clustering.PowerIterationClustering +import org.apache.spark.rdd.RDD +import org.apache.spark.{SparkConf, SparkContext} + +/** + * An example Power Iteration Clustering http://www.icml2010.org/papers/387.pdf app. + * Takes an input of K concentric circles and the number of points in the innermost circle. + * The output should be K clusters - each cluster containing precisely the points associated + * with each of the input circles. + * + * Run with + * {{{ + * ./bin/run-example mllib.PowerIterationClusteringExample [options] + * + * Where options include: + * k: Number of circles/clusters + * n: Number of sampled points on innermost circle.. There are proportionally more points + * within the outer/larger circles + * maxIterations: Number of Power Iterations + * outerRadius: radius of the outermost of the concentric circles + * }}} + * + * Here is a sample run and output: + * + * ./bin/run-example mllib.PowerIterationClusteringExample -k 3 --n 30 --maxIterations 15 + * + * Cluster assignments: 1 -> [0,1,2,3,4],2 -> [5,6,7,8,9,10,11,12,13,14], + * 0 -> [15,16,17,18,19,20,21,22,23,24,25,26,27,28,29] + * + * + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object PowerIterationClusteringExample { + + case class Params( + input: String = null, + k: Int = 3, + numPoints: Int = 5, + maxIterations: Int = 10, + outerRadius: Double = 3.0 + ) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("PIC Circles") { + head("PowerIterationClusteringExample: an example PIC app using concentric circles.") + opt[Int]('k', "k") + .text(s"number of circles (/clusters), default: ${defaultParams.k}") + .action((x, c) => c.copy(k = x)) + opt[Int]('n', "n") + .text(s"number of points in smallest circle, default: ${defaultParams.numPoints}") + .action((x, c) => c.copy(numPoints = x)) + opt[Int]("maxIterations") + .text(s"number of iterations, default: ${defaultParams.maxIterations}") + .action((x, c) => c.copy(maxIterations = x)) + opt[Int]('r', "r") + .text(s"radius of outermost circle, default: ${defaultParams.outerRadius}") + .action((x, c) => c.copy(numPoints = x)) + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf() + .setMaster("local") + .setAppName(s"PowerIterationClustering with $params") + val sc = new SparkContext(conf) + + Logger.getRootLogger.setLevel(Level.WARN) + + val circlesRdd = generateCirclesRdd(sc, params.k, params.numPoints, params.outerRadius) + val model = new PowerIterationClustering() + .setK(params.k) + .setMaxIterations(params.maxIterations) + .run(circlesRdd) + + val clusters = model.assignments.collect().groupBy(_.cluster).mapValues(_.map(_.id)) + val assignments = clusters.toList.sortBy { case (k, v) => v.length} + val assignmentsStr = assignments + .map { case (k, v) => + s"$k -> ${v.sorted.mkString("[", ",", "]")}" + }.mkString(",") + val sizesStr = assignments.map { + _._2.size + }.sorted.mkString("(", ",", ")") + println(s"Cluster assignments: $assignmentsStr\ncluster sizes: $sizesStr") + + sc.stop() + } + + def generateCircle(radius: Double, n: Int) = { + Seq.tabulate(n) { i => + val theta = 2.0 * math.Pi * i / n + (radius * math.cos(theta), radius * math.sin(theta)) + } + } + + def generateCirclesRdd(sc: SparkContext, + nCircles: Int = 3, + nPoints: Int = 30, + outerRadius: Double): RDD[(Long, Long, Double)] = { + + val radii = Array.tabulate(nCircles) { cx => outerRadius / (nCircles - cx)} + val groupSizes = Array.tabulate(nCircles) { cx => (cx + 1) * nPoints} + val points = (0 until nCircles).flatMap { cx => + generateCircle(radii(cx), groupSizes(cx)) + }.zipWithIndex + val rdd = sc.parallelize(points) + val distancesRdd = rdd.cartesian(rdd).flatMap { case (((x0, y0), i0), ((x1, y1), i1)) => + if (i0 < i1) { + Some((i0.toLong, i1.toLong, gaussianSimilarity((x0, y0), (x1, y1), 1.0))) + } else { + None + } + } + distancesRdd + } + + /** + * Gaussian Similarity: http://en.wikipedia.org/wiki/Radial_basis_function_kernel + */ + def gaussianSimilarity(p1: (Double, Double), p2: (Double, Double), sigma: Double) = { + val coeff = 1.0 / (math.sqrt(2.0 * math.Pi) * sigma) + val expCoeff = -1.0 / 2.0 * math.pow(sigma, 2.0) + val ssquares = (p1._1 - p2._1) * (p1._1 - p2._1) + (p1._2 - p2._2) * (p1._2 - p2._2) + coeff * math.exp(expCoeff * ssquares) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index 82a0b637b3cff..6331d1c0060f8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -19,7 +19,7 @@ package org.apache.spark.examples.sql import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ // One method for defining the schema of an RDD is to make a case class with the desired column // names and types. @@ -32,33 +32,33 @@ object RDDRelation { val sqlContext = new SQLContext(sc) // Importing the SQL context gives access to all the SQL functions and implicit conversions. - import sqlContext._ + import sqlContext.implicits._ - val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) + val df = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))).toDF() // Any RDD containing case classes can be registered as a table. The schema of the table is // automatically inferred using scala reflection. - rdd.registerTempTable("records") + df.registerTempTable("records") // Once tables have been registered, you can run SQL queries over them. println("Result of SELECT *:") - sql("SELECT * FROM records").collect().foreach(println) + sqlContext.sql("SELECT * FROM records").collect().foreach(println) // Aggregation queries are also supported. - val count = sql("SELECT COUNT(*) FROM records").collect().head.getLong(0) + val count = sqlContext.sql("SELECT COUNT(*) FROM records").collect().head.getLong(0) println(s"COUNT(*): $count") // The results of SQL queries are themselves RDDs and support all normal RDD functions. The // items in the RDD are of type Row, which allows you to access each column by ordinal. - val rddFromSql = sql("SELECT key, value FROM records WHERE key < 10") + val rddFromSql = sqlContext.sql("SELECT key, value FROM records WHERE key < 10") println("Result of RDD.map:") rddFromSql.map(row => s"Key: ${row(0)}, Value: ${row(1)}").collect().foreach(println) // Queries can also be written using a LINQ-like Scala DSL. - rdd.where($"key" === 1).orderBy($"value".asc).select($"key").collect().foreach(println) + df.where($"key" === 1).orderBy($"value".asc).select($"key").collect().foreach(println) // Write out an RDD as a parquet file. - rdd.saveAsParquetFile("pair.parquet") + df.saveAsParquetFile("pair.parquet") // Read in parquet file. Parquet files are self-describing so the schmema is preserved. val parquetFile = sqlContext.parquetFile("pair.parquet") @@ -68,7 +68,7 @@ object RDDRelation { // These files can also be registered as tables. parquetFile.registerTempTable("parquetFile") - sql("SELECT * FROM parquetFile").collect().foreach(println) + sqlContext.sql("SELECT * FROM parquetFile").collect().foreach(println) sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index 5725da1848114..b7ba60ec28155 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -43,7 +43,8 @@ object HiveFromSpark { // HiveContext. When not configured by the hive-site.xml, the context automatically // creates metastore_db and warehouse in the current directory. val hiveContext = new HiveContext(sc) - import hiveContext._ + import hiveContext.implicits._ + import hiveContext.sql sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") sql(s"LOAD DATA LOCAL INPATH '${kv1File.getAbsolutePath}' INTO TABLE src") @@ -67,7 +68,7 @@ object HiveFromSpark { // You can also register RDDs as temporary tables within a HiveContext. val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) - rdd.registerTempTable("records") + rdd.toDF().registerTempTable("records") // Queries can then join RDD data with data stored in Hive. println("Result of SELECT *:") diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala new file mode 100644 index 0000000000000..5a6b9216a3fbc --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming + +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Time, Seconds, StreamingContext} +import org.apache.spark.util.IntParam +import org.apache.spark.sql.SQLContext +import org.apache.spark.storage.StorageLevel + +/** + * Use DataFrames and SQL to count words in UTF8 encoded, '\n' delimited text received from the + * network every second. + * + * Usage: SqlNetworkWordCount + * and describe the TCP server that Spark Streaming would connect to receive data. + * + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ bin/run-example org.apache.spark.examples.streaming.SqlNetworkWordCount localhost 9999` + */ + +object SqlNetworkWordCount { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: NetworkWordCount ") + System.exit(1) + } + + StreamingExamples.setStreamingLogLevels() + + // Create the context with a 2 second batch size + val sparkConf = new SparkConf().setAppName("SqlNetworkWordCount") + val ssc = new StreamingContext(sparkConf, Seconds(2)) + + // Create a socket stream on target ip:port and count the + // words in input stream of \n delimited text (eg. generated by 'nc') + // Note that no duplication in storage level only for running locally. + // Replication necessary in distributed scenario for fault tolerance. + val lines = ssc.socketTextStream(args(0), args(1).toInt, StorageLevel.MEMORY_AND_DISK_SER) + val words = lines.flatMap(_.split(" ")) + + // Convert RDDs of the words DStream to DataFrame and run SQL query + words.foreachRDD((rdd: RDD[String], time: Time) => { + // Get the singleton instance of SQLContext + val sqlContext = SQLContextSingleton.getInstance(rdd.sparkContext) + import sqlContext.implicits._ + + // Convert RDD[String] to RDD[case class] to DataFrame + val wordsDataFrame = rdd.map(w => Record(w)).toDF() + + // Register as table + wordsDataFrame.registerTempTable("words") + + // Do word count on table using SQL and print it + val wordCountsDataFrame = + sqlContext.sql("select word, count(*) as total from words group by word") + println(s"========= $time =========") + wordCountsDataFrame.show() + }) + + ssc.start() + ssc.awaitTermination() + } +} + + +/** Case class for converting RDD to DataFrame */ +case class Record(word: String) + + +/** Lazily instantiated singleton instance of SQLContext */ +object SQLContextSingleton { + + @transient private var instance: SQLContext = _ + + def getInstance(sparkContext: SparkContext): SQLContext = { + if (instance == null) { + instance = new SQLContext(sparkContext) + } + instance + } +} diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 0706f1ebf66e2..f46a2a091db75 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.3.2-SNAPSHOT ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 1f2681394c583..02331e83d3f8f 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.3.2-SNAPSHOT ../../pom.xml diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala index 4b732c1592ab2..44dec45c227ca 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala @@ -19,7 +19,6 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress -import org.apache.spark.annotation.Experimental import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} @@ -121,7 +120,6 @@ object FlumeUtils { * @param port Port of the host at which the Spark Sink is listening * @param storageLevel Storage level to use for storing the received objects */ - @Experimental def createPollingStream( ssc: StreamingContext, hostname: String, @@ -138,7 +136,6 @@ object FlumeUtils { * @param addresses List of InetSocketAddresses representing the hosts to connect to. * @param storageLevel Storage level to use for storing the received objects */ - @Experimental def createPollingStream( ssc: StreamingContext, addresses: Seq[InetSocketAddress], @@ -159,7 +156,6 @@ object FlumeUtils { * result in this stream using more threads * @param storageLevel Storage level to use for storing the received objects */ - @Experimental def createPollingStream( ssc: StreamingContext, addresses: Seq[InetSocketAddress], @@ -178,7 +174,6 @@ object FlumeUtils { * @param hostname Hostname of the host on which the Spark Sink is running * @param port Port of the host at which the Spark Sink is listening */ - @Experimental def createPollingStream( jssc: JavaStreamingContext, hostname: String, @@ -195,7 +190,6 @@ object FlumeUtils { * @param port Port of the host at which the Spark Sink is listening * @param storageLevel Storage level to use for storing the received objects */ - @Experimental def createPollingStream( jssc: JavaStreamingContext, hostname: String, @@ -212,7 +206,6 @@ object FlumeUtils { * @param addresses List of InetSocketAddresses on which the Spark Sink is running. * @param storageLevel Storage level to use for storing the received objects */ - @Experimental def createPollingStream( jssc: JavaStreamingContext, addresses: Array[InetSocketAddress], @@ -233,7 +226,6 @@ object FlumeUtils { * result in this stream using more threads * @param storageLevel Storage level to use for storing the received objects */ - @Experimental def createPollingStream( jssc: JavaStreamingContext, addresses: Array[InetSocketAddress], diff --git a/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 1e24da7f5f60c..cfedb5a042a35 100644 --- a/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -31,7 +31,7 @@ public void setUp() { SparkConf conf = new SparkConf() .setMaster("local[2]") .setAppName("test") - .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index b57a1c71e35b9..e04d4088df7dc 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -34,10 +34,9 @@ import org.scalatest.{BeforeAndAfter, FunSuite} import org.apache.spark.{SparkConf, Logging} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.streaming.util.ManualClock import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext} import org.apache.spark.streaming.flume.sink._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{ManualClock, Utils} class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging { @@ -54,7 +53,7 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging def beforeFunction() { logInfo("Using manual clock") - conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") + conf.set("spark.streaming.clock", "org.apache.spark.util.ManualClock") } before(beforeFunction()) @@ -236,7 +235,7 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging tx.commit() tx.close() Thread.sleep(500) // Allow some time for the events to reach - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) } null } diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index f333e3891b5f0..51d273af8da84 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -19,15 +19,16 @@ package org.apache.spark.streaming.flume import java.net.{InetSocketAddress, ServerSocket} import java.nio.ByteBuffer -import java.nio.charset.Charset import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.concurrent.duration._ import scala.language.postfixOps +import com.google.common.base.Charsets import org.apache.avro.ipc.NettyTransceiver import org.apache.avro.ipc.specific.SpecificRequestor +import org.apache.commons.lang3.RandomUtils import org.apache.flume.source.avro import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol} import org.jboss.netty.channel.ChannelPipeline @@ -40,7 +41,6 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} -import org.apache.spark.streaming.scheduler.{StreamingListener, StreamingListenerReceiverStarted} import org.apache.spark.util.Utils class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { @@ -76,7 +76,8 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L /** Find a free port */ private def findFreePort(): Int = { - Utils.startServiceOnPort(23456, (trialPort: Int) => { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { val socket = new ServerSocket(trialPort) socket.close() (null, trialPort) @@ -108,7 +109,7 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L val inputEvents = input.map { item => val event = new AvroFlumeEvent - event.setBody(ByteBuffer.wrap(item.getBytes("UTF-8"))) + event.setBody(ByteBuffer.wrap(item.getBytes(Charsets.UTF_8))) event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) event } @@ -138,14 +139,13 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L status should be (avro.Status.OK) } - val decoder = Charset.forName("UTF-8").newDecoder() eventually(timeout(10 seconds), interval(100 milliseconds)) { val outputEvents = outputBuffer.flatten.map { _.event } outputEvents.foreach { event => event.getHeaders.get("test") should be("header") } - val output = outputEvents.map(event => decoder.decode(event.getBody()).toString) + val output = outputEvents.map(event => new String(event.getBody.array(), Charsets.UTF_8)) output should be (input) } } diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml index 503fc129dc4f2..23ca5d2a72fe6 100644 --- a/external/kafka-assembly/pom.xml +++ b/external/kafka-assembly/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.3.2-SNAPSHOT ../../pom.xml @@ -33,9 +33,6 @@ streaming-kafka-assembly - scala-${scala.binary.version} - spark-streaming-kafka-assembly-${project.version}.jar - ${project.build.directory}/${spark.jar.dir}/${spark.jar.basename} @@ -61,7 +58,6 @@ maven-shade-plugin false - ${spark.jar} *:* diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index b29b0509656ba..5b8f4a358108b 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.3.2-SNAPSHOT ../../pom.xml @@ -44,7 +44,7 @@ org.apache.kafka kafka_${scala.binary.version} - 0.8.0 + 0.8.1.1 com.sun.jmx diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala new file mode 100644 index 0000000000000..5a74febb4bd46 --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Represent the host and port info for a Kafka broker. + * Differs from the Kafka project's internal kafka.cluster.Broker, which contains a server ID + */ +@Experimental +final class Broker private( + /** Broker's hostname */ + val host: String, + /** Broker's port */ + val port: Int) extends Serializable { + override def equals(obj: Any): Boolean = obj match { + case that: Broker => + this.host == that.host && + this.port == that.port + case _ => false + } + + override def hashCode: Int = { + 41 * (41 + host.hashCode) + port + } + + override def toString(): String = { + s"Broker($host, $port)" + } +} + +/** + * :: Experimental :: + * Companion object that provides methods to create instances of [[Broker]]. + */ +@Experimental +object Broker { + def create(host: String, port: Int): Broker = + new Broker(host, port) + + def apply(host: String, port: Int): Broker = + new Broker(host, port) + + def unapply(broker: Broker): Option[(String, Int)] = { + if (broker == null) { + None + } else { + Some((broker.host, broker.port)) + } + } +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala new file mode 100644 index 0000000000000..04e65cb3d708c --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + + +import scala.annotation.tailrec +import scala.collection.mutable +import scala.reflect.{classTag, ClassTag} + +import kafka.common.TopicAndPartition +import kafka.message.MessageAndMetadata +import kafka.serializer.Decoder + +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset +import org.apache.spark.streaming.{StreamingContext, Time} +import org.apache.spark.streaming.dstream._ + +/** + * A stream of {@link org.apache.spark.streaming.kafka.KafkaRDD} where + * each given Kafka topic/partition corresponds to an RDD partition. + * The spark configuration spark.streaming.kafka.maxRatePerPartition gives the maximum number + * of messages + * per second that each '''partition''' will accept. + * Starting offsets are specified in advance, + * and this DStream is not responsible for committing offsets, + * so that you can control exactly-once semantics. + * For an easy interface to Kafka-managed offsets, + * see {@link org.apache.spark.streaming.kafka.KafkaCluster} + * @param kafkaParams Kafka + * configuration parameters. + * Requires "metadata.broker.list" or "bootstrap.servers" to be set with Kafka broker(s), + * NOT zookeeper servers, specified in host1:port1,host2:port2 form. + * @param fromOffsets per-topic/partition Kafka offsets defining the (inclusive) + * starting point of the stream + * @param messageHandler function for translating each message into the desired type + */ +private[streaming] +class DirectKafkaInputDStream[ + K: ClassTag, + V: ClassTag, + U <: Decoder[K]: ClassTag, + T <: Decoder[V]: ClassTag, + R: ClassTag]( + @transient ssc_ : StreamingContext, + val kafkaParams: Map[String, String], + val fromOffsets: Map[TopicAndPartition, Long], + messageHandler: MessageAndMetadata[K, V] => R +) extends InputDStream[R](ssc_) with Logging { + val maxRetries = context.sparkContext.getConf.getInt( + "spark.streaming.kafka.maxRetries", 1) + + protected[streaming] override val checkpointData = + new DirectKafkaInputDStreamCheckpointData + + protected val kc = new KafkaCluster(kafkaParams) + + protected val maxMessagesPerPartition: Option[Long] = { + val ratePerSec = context.sparkContext.getConf.getInt( + "spark.streaming.kafka.maxRatePerPartition", 0) + if (ratePerSec > 0) { + val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000 + Some((secsPerBatch * ratePerSec).toLong) + } else { + None + } + } + + protected var currentOffsets = fromOffsets + + @tailrec + protected final def latestLeaderOffsets(retries: Int): Map[TopicAndPartition, LeaderOffset] = { + val o = kc.getLatestLeaderOffsets(currentOffsets.keySet) + // Either.fold would confuse @tailrec, do it manually + if (o.isLeft) { + val err = o.left.get.toString + if (retries <= 0) { + throw new SparkException(err) + } else { + log.error(err) + Thread.sleep(kc.config.refreshLeaderBackoffMs) + latestLeaderOffsets(retries - 1) + } + } else { + o.right.get + } + } + + // limits the maximum number of messages per partition + protected def clamp( + leaderOffsets: Map[TopicAndPartition, LeaderOffset]): Map[TopicAndPartition, LeaderOffset] = { + maxMessagesPerPartition.map { mmp => + leaderOffsets.map { case (tp, lo) => + tp -> lo.copy(offset = Math.min(currentOffsets(tp) + mmp, lo.offset)) + } + }.getOrElse(leaderOffsets) + } + + override def compute(validTime: Time): Option[KafkaRDD[K, V, U, T, R]] = { + val untilOffsets = clamp(latestLeaderOffsets(maxRetries)) + val rdd = KafkaRDD[K, V, U, T, R]( + context.sparkContext, kafkaParams, currentOffsets, untilOffsets, messageHandler) + + currentOffsets = untilOffsets.map(kv => kv._1 -> kv._2.offset) + Some(rdd) + } + + override def start(): Unit = { + } + + def stop(): Unit = { + } + + private[streaming] + class DirectKafkaInputDStreamCheckpointData extends DStreamCheckpointData(this) { + def batchForTime = data.asInstanceOf[mutable.HashMap[ + Time, Array[OffsetRange.OffsetRangeTuple]]] + + override def update(time: Time) { + batchForTime.clear() + generatedRDDs.foreach { kv => + val a = kv._2.asInstanceOf[KafkaRDD[K, V, U, T, R]].offsetRanges.map(_.toTuple).toArray + batchForTime += kv._1 -> a + } + } + + override def cleanup(time: Time) { } + + override def restore() { + // this is assuming that the topics don't change during execution, which is true currently + val topics = fromOffsets.keySet + val leaders = kc.findLeaders(topics).fold( + errs => throw new SparkException(errs.mkString("\n")), + ok => ok + ) + + batchForTime.toSeq.sortBy(_._1)(Time.ordering).foreach { case (t, b) => + logInfo(s"Restoring KafkaRDD for time $t ${b.mkString("[", ", ", "]")}") + generatedRDDs += t -> new KafkaRDD[K, V, U, T, R]( + context.sparkContext, kafkaParams, b.map(OffsetRange(_)), leaders, messageHandler) + } + } + } + +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala new file mode 100644 index 0000000000000..2f7e0ab39fefd --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -0,0 +1,372 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import scala.util.control.NonFatal +import scala.util.Random +import scala.collection.mutable.ArrayBuffer +import java.util.Properties +import kafka.api._ +import kafka.common.{ErrorMapping, OffsetMetadataAndError, TopicAndPartition} +import kafka.consumer.{ConsumerConfig, SimpleConsumer} +import org.apache.spark.SparkException + +/** + * Convenience methods for interacting with a Kafka cluster. + * @param kafkaParams Kafka + * configuration parameters. + * Requires "metadata.broker.list" or "bootstrap.servers" to be set with Kafka broker(s), + * NOT zookeeper servers, specified in host1:port1,host2:port2 form + */ +private[spark] +class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { + import KafkaCluster.{Err, LeaderOffset, SimpleConsumerConfig} + + // ConsumerConfig isn't serializable + @transient private var _config: SimpleConsumerConfig = null + + def config: SimpleConsumerConfig = this.synchronized { + if (_config == null) { + _config = SimpleConsumerConfig(kafkaParams) + } + _config + } + + def connect(host: String, port: Int): SimpleConsumer = + new SimpleConsumer(host, port, config.socketTimeoutMs, + config.socketReceiveBufferBytes, config.clientId) + + def connectLeader(topic: String, partition: Int): Either[Err, SimpleConsumer] = + findLeader(topic, partition).right.map(hp => connect(hp._1, hp._2)) + + // Metadata api + // scalastyle:off + // https://cwiki.apache.org/confluence/display/KAFKA/A+Guide+To+The+Kafka+Protocol#AGuideToTheKafkaProtocol-MetadataAPI + // scalastyle:on + + def findLeader(topic: String, partition: Int): Either[Err, (String, Int)] = { + val req = TopicMetadataRequest(TopicMetadataRequest.CurrentVersion, + 0, config.clientId, Seq(topic)) + val errs = new Err + withBrokers(Random.shuffle(config.seedBrokers), errs) { consumer => + val resp: TopicMetadataResponse = consumer.send(req) + resp.topicsMetadata.find(_.topic == topic).flatMap { tm: TopicMetadata => + tm.partitionsMetadata.find(_.partitionId == partition) + }.foreach { pm: PartitionMetadata => + pm.leader.foreach { leader => + return Right((leader.host, leader.port)) + } + } + } + Left(errs) + } + + def findLeaders( + topicAndPartitions: Set[TopicAndPartition] + ): Either[Err, Map[TopicAndPartition, (String, Int)]] = { + val topics = topicAndPartitions.map(_.topic) + val response = getPartitionMetadata(topics).right + val answer = response.flatMap { tms: Set[TopicMetadata] => + val leaderMap = tms.flatMap { tm: TopicMetadata => + tm.partitionsMetadata.flatMap { pm: PartitionMetadata => + val tp = TopicAndPartition(tm.topic, pm.partitionId) + if (topicAndPartitions(tp)) { + pm.leader.map { l => + tp -> (l.host -> l.port) + } + } else { + None + } + } + }.toMap + + if (leaderMap.keys.size == topicAndPartitions.size) { + Right(leaderMap) + } else { + val missing = topicAndPartitions.diff(leaderMap.keySet) + val err = new Err + err.append(new SparkException(s"Couldn't find leaders for ${missing}")) + Left(err) + } + } + answer + } + + def getPartitions(topics: Set[String]): Either[Err, Set[TopicAndPartition]] = { + getPartitionMetadata(topics).right.map { r => + r.flatMap { tm: TopicMetadata => + tm.partitionsMetadata.map { pm: PartitionMetadata => + TopicAndPartition(tm.topic, pm.partitionId) + } + } + } + } + + def getPartitionMetadata(topics: Set[String]): Either[Err, Set[TopicMetadata]] = { + val req = TopicMetadataRequest( + TopicMetadataRequest.CurrentVersion, 0, config.clientId, topics.toSeq) + val errs = new Err + withBrokers(Random.shuffle(config.seedBrokers), errs) { consumer => + val resp: TopicMetadataResponse = consumer.send(req) + // error codes here indicate missing / just created topic, + // repeating on a different broker wont be useful + return Right(resp.topicsMetadata.toSet) + } + Left(errs) + } + + // Leader offset api + // scalastyle:off + // https://cwiki.apache.org/confluence/display/KAFKA/A+Guide+To+The+Kafka+Protocol#AGuideToTheKafkaProtocol-OffsetAPI + // scalastyle:on + + def getLatestLeaderOffsets( + topicAndPartitions: Set[TopicAndPartition] + ): Either[Err, Map[TopicAndPartition, LeaderOffset]] = + getLeaderOffsets(topicAndPartitions, OffsetRequest.LatestTime) + + def getEarliestLeaderOffsets( + topicAndPartitions: Set[TopicAndPartition] + ): Either[Err, Map[TopicAndPartition, LeaderOffset]] = + getLeaderOffsets(topicAndPartitions, OffsetRequest.EarliestTime) + + def getLeaderOffsets( + topicAndPartitions: Set[TopicAndPartition], + before: Long + ): Either[Err, Map[TopicAndPartition, LeaderOffset]] = { + getLeaderOffsets(topicAndPartitions, before, 1).right.map { r => + r.map { kv => + // mapValues isnt serializable, see SI-7005 + kv._1 -> kv._2.head + } + } + } + + private def flip[K, V](m: Map[K, V]): Map[V, Seq[K]] = + m.groupBy(_._2).map { kv => + kv._1 -> kv._2.keys.toSeq + } + + def getLeaderOffsets( + topicAndPartitions: Set[TopicAndPartition], + before: Long, + maxNumOffsets: Int + ): Either[Err, Map[TopicAndPartition, Seq[LeaderOffset]]] = { + findLeaders(topicAndPartitions).right.flatMap { tpToLeader => + val leaderToTp: Map[(String, Int), Seq[TopicAndPartition]] = flip(tpToLeader) + val leaders = leaderToTp.keys + var result = Map[TopicAndPartition, Seq[LeaderOffset]]() + val errs = new Err + withBrokers(leaders, errs) { consumer => + val partitionsToGetOffsets: Seq[TopicAndPartition] = + leaderToTp((consumer.host, consumer.port)) + val reqMap = partitionsToGetOffsets.map { tp: TopicAndPartition => + tp -> PartitionOffsetRequestInfo(before, maxNumOffsets) + }.toMap + val req = OffsetRequest(reqMap) + val resp = consumer.getOffsetsBefore(req) + val respMap = resp.partitionErrorAndOffsets + partitionsToGetOffsets.foreach { tp: TopicAndPartition => + respMap.get(tp).foreach { por: PartitionOffsetsResponse => + if (por.error == ErrorMapping.NoError) { + if (por.offsets.nonEmpty) { + result += tp -> por.offsets.map { off => + LeaderOffset(consumer.host, consumer.port, off) + } + } else { + errs.append(new SparkException( + s"Empty offsets for ${tp}, is ${before} before log beginning?")) + } + } else { + errs.append(ErrorMapping.exceptionFor(por.error)) + } + } + } + if (result.keys.size == topicAndPartitions.size) { + return Right(result) + } + } + val missing = topicAndPartitions.diff(result.keySet) + errs.append(new SparkException(s"Couldn't find leader offsets for ${missing}")) + Left(errs) + } + } + + // Consumer offset api + // scalastyle:off + // https://cwiki.apache.org/confluence/display/KAFKA/A+Guide+To+The+Kafka+Protocol#AGuideToTheKafkaProtocol-OffsetCommit/FetchAPI + // scalastyle:on + + /** Requires Kafka >= 0.8.1.1 */ + def getConsumerOffsets( + groupId: String, + topicAndPartitions: Set[TopicAndPartition] + ): Either[Err, Map[TopicAndPartition, Long]] = { + getConsumerOffsetMetadata(groupId, topicAndPartitions).right.map { r => + r.map { kv => + kv._1 -> kv._2.offset + } + } + } + + /** Requires Kafka >= 0.8.1.1 */ + def getConsumerOffsetMetadata( + groupId: String, + topicAndPartitions: Set[TopicAndPartition] + ): Either[Err, Map[TopicAndPartition, OffsetMetadataAndError]] = { + var result = Map[TopicAndPartition, OffsetMetadataAndError]() + val req = OffsetFetchRequest(groupId, topicAndPartitions.toSeq) + val errs = new Err + withBrokers(Random.shuffle(config.seedBrokers), errs) { consumer => + val resp = consumer.fetchOffsets(req) + val respMap = resp.requestInfo + val needed = topicAndPartitions.diff(result.keySet) + needed.foreach { tp: TopicAndPartition => + respMap.get(tp).foreach { ome: OffsetMetadataAndError => + if (ome.error == ErrorMapping.NoError) { + result += tp -> ome + } else { + errs.append(ErrorMapping.exceptionFor(ome.error)) + } + } + } + if (result.keys.size == topicAndPartitions.size) { + return Right(result) + } + } + val missing = topicAndPartitions.diff(result.keySet) + errs.append(new SparkException(s"Couldn't find consumer offsets for ${missing}")) + Left(errs) + } + + /** Requires Kafka >= 0.8.1.1 */ + def setConsumerOffsets( + groupId: String, + offsets: Map[TopicAndPartition, Long] + ): Either[Err, Map[TopicAndPartition, Short]] = { + setConsumerOffsetMetadata(groupId, offsets.map { kv => + kv._1 -> OffsetMetadataAndError(kv._2) + }) + } + + /** Requires Kafka >= 0.8.1.1 */ + def setConsumerOffsetMetadata( + groupId: String, + metadata: Map[TopicAndPartition, OffsetMetadataAndError] + ): Either[Err, Map[TopicAndPartition, Short]] = { + var result = Map[TopicAndPartition, Short]() + val req = OffsetCommitRequest(groupId, metadata) + val errs = new Err + val topicAndPartitions = metadata.keySet + withBrokers(Random.shuffle(config.seedBrokers), errs) { consumer => + val resp = consumer.commitOffsets(req) + val respMap = resp.requestInfo + val needed = topicAndPartitions.diff(result.keySet) + needed.foreach { tp: TopicAndPartition => + respMap.get(tp).foreach { err: Short => + if (err == ErrorMapping.NoError) { + result += tp -> err + } else { + errs.append(ErrorMapping.exceptionFor(err)) + } + } + } + if (result.keys.size == topicAndPartitions.size) { + return Right(result) + } + } + val missing = topicAndPartitions.diff(result.keySet) + errs.append(new SparkException(s"Couldn't set offsets for ${missing}")) + Left(errs) + } + + // Try a call against potentially multiple brokers, accumulating errors + private def withBrokers(brokers: Iterable[(String, Int)], errs: Err) + (fn: SimpleConsumer => Any): Unit = { + brokers.foreach { hp => + var consumer: SimpleConsumer = null + try { + consumer = connect(hp._1, hp._2) + fn(consumer) + } catch { + case NonFatal(e) => + errs.append(e) + } finally { + if (consumer != null) { + consumer.close() + } + } + } + } +} + +private[spark] +object KafkaCluster { + type Err = ArrayBuffer[Throwable] + + private[spark] + case class LeaderOffset(host: String, port: Int, offset: Long) + + /** + * High-level kafka consumers connect to ZK. ConsumerConfig assumes this use case. + * Simple consumers connect directly to brokers, but need many of the same configs. + * This subclass won't warn about missing ZK params, or presence of broker params. + */ + private[spark] + class SimpleConsumerConfig private(brokers: String, originalProps: Properties) + extends ConsumerConfig(originalProps) { + val seedBrokers: Array[(String, Int)] = brokers.split(",").map { hp => + val hpa = hp.split(":") + if (hpa.size == 1) { + throw new SparkException(s"Broker not the in correct format of : [$brokers]") + } + (hpa(0), hpa(1).toInt) + } + } + + private[spark] + object SimpleConsumerConfig { + /** + * Make a consumer config without requiring group.id or zookeeper.connect, + * since communicating with brokers also needs common settings such as timeout + */ + def apply(kafkaParams: Map[String, String]): SimpleConsumerConfig = { + // These keys are from other pre-existing kafka configs for specifying brokers, accept either + val brokers = kafkaParams.get("metadata.broker.list") + .orElse(kafkaParams.get("bootstrap.servers")) + .getOrElse(throw new SparkException( + "Must specify metadata.broker.list or bootstrap.servers")) + + val props = new Properties() + kafkaParams.foreach { case (key, value) => + // prevent warnings on parameters ConsumerConfig doesn't know about + if (key != "metadata.broker.list" && key != "bootstrap.servers") { + props.put(key, value) + } + } + + Seq("zookeeper.connect", "group.id").foreach { s => + if (!props.contains(s)) { + props.setProperty(s, "") + } + } + + new SimpleConsumerConfig(brokers, props) + } + } +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala new file mode 100644 index 0000000000000..d56cc01be9514 --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import scala.reflect.{classTag, ClassTag} + +import org.apache.spark.{Logging, Partition, SparkContext, SparkException, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.util.NextIterator + +import java.util.Properties +import kafka.api.{FetchRequestBuilder, FetchResponse} +import kafka.common.{ErrorMapping, TopicAndPartition} +import kafka.consumer.{ConsumerConfig, SimpleConsumer} +import kafka.message.{MessageAndMetadata, MessageAndOffset} +import kafka.serializer.Decoder +import kafka.utils.VerifiableProperties + +/** + * A batch-oriented interface for consuming from Kafka. + * Starting and ending offsets are specified in advance, + * so that you can control exactly-once semantics. + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" to be set + * with Kafka broker(s) specified in host1:port1,host2:port2 form. + * @param offsetRanges offset ranges that define the Kafka data belonging to this RDD + * @param messageHandler function for translating each message into the desired type + */ +private[kafka] +class KafkaRDD[ + K: ClassTag, + V: ClassTag, + U <: Decoder[_]: ClassTag, + T <: Decoder[_]: ClassTag, + R: ClassTag] private[spark] ( + sc: SparkContext, + kafkaParams: Map[String, String], + val offsetRanges: Array[OffsetRange], + leaders: Map[TopicAndPartition, (String, Int)], + messageHandler: MessageAndMetadata[K, V] => R + ) extends RDD[R](sc, Nil) with Logging with HasOffsetRanges { + override def getPartitions: Array[Partition] = { + offsetRanges.zipWithIndex.map { case (o, i) => + val (host, port) = leaders(TopicAndPartition(o.topic, o.partition)) + new KafkaRDDPartition(i, o.topic, o.partition, o.fromOffset, o.untilOffset, host, port) + }.toArray + } + + override def getPreferredLocations(thePart: Partition): Seq[String] = { + val part = thePart.asInstanceOf[KafkaRDDPartition] + // TODO is additional hostname resolution necessary here + Seq(part.host) + } + + private def errBeginAfterEnd(part: KafkaRDDPartition): String = + s"Beginning offset ${part.fromOffset} is after the ending offset ${part.untilOffset} " + + s"for topic ${part.topic} partition ${part.partition}. " + + "You either provided an invalid fromOffset, or the Kafka topic has been damaged" + + private def errRanOutBeforeEnd(part: KafkaRDDPartition): String = + s"Ran out of messages before reaching ending offset ${part.untilOffset} " + + s"for topic ${part.topic} partition ${part.partition} start ${part.fromOffset}." + + " This should not happen, and indicates that messages may have been lost" + + private def errOvershotEnd(itemOffset: Long, part: KafkaRDDPartition): String = + s"Got ${itemOffset} > ending offset ${part.untilOffset} " + + s"for topic ${part.topic} partition ${part.partition} start ${part.fromOffset}." + + " This should not happen, and indicates a message may have been skipped" + + override def compute(thePart: Partition, context: TaskContext): Iterator[R] = { + val part = thePart.asInstanceOf[KafkaRDDPartition] + assert(part.fromOffset <= part.untilOffset, errBeginAfterEnd(part)) + if (part.fromOffset == part.untilOffset) { + log.warn("Beginning offset ${part.fromOffset} is the same as ending offset " + + s"skipping ${part.topic} ${part.partition}") + Iterator.empty + } else { + new KafkaRDDIterator(part, context) + } + } + + private class KafkaRDDIterator( + part: KafkaRDDPartition, + context: TaskContext) extends NextIterator[R] { + + context.addTaskCompletionListener{ context => closeIfNeeded() } + + log.info(s"Computing topic ${part.topic}, partition ${part.partition} " + + s"offsets ${part.fromOffset} -> ${part.untilOffset}") + + val kc = new KafkaCluster(kafkaParams) + val keyDecoder = classTag[U].runtimeClass.getConstructor(classOf[VerifiableProperties]) + .newInstance(kc.config.props) + .asInstanceOf[Decoder[K]] + val valueDecoder = classTag[T].runtimeClass.getConstructor(classOf[VerifiableProperties]) + .newInstance(kc.config.props) + .asInstanceOf[Decoder[V]] + val consumer = connectLeader + var requestOffset = part.fromOffset + var iter: Iterator[MessageAndOffset] = null + + // The idea is to use the provided preferred host, except on task retry atttempts, + // to minimize number of kafka metadata requests + private def connectLeader: SimpleConsumer = { + if (context.attemptNumber > 0) { + kc.connectLeader(part.topic, part.partition).fold( + errs => throw new SparkException( + s"Couldn't connect to leader for topic ${part.topic} ${part.partition}: " + + errs.mkString("\n")), + consumer => consumer + ) + } else { + kc.connect(part.host, part.port) + } + } + + private def handleFetchErr(resp: FetchResponse) { + if (resp.hasError) { + val err = resp.errorCode(part.topic, part.partition) + if (err == ErrorMapping.LeaderNotAvailableCode || + err == ErrorMapping.NotLeaderForPartitionCode) { + log.error(s"Lost leader for topic ${part.topic} partition ${part.partition}, " + + s" sleeping for ${kc.config.refreshLeaderBackoffMs}ms") + Thread.sleep(kc.config.refreshLeaderBackoffMs) + } + // Let normal rdd retry sort out reconnect attempts + throw ErrorMapping.exceptionFor(err) + } + } + + private def fetchBatch: Iterator[MessageAndOffset] = { + val req = new FetchRequestBuilder() + .addFetch(part.topic, part.partition, requestOffset, kc.config.fetchMessageMaxBytes) + .build() + val resp = consumer.fetch(req) + handleFetchErr(resp) + // kafka may return a batch that starts before the requested offset + resp.messageSet(part.topic, part.partition) + .iterator + .dropWhile(_.offset < requestOffset) + } + + override def close() = consumer.close() + + override def getNext(): R = { + if (iter == null || !iter.hasNext) { + iter = fetchBatch + } + if (!iter.hasNext) { + assert(requestOffset == part.untilOffset, errRanOutBeforeEnd(part)) + finished = true + null.asInstanceOf[R] + } else { + val item = iter.next() + if (item.offset >= part.untilOffset) { + assert(item.offset == part.untilOffset, errOvershotEnd(item.offset, part)) + finished = true + null.asInstanceOf[R] + } else { + requestOffset = item.nextOffset + messageHandler(new MessageAndMetadata( + part.topic, part.partition, item.message, item.offset, keyDecoder, valueDecoder)) + } + } + } + } +} + +private[kafka] +object KafkaRDD { + import KafkaCluster.LeaderOffset + + /** + * @param kafkaParams Kafka + * configuration parameters. + * Requires "metadata.broker.list" or "bootstrap.servers" to be set with Kafka broker(s), + * NOT zookeeper servers, specified in host1:port1,host2:port2 form. + * @param fromOffsets per-topic/partition Kafka offsets defining the (inclusive) + * starting point of the batch + * @param untilOffsets per-topic/partition Kafka offsets defining the (exclusive) + * ending point of the batch + * @param messageHandler function for translating each message into the desired type + */ + def apply[ + K: ClassTag, + V: ClassTag, + U <: Decoder[_]: ClassTag, + T <: Decoder[_]: ClassTag, + R: ClassTag]( + sc: SparkContext, + kafkaParams: Map[String, String], + fromOffsets: Map[TopicAndPartition, Long], + untilOffsets: Map[TopicAndPartition, LeaderOffset], + messageHandler: MessageAndMetadata[K, V] => R + ): KafkaRDD[K, V, U, T, R] = { + val leaders = untilOffsets.map { case (tp, lo) => + tp -> (lo.host, lo.port) + }.toMap + + val offsetRanges = fromOffsets.map { case (tp, fo) => + val uo = untilOffsets(tp) + OffsetRange(tp.topic, tp.partition, fo, uo.offset) + }.toArray + + new KafkaRDD[K, V, U, T, R](sc, kafkaParams, offsetRanges, leaders, messageHandler) + } +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala new file mode 100644 index 0000000000000..a842a6f17766f --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import org.apache.spark.Partition + +/** @param topic kafka topic name + * @param partition kafka partition id + * @param fromOffset inclusive starting offset + * @param untilOffset exclusive ending offset + * @param host preferred kafka host, i.e. the leader at the time the rdd was created + * @param port preferred kafka host's port + */ +private[kafka] +class KafkaRDDPartition( + val index: Int, + val topic: String, + val partition: Int, + val fromOffset: Long, + val untilOffset: Long, + val host: String, + val port: Int +) extends Partition diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index df725f0c65a64..5a9bd4214cf51 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -18,21 +18,30 @@ package org.apache.spark.streaming.kafka import java.lang.{Integer => JInt} +import java.lang.{Long => JLong} import java.util.{Map => JMap} +import java.util.{Set => JSet} import scala.reflect.ClassTag import scala.collection.JavaConversions._ -import kafka.serializer.{Decoder, StringDecoder} +import kafka.common.TopicAndPartition +import kafka.message.MessageAndMetadata +import kafka.serializer.{DefaultDecoder, Decoder, StringDecoder} +import org.apache.spark.api.java.function.{Function => JFunction} +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.annotation.Experimental +import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaPairReceiverInputDStream, JavaStreamingContext} -import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.api.java.{JavaPairInputDStream, JavaInputDStream, JavaPairReceiverInputDStream, JavaStreamingContext} +import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream} +import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} object KafkaUtils { /** - * Create an input stream that pulls messages from a Kafka Broker. + * Create an input stream that pulls messages from Kafka Brokers. * @param ssc StreamingContext object * @param zkQuorum Zookeeper quorum (hostname:port,hostname:port,..) * @param groupId The group id for this consumer @@ -56,7 +65,7 @@ object KafkaUtils { } /** - * Create an input stream that pulls messages from a Kafka Broker. + * Create an input stream that pulls messages from Kafka Brokers. * @param ssc StreamingContext object * @param kafkaParams Map of kafka configuration parameters, * see http://kafka.apache.org/08/configuration.html @@ -75,7 +84,7 @@ object KafkaUtils { } /** - * Create an input stream that pulls messages from a Kafka Broker. + * Create an input stream that pulls messages from Kafka Brokers. * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2. * @param jssc JavaStreamingContext object * @param zkQuorum Zookeeper quorum (hostname:port,hostname:port,..) @@ -93,7 +102,7 @@ object KafkaUtils { } /** - * Create an input stream that pulls messages from a Kafka Broker. + * Create an input stream that pulls messages from Kafka Brokers. * @param jssc JavaStreamingContext object * @param zkQuorum Zookeeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. @@ -113,10 +122,10 @@ object KafkaUtils { } /** - * Create an input stream that pulls messages from a Kafka Broker. + * Create an input stream that pulls messages from Kafka Brokers. * @param jssc JavaStreamingContext object - * @param keyTypeClass Key type of RDD - * @param valueTypeClass value type of RDD + * @param keyTypeClass Key type of DStream + * @param valueTypeClass value type of Dstream * @param keyDecoderClass Type of kafka key decoder * @param valueDecoderClass Type of kafka value decoder * @param kafkaParams Map of kafka configuration parameters, @@ -144,4 +153,409 @@ object KafkaUtils { createStream[K, V, U, T]( jssc.ssc, kafkaParams.toMap, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel) } + + /** get leaders for the given offset ranges, or throw an exception */ + private def leadersForRanges( + kafkaParams: Map[String, String], + offsetRanges: Array[OffsetRange]): Map[TopicAndPartition, (String, Int)] = { + val kc = new KafkaCluster(kafkaParams) + val topics = offsetRanges.map(o => TopicAndPartition(o.topic, o.partition)).toSet + val leaders = kc.findLeaders(topics).fold( + errs => throw new SparkException(errs.mkString("\n")), + ok => ok + ) + leaders + } + + /** + * Create a RDD from Kafka using offset ranges for each topic and partition. + * + * @param sc SparkContext object + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers) specified in + * host1:port1,host2:port2 form. + * @param offsetRanges Each OffsetRange in the batch corresponds to a + * range of offsets for a given Kafka topic/partition + */ + @Experimental + def createRDD[ + K: ClassTag, + V: ClassTag, + KD <: Decoder[K]: ClassTag, + VD <: Decoder[V]: ClassTag]( + sc: SparkContext, + kafkaParams: Map[String, String], + offsetRanges: Array[OffsetRange] + ): RDD[(K, V)] = { + val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message) + val leaders = leadersForRanges(kafkaParams, offsetRanges) + new KafkaRDD[K, V, KD, VD, (K, V)](sc, kafkaParams, offsetRanges, leaders, messageHandler) + } + + /** + * :: Experimental :: + * Create a RDD from Kafka using offset ranges for each topic and partition. This allows you + * specify the Kafka leader to connect to (to optimize fetching) and access the message as well + * as the metadata. + * + * @param sc SparkContext object + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers) specified in + * host1:port1,host2:port2 form. + * @param offsetRanges Each OffsetRange in the batch corresponds to a + * range of offsets for a given Kafka topic/partition + * @param leaders Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty map, + * in which case leaders will be looked up on the driver. + * @param messageHandler Function for translating each message and metadata into the desired type + */ + @Experimental + def createRDD[ + K: ClassTag, + V: ClassTag, + KD <: Decoder[K]: ClassTag, + VD <: Decoder[V]: ClassTag, + R: ClassTag]( + sc: SparkContext, + kafkaParams: Map[String, String], + offsetRanges: Array[OffsetRange], + leaders: Map[TopicAndPartition, Broker], + messageHandler: MessageAndMetadata[K, V] => R + ): RDD[R] = { + val leaderMap = if (leaders.isEmpty) { + leadersForRanges(kafkaParams, offsetRanges) + } else { + // This could be avoided by refactoring KafkaRDD.leaders and KafkaCluster to use Broker + leaders.map { + case (tp: TopicAndPartition, Broker(host, port)) => (tp, (host, port)) + }.toMap + } + new KafkaRDD[K, V, KD, VD, R](sc, kafkaParams, offsetRanges, leaderMap, messageHandler) + } + + + /** + * Create a RDD from Kafka using offset ranges for each topic and partition. + * + * @param jsc JavaSparkContext object + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers) specified in + * host1:port1,host2:port2 form. + * @param offsetRanges Each OffsetRange in the batch corresponds to a + * range of offsets for a given Kafka topic/partition + */ + @Experimental + def createRDD[K, V, KD <: Decoder[K], VD <: Decoder[V]]( + jsc: JavaSparkContext, + keyClass: Class[K], + valueClass: Class[V], + keyDecoderClass: Class[KD], + valueDecoderClass: Class[VD], + kafkaParams: JMap[String, String], + offsetRanges: Array[OffsetRange] + ): JavaPairRDD[K, V] = { + implicit val keyCmt: ClassTag[K] = ClassTag(keyClass) + implicit val valueCmt: ClassTag[V] = ClassTag(valueClass) + implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) + implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) + new JavaPairRDD(createRDD[K, V, KD, VD]( + jsc.sc, Map(kafkaParams.toSeq: _*), offsetRanges)) + } + + /** + * :: Experimental :: + * Create a RDD from Kafka using offset ranges for each topic and partition. This allows you + * specify the Kafka leader to connect to (to optimize fetching) and access the message as well + * as the metadata. + * + * @param jsc JavaSparkContext object + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers) specified in + * host1:port1,host2:port2 form. + * @param offsetRanges Each OffsetRange in the batch corresponds to a + * range of offsets for a given Kafka topic/partition + * @param leaders Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty map, + * in which case leaders will be looked up on the driver. + * @param messageHandler Function for translating each message and metadata into the desired type + */ + @Experimental + def createRDD[K, V, KD <: Decoder[K], VD <: Decoder[V], R]( + jsc: JavaSparkContext, + keyClass: Class[K], + valueClass: Class[V], + keyDecoderClass: Class[KD], + valueDecoderClass: Class[VD], + recordClass: Class[R], + kafkaParams: JMap[String, String], + offsetRanges: Array[OffsetRange], + leaders: JMap[TopicAndPartition, Broker], + messageHandler: JFunction[MessageAndMetadata[K, V], R] + ): JavaRDD[R] = { + implicit val keyCmt: ClassTag[K] = ClassTag(keyClass) + implicit val valueCmt: ClassTag[V] = ClassTag(valueClass) + implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) + implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) + implicit val recordCmt: ClassTag[R] = ClassTag(recordClass) + val leaderMap = Map(leaders.toSeq: _*) + createRDD[K, V, KD, VD, R]( + jsc.sc, Map(kafkaParams.toSeq: _*), offsetRanges, leaderMap, messageHandler.call _) + } + + /** + * :: Experimental :: + * Create an input stream that directly pulls messages from Kafka Brokers + * without using any receiver. This stream can guarantee that each message + * from Kafka is included in transformations exactly once (see points below). + * + * Points to note: + * - No receivers: This stream does not use any receiver. It directly queries Kafka + * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked + * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. + * You can access the offsets used in each batch from the generated RDDs (see + * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). + * - Failure Recovery: To recover from driver failures, you have to enable checkpointing + * in the [[StreamingContext]]. The information on consumed offset can be + * recovered from the checkpoint. See the programming guide for details (constraints, etc.). + * - End-to-end semantics: This stream ensures that every records is effectively received and + * transformed exactly once, but gives no guarantees on whether the transformed data are + * outputted exactly once. For end-to-end exactly-once semantics, you have to either ensure + * that the output operation is idempotent, or use transactions to output records atomically. + * See the programming guide for more details. + * + * @param ssc StreamingContext object + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers) specified in + * host1:port1,host2:port2 form. + * @param fromOffsets Per-topic/partition Kafka offsets defining the (inclusive) + * starting point of the stream + * @param messageHandler Function for translating each message and metadata into the desired type + */ + @Experimental + def createDirectStream[ + K: ClassTag, + V: ClassTag, + KD <: Decoder[K]: ClassTag, + VD <: Decoder[V]: ClassTag, + R: ClassTag] ( + ssc: StreamingContext, + kafkaParams: Map[String, String], + fromOffsets: Map[TopicAndPartition, Long], + messageHandler: MessageAndMetadata[K, V] => R + ): InputDStream[R] = { + new DirectKafkaInputDStream[K, V, KD, VD, R]( + ssc, kafkaParams, fromOffsets, messageHandler) + } + + /** + * :: Experimental :: + * Create an input stream that directly pulls messages from Kafka Brokers + * without using any receiver. This stream can guarantee that each message + * from Kafka is included in transformations exactly once (see points below). + * + * Points to note: + * - No receivers: This stream does not use any receiver. It directly queries Kafka + * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked + * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. + * You can access the offsets used in each batch from the generated RDDs (see + * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). + * - Failure Recovery: To recover from driver failures, you have to enable checkpointing + * in the [[StreamingContext]]. The information on consumed offset can be + * recovered from the checkpoint. See the programming guide for details (constraints, etc.). + * - End-to-end semantics: This stream ensures that every records is effectively received and + * transformed exactly once, but gives no guarantees on whether the transformed data are + * outputted exactly once. For end-to-end exactly-once semantics, you have to either ensure + * that the output operation is idempotent, or use transactions to output records atomically. + * See the programming guide for more details. + * + * @param ssc StreamingContext object + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers), specified in + * host1:port1,host2:port2 form. + * If not starting from a checkpoint, "auto.offset.reset" may be set to "largest" or "smallest" + * to determine where the stream starts (defaults to "largest") + * @param topics Names of the topics to consume + */ + @Experimental + def createDirectStream[ + K: ClassTag, + V: ClassTag, + KD <: Decoder[K]: ClassTag, + VD <: Decoder[V]: ClassTag] ( + ssc: StreamingContext, + kafkaParams: Map[String, String], + topics: Set[String] + ): InputDStream[(K, V)] = { + val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message) + val kc = new KafkaCluster(kafkaParams) + val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase) + + (for { + topicPartitions <- kc.getPartitions(topics).right + leaderOffsets <- (if (reset == Some("smallest")) { + kc.getEarliestLeaderOffsets(topicPartitions) + } else { + kc.getLatestLeaderOffsets(topicPartitions) + }).right + } yield { + val fromOffsets = leaderOffsets.map { case (tp, lo) => + (tp, lo.offset) + } + new DirectKafkaInputDStream[K, V, KD, VD, (K, V)]( + ssc, kafkaParams, fromOffsets, messageHandler) + }).fold( + errs => throw new SparkException(errs.mkString("\n")), + ok => ok + ) + } + + /** + * :: Experimental :: + * Create an input stream that directly pulls messages from Kafka Brokers + * without using any receiver. This stream can guarantee that each message + * from Kafka is included in transformations exactly once (see points below). + * + * Points to note: + * - No receivers: This stream does not use any receiver. It directly queries Kafka + * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked + * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. + * You can access the offsets used in each batch from the generated RDDs (see + * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). + * - Failure Recovery: To recover from driver failures, you have to enable checkpointing + * in the [[StreamingContext]]. The information on consumed offset can be + * recovered from the checkpoint. See the programming guide for details (constraints, etc.). + * - End-to-end semantics: This stream ensures that every records is effectively received and + * transformed exactly once, but gives no guarantees on whether the transformed data are + * outputted exactly once. For end-to-end exactly-once semantics, you have to either ensure + * that the output operation is idempotent, or use transactions to output records atomically. + * See the programming guide for more details. + * + * @param jssc JavaStreamingContext object + * @param keyClass Class of the keys in the Kafka records + * @param valueClass Class of the values in the Kafka records + * @param keyDecoderClass Class of the key decoder + * @param valueDecoderClass Class of the value decoder + * @param recordClass Class of the records in DStream + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers), specified in + * host1:port1,host2:port2 form. + * @param fromOffsets Per-topic/partition Kafka offsets defining the (inclusive) + * starting point of the stream + * @param messageHandler Function for translating each message and metadata into the desired type + */ + @Experimental + def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V], R]( + jssc: JavaStreamingContext, + keyClass: Class[K], + valueClass: Class[V], + keyDecoderClass: Class[KD], + valueDecoderClass: Class[VD], + recordClass: Class[R], + kafkaParams: JMap[String, String], + fromOffsets: JMap[TopicAndPartition, JLong], + messageHandler: JFunction[MessageAndMetadata[K, V], R] + ): JavaInputDStream[R] = { + implicit val keyCmt: ClassTag[K] = ClassTag(keyClass) + implicit val valueCmt: ClassTag[V] = ClassTag(valueClass) + implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) + implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) + implicit val recordCmt: ClassTag[R] = ClassTag(recordClass) + createDirectStream[K, V, KD, VD, R]( + jssc.ssc, + Map(kafkaParams.toSeq: _*), + Map(fromOffsets.mapValues { _.longValue() }.toSeq: _*), + messageHandler.call _ + ) + } + + /** + * :: Experimental :: + * Create an input stream that directly pulls messages from Kafka Brokers + * without using any receiver. This stream can guarantee that each message + * from Kafka is included in transformations exactly once (see points below). + * + * Points to note: + * - No receivers: This stream does not use any receiver. It directly queries Kafka + * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked + * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. + * You can access the offsets used in each batch from the generated RDDs (see + * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). + * - Failure Recovery: To recover from driver failures, you have to enable checkpointing + * in the [[StreamingContext]]. The information on consumed offset can be + * recovered from the checkpoint. See the programming guide for details (constraints, etc.). + * - End-to-end semantics: This stream ensures that every records is effectively received and + * transformed exactly once, but gives no guarantees on whether the transformed data are + * outputted exactly once. For end-to-end exactly-once semantics, you have to either ensure + * that the output operation is idempotent, or use transactions to output records atomically. + * See the programming guide for more details. + * + * @param jssc JavaStreamingContext object + * @param keyClass Class of the keys in the Kafka records + * @param valueClass Class of the values in the Kafka records + * @param keyDecoderClass Class of the key decoder + * @param valueDecoderClass Class type of the value decoder + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers), specified in + * host1:port1,host2:port2 form. + * If not starting from a checkpoint, "auto.offset.reset" may be set to "largest" or "smallest" + * to determine where the stream starts (defaults to "largest") + * @param topics Names of the topics to consume + */ + @Experimental + def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V]]( + jssc: JavaStreamingContext, + keyClass: Class[K], + valueClass: Class[V], + keyDecoderClass: Class[KD], + valueDecoderClass: Class[VD], + kafkaParams: JMap[String, String], + topics: JSet[String] + ): JavaPairInputDStream[K, V] = { + implicit val keyCmt: ClassTag[K] = ClassTag(keyClass) + implicit val valueCmt: ClassTag[V] = ClassTag(valueClass) + implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) + implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) + createDirectStream[K, V, KD, VD]( + jssc.ssc, + Map(kafkaParams.toSeq: _*), + Set(topics.toSeq: _*) + ) + } +} + +/** + * This is a helper class that wraps the KafkaUtils.createStream() into more + * Python-friendly class and function so that it can be easily + * instantiated and called from Python's KafkaUtils (see SPARK-6027). + * + * The zero-arg constructor helps instantiate this class from the Class object + * classOf[KafkaUtilsPythonHelper].newInstance(), and the createStream() + * takes care of known parameters instead of passing them from Python + */ +private class KafkaUtilsPythonHelper { + def createStream( + jssc: JavaStreamingContext, + kafkaParams: JMap[String, String], + topics: JMap[String, JInt], + storageLevel: StorageLevel): JavaPairReceiverInputDStream[Array[Byte], Array[Byte]] = { + KafkaUtils.createStream[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder]( + jssc, + classOf[Array[Byte]], + classOf[Array[Byte]], + classOf[DefaultDecoder], + classOf[DefaultDecoder], + kafkaParams, + topics, + storageLevel) + } } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala new file mode 100644 index 0000000000000..9c3dfeb8f5928 --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import kafka.common.TopicAndPartition + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Represents any object that has a collection of [[OffsetRange]]s. This can be used access the + * offset ranges in RDDs generated by the direct Kafka DStream (see + * [[KafkaUtils.createDirectStream()]]). + * {{{ + * KafkaUtils.createDirectStream(...).foreachRDD { rdd => + * val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + * ... + * } + * }}} + */ +@Experimental +trait HasOffsetRanges { + def offsetRanges: Array[OffsetRange] +} + +/** + * :: Experimental :: + * Represents a range of offsets from a single Kafka TopicAndPartition. Instances of this class + * can be created with `OffsetRange.create()`. + */ +@Experimental +final class OffsetRange private( + /** Kafka topic name */ + val topic: String, + /** Kafka partition id */ + val partition: Int, + /** inclusive starting offset */ + val fromOffset: Long, + /** exclusive ending offset */ + val untilOffset: Long) extends Serializable { + import OffsetRange.OffsetRangeTuple + + override def equals(obj: Any): Boolean = obj match { + case that: OffsetRange => + this.topic == that.topic && + this.partition == that.partition && + this.fromOffset == that.fromOffset && + this.untilOffset == that.untilOffset + case _ => false + } + + override def hashCode(): Int = { + toTuple.hashCode() + } + + override def toString(): String = { + s"OffsetRange(topic: '$topic', partition: $partition, range: [$fromOffset -> $untilOffset]" + } + + /** this is to avoid ClassNotFoundException during checkpoint restore */ + private[streaming] + def toTuple: OffsetRangeTuple = (topic, partition, fromOffset, untilOffset) +} + +/** + * :: Experimental :: + * Companion object the provides methods to create instances of [[OffsetRange]]. + */ +@Experimental +object OffsetRange { + def create(topic: String, partition: Int, fromOffset: Long, untilOffset: Long): OffsetRange = + new OffsetRange(topic, partition, fromOffset, untilOffset) + + def create( + topicAndPartition: TopicAndPartition, + fromOffset: Long, + untilOffset: Long): OffsetRange = + new OffsetRange(topicAndPartition.topic, topicAndPartition.partition, fromOffset, untilOffset) + + def apply(topic: String, partition: Int, fromOffset: Long, untilOffset: Long): OffsetRange = + new OffsetRange(topic, partition, fromOffset, untilOffset) + + def apply( + topicAndPartition: TopicAndPartition, + fromOffset: Long, + untilOffset: Long): OffsetRange = + new OffsetRange(topicAndPartition.topic, topicAndPartition.partition, fromOffset, untilOffset) + + /** this is to avoid ClassNotFoundException during checkpoint restore */ + private[kafka] + type OffsetRangeTuple = (String, Int, Long, Long) + + private[kafka] + def apply(t: OffsetRangeTuple) = + new OffsetRange(t._1, t._2, t._3, t._4) +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala index be734b80272d1..c4a44c1822c39 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala @@ -201,12 +201,31 @@ class ReliableKafkaReceiver[ topicPartitionOffsetMap.clear() } - /** Store the ready-to-be-stored block and commit the related offsets to zookeeper. */ + /** + * Store the ready-to-be-stored block and commit the related offsets to zookeeper. This method + * will try a fixed number of times to push the block. If the push fails, the receiver is stopped. + */ private def storeBlockAndCommitOffset( blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { - store(arrayBuffer.asInstanceOf[mutable.ArrayBuffer[(K, V)]]) - Option(blockOffsetMap.get(blockId)).foreach(commitOffset) - blockOffsetMap.remove(blockId) + var count = 0 + var pushed = false + var exception: Exception = null + while (!pushed && count <= 3) { + try { + store(arrayBuffer.asInstanceOf[mutable.ArrayBuffer[(K, V)]]) + pushed = true + } catch { + case ex: Exception => + count += 1 + exception = ex + } + } + if (pushed) { + Option(blockOffsetMap.get(blockId)).foreach(commitOffset) + blockOffsetMap.remove(blockId) + } else { + stop("Error while storing block into Spark", exception) + } } /** diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java new file mode 100644 index 0000000000000..1334cc8fd1b57 --- /dev/null +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Random; +import java.util.Arrays; + +import org.apache.spark.SparkConf; + +import scala.Tuple2; + +import junit.framework.Assert; + +import kafka.common.TopicAndPartition; +import kafka.message.MessageAndMetadata; +import kafka.serializer.StringDecoder; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.Durations; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + +import org.junit.Test; +import org.junit.After; +import org.junit.Before; + +public class JavaDirectKafkaStreamSuite implements Serializable { + private transient JavaStreamingContext ssc = null; + private transient Random random = new Random(); + private transient KafkaStreamSuiteBase suiteBase = null; + + @Before + public void setUp() { + suiteBase = new KafkaStreamSuiteBase() { }; + suiteBase.setupKafka(); + System.clearProperty("spark.driver.port"); + SparkConf sparkConf = new SparkConf() + .setMaster("local[4]").setAppName(this.getClass().getSimpleName()); + ssc = new JavaStreamingContext(sparkConf, Durations.milliseconds(200)); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + System.clearProperty("spark.driver.port"); + suiteBase.tearDownKafka(); + } + + @Test + public void testKafkaStream() throws InterruptedException { + String topic1 = "topic1"; + String topic2 = "topic2"; + + String[] topic1data = createTopicAndSendData(topic1); + String[] topic2data = createTopicAndSendData(topic2); + + HashSet sent = new HashSet(); + sent.addAll(Arrays.asList(topic1data)); + sent.addAll(Arrays.asList(topic2data)); + + HashMap kafkaParams = new HashMap(); + kafkaParams.put("metadata.broker.list", suiteBase.brokerAddress()); + kafkaParams.put("auto.offset.reset", "smallest"); + + JavaDStream stream1 = KafkaUtils.createDirectStream( + ssc, + String.class, + String.class, + StringDecoder.class, + StringDecoder.class, + kafkaParams, + topicToSet(topic1) + ).map( + new Function, String>() { + @Override + public String call(scala.Tuple2 kv) throws Exception { + return kv._2(); + } + } + ); + + JavaDStream stream2 = KafkaUtils.createDirectStream( + ssc, + String.class, + String.class, + StringDecoder.class, + StringDecoder.class, + String.class, + kafkaParams, + topicOffsetToMap(topic2, (long) 0), + new Function, String>() { + @Override + public String call(MessageAndMetadata msgAndMd) throws Exception { + return msgAndMd.message(); + } + } + ); + JavaDStream unifiedStream = stream1.union(stream2); + + final HashSet result = new HashSet(); + unifiedStream.foreachRDD( + new Function, Void>() { + @Override + public Void call(org.apache.spark.api.java.JavaRDD rdd) throws Exception { + result.addAll(rdd.collect()); + return null; + } + } + ); + ssc.start(); + long startTime = System.currentTimeMillis(); + boolean matches = false; + while (!matches && System.currentTimeMillis() - startTime < 20000) { + matches = sent.size() == result.size(); + Thread.sleep(50); + } + Assert.assertEquals(sent, result); + ssc.stop(); + } + + private HashSet topicToSet(String topic) { + HashSet topicSet = new HashSet(); + topicSet.add(topic); + return topicSet; + } + + private HashMap topicOffsetToMap(String topic, Long offsetToStart) { + HashMap topicMap = new HashMap(); + topicMap.put(new TopicAndPartition(topic, 0), offsetToStart); + return topicMap; + } + + private String[] createTopicAndSendData(String topic) { + String[] data = { topic + "-1", topic + "-2", topic + "-3"}; + suiteBase.createTopic(topic); + suiteBase.sendMessages(topic, data); + return data; + } +} diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java new file mode 100644 index 0000000000000..9d2e1705c6c73 --- /dev/null +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Arrays; + +import org.apache.spark.SparkConf; + +import scala.Tuple2; + +import junit.framework.Assert; + +import kafka.common.TopicAndPartition; +import kafka.message.MessageAndMetadata; +import kafka.serializer.StringDecoder; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; + +import org.junit.Test; +import org.junit.After; +import org.junit.Before; + +public class JavaKafkaRDDSuite implements Serializable { + private transient JavaSparkContext sc = null; + private transient KafkaStreamSuiteBase suiteBase = null; + + @Before + public void setUp() { + suiteBase = new KafkaStreamSuiteBase() { }; + suiteBase.setupKafka(); + System.clearProperty("spark.driver.port"); + SparkConf sparkConf = new SparkConf() + .setMaster("local[4]").setAppName(this.getClass().getSimpleName()); + sc = new JavaSparkContext(sparkConf); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + System.clearProperty("spark.driver.port"); + suiteBase.tearDownKafka(); + } + + @Test + public void testKafkaRDD() throws InterruptedException { + String topic1 = "topic1"; + String topic2 = "topic2"; + + String[] topic1data = createTopicAndSendData(topic1); + String[] topic2data = createTopicAndSendData(topic2); + + HashMap kafkaParams = new HashMap(); + kafkaParams.put("metadata.broker.list", suiteBase.brokerAddress()); + + OffsetRange[] offsetRanges = { + OffsetRange.create(topic1, 0, 0, 1), + OffsetRange.create(topic2, 0, 0, 1) + }; + + HashMap emptyLeaders = new HashMap(); + HashMap leaders = new HashMap(); + String[] hostAndPort = suiteBase.brokerAddress().split(":"); + Broker broker = Broker.create(hostAndPort[0], Integer.parseInt(hostAndPort[1])); + leaders.put(new TopicAndPartition(topic1, 0), broker); + leaders.put(new TopicAndPartition(topic2, 0), broker); + + JavaRDD rdd1 = KafkaUtils.createRDD( + sc, + String.class, + String.class, + StringDecoder.class, + StringDecoder.class, + kafkaParams, + offsetRanges + ).map( + new Function, String>() { + @Override + public String call(scala.Tuple2 kv) throws Exception { + return kv._2(); + } + } + ); + + JavaRDD rdd2 = KafkaUtils.createRDD( + sc, + String.class, + String.class, + StringDecoder.class, + StringDecoder.class, + String.class, + kafkaParams, + offsetRanges, + emptyLeaders, + new Function, String>() { + @Override + public String call(MessageAndMetadata msgAndMd) throws Exception { + return msgAndMd.message(); + } + } + ); + + JavaRDD rdd3 = KafkaUtils.createRDD( + sc, + String.class, + String.class, + StringDecoder.class, + StringDecoder.class, + String.class, + kafkaParams, + offsetRanges, + leaders, + new Function, String>() { + @Override + public String call(MessageAndMetadata msgAndMd) throws Exception { + return msgAndMd.message(); + } + } + ); + + // just making sure the java user apis work; the scala tests handle logic corner cases + long count1 = rdd1.count(); + long count2 = rdd2.count(); + long count3 = rdd3.count(); + Assert.assertTrue(count1 > 0); + Assert.assertEquals(count1, count2); + Assert.assertEquals(count1, count3); + } + + private String[] createTopicAndSendData(String topic) { + String[] data = { topic + "-1", topic + "-2", topic + "-3"}; + suiteBase.createTopic(topic); + suiteBase.sendMessages(topic, data); + return data; + } +} diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index 6e1abf3f385ee..208cc51b29876 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -79,9 +79,10 @@ public void testKafkaStream() throws InterruptedException { suiteBase.createTopic(topic); HashMap tmp = new HashMap(sent); - suiteBase.produceAndSendMessage(topic, + suiteBase.sendMessages(topic, JavaConverters.mapAsScalaMapConverter(tmp).asScala().toMap( - Predef.>conforms())); + Predef.>conforms()) + ); HashMap kafkaParams = new HashMap(); kafkaParams.put("zookeeper.connect", suiteBase.zkAddress()); diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala new file mode 100644 index 0000000000000..17ca9d145d665 --- /dev/null +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import java.io.File + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ +import scala.language.postfixOps + +import kafka.common.TopicAndPartition +import kafka.message.MessageAndMetadata +import kafka.serializer.StringDecoder +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.scalatest.concurrent.Eventually + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.Utils + +class DirectKafkaStreamSuite extends KafkaStreamSuiteBase + with BeforeAndAfter with BeforeAndAfterAll with Eventually { + val sparkConf = new SparkConf() + .setMaster("local[4]") + .setAppName(this.getClass.getSimpleName) + + var sc: SparkContext = _ + var ssc: StreamingContext = _ + var testDir: File = _ + + override def beforeAll { + setupKafka() + } + + override def afterAll { + tearDownKafka() + } + + after { + if (ssc != null) { + ssc.stop() + sc = null + } + if (sc != null) { + sc.stop() + } + if (testDir != null) { + Utils.deleteRecursively(testDir) + } + } + + + test("basic stream receiving with multiple topics and smallest starting offset") { + val topics = Set("basic1", "basic2", "basic3") + val data = Map("a" -> 7, "b" -> 9) + topics.foreach { t => + createTopic(t) + sendMessages(t, data) + } + val totalSent = data.values.sum * topics.size + val kafkaParams = Map( + "metadata.broker.list" -> s"$brokerAddress", + "auto.offset.reset" -> "smallest" + ) + + ssc = new StreamingContext(sparkConf, Milliseconds(200)) + val stream = withClue("Error creating direct stream") { + KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, topics) + } + + val allReceived = new ArrayBuffer[(String, String)] + + stream.foreachRDD { rdd => + // Get the offset ranges in the RDD + val offsets = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + val collected = rdd.mapPartitionsWithIndex { (i, iter) => + // For each partition, get size of the range in the partition, + // and the number of items in the partition + val off = offsets(i) + val all = iter.toSeq + val partSize = all.size + val rangeSize = off.untilOffset - off.fromOffset + Iterator((partSize, rangeSize)) + }.collect + + // Verify whether number of elements in each partition + // matches with the corresponding offset range + collected.foreach { case (partSize, rangeSize) => + assert(partSize === rangeSize, "offset ranges are wrong") + } + } + stream.foreachRDD { rdd => allReceived ++= rdd.collect() } + ssc.start() + eventually(timeout(20000.milliseconds), interval(200.milliseconds)) { + assert(allReceived.size === totalSent, + "didn't get expected number of messages, messages:\n" + allReceived.mkString("\n")) + } + ssc.stop() + } + + test("receiving from largest starting offset") { + val topic = "largest" + val topicPartition = TopicAndPartition(topic, 0) + val data = Map("a" -> 10) + createTopic(topic) + val kafkaParams = Map( + "metadata.broker.list" -> s"$brokerAddress", + "auto.offset.reset" -> "largest" + ) + val kc = new KafkaCluster(kafkaParams) + def getLatestOffset(): Long = { + kc.getLatestLeaderOffsets(Set(topicPartition)).right.get(topicPartition).offset + } + + // Send some initial messages before starting context + sendMessages(topic, data) + eventually(timeout(10 seconds), interval(20 milliseconds)) { + assert(getLatestOffset() > 3) + } + val offsetBeforeStart = getLatestOffset() + + // Setup context and kafka stream with largest offset + ssc = new StreamingContext(sparkConf, Milliseconds(200)) + val stream = withClue("Error creating direct stream") { + KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, Set(topic)) + } + assert( + stream.asInstanceOf[DirectKafkaInputDStream[_, _, _, _, _]] + .fromOffsets(topicPartition) >= offsetBeforeStart, + "Start offset not from latest" + ) + + val collectedData = new mutable.ArrayBuffer[String]() + stream.map { _._2 }.foreachRDD { rdd => collectedData ++= rdd.collect() } + ssc.start() + val newData = Map("b" -> 10) + sendMessages(topic, newData) + eventually(timeout(10 seconds), interval(50 milliseconds)) { + collectedData.contains("b") + } + assert(!collectedData.contains("a")) + } + + + test("creating stream by offset") { + val topic = "offset" + val topicPartition = TopicAndPartition(topic, 0) + val data = Map("a" -> 10) + createTopic(topic) + val kafkaParams = Map( + "metadata.broker.list" -> s"$brokerAddress", + "auto.offset.reset" -> "largest" + ) + val kc = new KafkaCluster(kafkaParams) + def getLatestOffset(): Long = { + kc.getLatestLeaderOffsets(Set(topicPartition)).right.get(topicPartition).offset + } + + // Send some initial messages before starting context + sendMessages(topic, data) + eventually(timeout(10 seconds), interval(20 milliseconds)) { + assert(getLatestOffset() >= 10) + } + val offsetBeforeStart = getLatestOffset() + + // Setup context and kafka stream with largest offset + ssc = new StreamingContext(sparkConf, Milliseconds(200)) + val stream = withClue("Error creating direct stream") { + KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder, String]( + ssc, kafkaParams, Map(topicPartition -> 11L), + (m: MessageAndMetadata[String, String]) => m.message()) + } + assert( + stream.asInstanceOf[DirectKafkaInputDStream[_, _, _, _, _]] + .fromOffsets(topicPartition) >= offsetBeforeStart, + "Start offset not from latest" + ) + + val collectedData = new mutable.ArrayBuffer[String]() + stream.foreachRDD { rdd => collectedData ++= rdd.collect() } + ssc.start() + val newData = Map("b" -> 10) + sendMessages(topic, newData) + eventually(timeout(10 seconds), interval(50 milliseconds)) { + collectedData.contains("b") + } + assert(!collectedData.contains("a")) + } + + // Test to verify the offset ranges can be recovered from the checkpoints + test("offset recovery") { + val topic = "recovery" + createTopic(topic) + testDir = Utils.createTempDir() + + val kafkaParams = Map( + "metadata.broker.list" -> s"$brokerAddress", + "auto.offset.reset" -> "smallest" + ) + + // Send data to Kafka and wait for it to be received + def sendDataAndWaitForReceive(data: Seq[Int]) { + val strings = data.map { _.toString} + sendMessages(topic, strings.map { _ -> 1}.toMap) + eventually(timeout(10 seconds), interval(50 milliseconds)) { + assert(strings.forall { DirectKafkaStreamSuite.collectedData.contains }) + } + } + + // Setup the streaming context + ssc = new StreamingContext(sparkConf, Milliseconds(100)) + val kafkaStream = withClue("Error creating direct stream") { + KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, Set(topic)) + } + val keyedStream = kafkaStream.map { v => "key" -> v._2.toInt } + val stateStream = keyedStream.updateStateByKey { (values: Seq[Int], state: Option[Int]) => + Some(values.sum + state.getOrElse(0)) + } + ssc.checkpoint(testDir.getAbsolutePath) + + // This is to collect the raw data received from Kafka + kafkaStream.foreachRDD { (rdd: RDD[(String, String)], time: Time) => + val data = rdd.map { _._2 }.collect() + DirectKafkaStreamSuite.collectedData.appendAll(data) + } + + // This is ensure all the data is eventually receiving only once + stateStream.foreachRDD { (rdd: RDD[(String, Int)]) => + rdd.collect().headOption.foreach { x => DirectKafkaStreamSuite.total = x._2 } + } + ssc.start() + + // Send some data and wait for them to be received + for (i <- (1 to 10).grouped(4)) { + sendDataAndWaitForReceive(i) + } + + // Verify that offset ranges were generated + val offsetRangesBeforeStop = getOffsetRanges(kafkaStream) + assert(offsetRangesBeforeStop.size >= 1, "No offset ranges generated") + assert( + offsetRangesBeforeStop.head._2.forall { _.fromOffset === 0 }, + "starting offset not zero" + ) + ssc.stop() + logInfo("====== RESTARTING ========") + + // Recover context from checkpoints + ssc = new StreamingContext(testDir.getAbsolutePath) + val recoveredStream = ssc.graph.getInputStreams().head.asInstanceOf[DStream[(String, String)]] + + // Verify offset ranges have been recovered + val recoveredOffsetRanges = getOffsetRanges(recoveredStream) + assert(recoveredOffsetRanges.size > 0, "No offset ranges recovered") + val earlierOffsetRangesAsSets = offsetRangesBeforeStop.map { x => (x._1, x._2.toSet) } + assert( + recoveredOffsetRanges.forall { or => + earlierOffsetRangesAsSets.contains((or._1, or._2.toSet)) + }, + "Recovered ranges are not the same as the ones generated" + ) + + // Restart context, give more data and verify the total at the end + // If the total is write that means each records has been received only once + ssc.start() + sendDataAndWaitForReceive(11 to 20) + eventually(timeout(10 seconds), interval(50 milliseconds)) { + assert(DirectKafkaStreamSuite.total === (1 to 20).sum) + } + ssc.stop() + } + + /** Get the generated offset ranges from the DirectKafkaStream */ + private def getOffsetRanges[K, V]( + kafkaStream: DStream[(K, V)]): Seq[(Time, Array[OffsetRange])] = { + kafkaStream.generatedRDDs.mapValues { rdd => + rdd.asInstanceOf[KafkaRDD[K, V, _, _, (K, V)]].offsetRanges + }.toSeq.sortBy { _._1 } + } +} + +object DirectKafkaStreamSuite { + val collectedData = new mutable.ArrayBuffer[String]() + var total = -1L +} diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala new file mode 100644 index 0000000000000..fc9275b7207be --- /dev/null +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import scala.util.Random + +import kafka.common.TopicAndPartition +import org.scalatest.BeforeAndAfterAll + +class KafkaClusterSuite extends KafkaStreamSuiteBase with BeforeAndAfterAll { + val topic = "kcsuitetopic" + Random.nextInt(10000) + val topicAndPartition = TopicAndPartition(topic, 0) + var kc: KafkaCluster = null + + override def beforeAll() { + setupKafka() + createTopic(topic) + sendMessages(topic, Map("a" -> 1)) + kc = new KafkaCluster(Map("metadata.broker.list" -> s"$brokerAddress")) + } + + override def afterAll() { + tearDownKafka() + } + + test("metadata apis") { + val leader = kc.findLeaders(Set(topicAndPartition)).right.get(topicAndPartition) + val leaderAddress = s"${leader._1}:${leader._2}" + assert(leaderAddress === brokerAddress, "didn't get leader") + + val parts = kc.getPartitions(Set(topic)).right.get + assert(parts(topicAndPartition), "didn't get partitions") + } + + test("leader offset apis") { + val earliest = kc.getEarliestLeaderOffsets(Set(topicAndPartition)).right.get + assert(earliest(topicAndPartition).offset === 0, "didn't get earliest") + + val latest = kc.getLatestLeaderOffsets(Set(topicAndPartition)).right.get + assert(latest(topicAndPartition).offset === 1, "didn't get latest") + } + + test("consumer offset apis") { + val group = "kcsuitegroup" + Random.nextInt(10000) + + val offset = Random.nextInt(10000) + + val set = kc.setConsumerOffsets(group, Map(topicAndPartition -> offset)) + assert(set.isRight, "didn't set consumer offsets") + + val get = kc.getConsumerOffsets(group, Set(topicAndPartition)).right.get + assert(get(topicAndPartition) === offset, "didn't get consumer offsets") + } +} diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala new file mode 100644 index 0000000000000..a223da70b043f --- /dev/null +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import scala.util.Random + +import kafka.serializer.StringDecoder +import kafka.common.TopicAndPartition +import kafka.message.MessageAndMetadata +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark._ +import org.apache.spark.SparkContext._ + +class KafkaRDDSuite extends KafkaStreamSuiteBase with BeforeAndAfterAll { + val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName) + var sc: SparkContext = _ + override def beforeAll { + sc = new SparkContext(sparkConf) + + setupKafka() + } + + override def afterAll { + if (sc != null) { + sc.stop + sc = null + } + tearDownKafka() + } + + test("basic usage") { + val topic = "topicbasic" + createTopic(topic) + val messages = Set("the", "quick", "brown", "fox") + sendMessages(topic, messages.toArray) + + + val kafkaParams = Map("metadata.broker.list" -> brokerAddress, + "group.id" -> s"test-consumer-${Random.nextInt(10000)}") + + val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size)) + + val rdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder]( + sc, kafkaParams, offsetRanges) + + val received = rdd.map(_._2).collect.toSet + assert(received === messages) + } + + test("iterator boundary conditions") { + // the idea is to find e.g. off-by-one errors between what kafka has available and the rdd + val topic = "topic1" + val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) + createTopic(topic) + + val kafkaParams = Map("metadata.broker.list" -> brokerAddress, + "group.id" -> s"test-consumer-${Random.nextInt(10000)}") + + val kc = new KafkaCluster(kafkaParams) + + // this is the "lots of messages" case + sendMessages(topic, sent) + // rdd defined from leaders after sending messages, should get the number sent + val rdd = getRdd(kc, Set(topic)) + + assert(rdd.isDefined) + assert(rdd.get.count === sent.values.sum, "didn't get all sent messages") + + val ranges = rdd.get.asInstanceOf[HasOffsetRanges] + .offsetRanges.map(o => TopicAndPartition(o.topic, o.partition) -> o.untilOffset).toMap + + kc.setConsumerOffsets(kafkaParams("group.id"), ranges) + + // this is the "0 messages" case + val rdd2 = getRdd(kc, Set(topic)) + // shouldn't get anything, since message is sent after rdd was defined + val sentOnlyOne = Map("d" -> 1) + + sendMessages(topic, sentOnlyOne) + assert(rdd2.isDefined) + assert(rdd2.get.count === 0, "got messages when there shouldn't be any") + + // this is the "exactly 1 message" case, namely the single message from sentOnlyOne above + val rdd3 = getRdd(kc, Set(topic)) + // send lots of messages after rdd was defined, they shouldn't show up + sendMessages(topic, Map("extra" -> 22)) + + assert(rdd3.isDefined) + assert(rdd3.get.count === sentOnlyOne.values.sum, "didn't get exactly one message") + + } + + // get an rdd from the committed consumer offsets until the latest leader offsets, + private def getRdd(kc: KafkaCluster, topics: Set[String]) = { + val groupId = kc.kafkaParams("group.id") + def consumerOffsets(topicPartitions: Set[TopicAndPartition]) = { + kc.getConsumerOffsets(groupId, topicPartitions).right.toOption.orElse( + kc.getEarliestLeaderOffsets(topicPartitions).right.toOption.map { offs => + offs.map(kv => kv._1 -> kv._2.offset) + } + ) + } + kc.getPartitions(topics).right.toOption.flatMap { topicPartitions => + consumerOffsets(topicPartitions).flatMap { from => + kc.getLatestLeaderOffsets(topicPartitions).right.toOption.map { until => + val offsetRanges = from.map { case (tp: TopicAndPartition, fromOffset: Long) => + OffsetRange(tp.topic, tp.partition, fromOffset, until(tp).offset) + }.toArray + + val leaders = until.map { case (tp: TopicAndPartition, lo: KafkaCluster.LeaderOffset) => + tp -> Broker(lo.host, lo.port) + }.toMap + + KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder, String]( + sc, kc.kafkaParams, offsetRanges, leaders, + (mmd: MessageAndMetadata[String, String]) => s"${mmd.offset} ${mmd.message}") + } + } + } + } +} diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index b19c053ebfc44..e4966eebb9b34 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -26,7 +26,7 @@ import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.Random -import kafka.admin.CreateTopicCommand +import kafka.admin.AdminUtils import kafka.common.{KafkaException, TopicAndPartition} import kafka.producer.{KeyedMessage, Producer, ProducerConfig} import kafka.serializer.{StringDecoder, StringEncoder} @@ -48,30 +48,41 @@ import org.apache.spark.util.Utils */ abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Logging { - var zkAddress: String = _ - var zkClient: ZkClient = _ - private val zkHost = "localhost" + private var zkPort: Int = 0 private val zkConnectionTimeout = 6000 private val zkSessionTimeout = 6000 private var zookeeper: EmbeddedZookeeper = _ - private var zkPort: Int = 0 + private val brokerHost = "localhost" private var brokerPort = 9092 private var brokerConf: KafkaConfig = _ private var server: KafkaServer = _ private var producer: Producer[String, String] = _ + private var zkReady = false + private var brokerReady = false + + protected var zkClient: ZkClient = _ + + def zkAddress: String = { + assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper address") + s"$zkHost:$zkPort" + } + + def brokerAddress: String = { + assert(brokerReady, "Kafka not setup yet or already torn down, cannot get broker address") + s"$brokerHost:$brokerPort" + } def setupKafka() { // Zookeeper server startup zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort") // Get the actual zookeeper binding port zkPort = zookeeper.actualPort - zkAddress = s"$zkHost:$zkPort" - logInfo("==================== 0 ====================") + zkReady = true + logInfo("==================== Zookeeper Started ====================") - zkClient = new ZkClient(zkAddress, zkSessionTimeout, zkConnectionTimeout, - ZKStringSerializer) - logInfo("==================== 1 ====================") + zkClient = new ZkClient(zkAddress, zkSessionTimeout, zkConnectionTimeout, ZKStringSerializer) + logInfo("==================== Zookeeper Client Created ====================") // Kafka broker startup var bindSuccess: Boolean = false @@ -80,9 +91,8 @@ abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Loggin val brokerProps = getBrokerConfig() brokerConf = new KafkaConfig(brokerProps) server = new KafkaServer(brokerConf) - logInfo("==================== 2 ====================") server.startup() - logInfo("==================== 3 ====================") + logInfo("==================== Kafka Broker Started ====================") bindSuccess = true } catch { case e: KafkaException => @@ -94,10 +104,13 @@ abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Loggin } Thread.sleep(2000) - logInfo("==================== 4 ====================") + logInfo("==================== Kafka + Zookeeper Ready ====================") + brokerReady = true } def tearDownKafka() { + brokerReady = false + zkReady = false if (producer != null) { producer.close() producer = null @@ -121,26 +134,23 @@ abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Loggin } } - private def createTestMessage(topic: String, sent: Map[String, Int]) - : Seq[KeyedMessage[String, String]] = { - val messages = for ((s, freq) <- sent; i <- 0 until freq) yield { - new KeyedMessage[String, String](topic, s) - } - messages.toSeq - } - def createTopic(topic: String) { - CreateTopicCommand.createTopic(zkClient, topic, 1, 1, "0") - logInfo("==================== 5 ====================") + AdminUtils.createTopic(zkClient, topic, 1, 1) // wait until metadata is propagated waitUntilMetadataIsPropagated(topic, 0) + logInfo(s"==================== Topic $topic Created ====================") } - def produceAndSendMessage(topic: String, sent: Map[String, Int]) { + def sendMessages(topic: String, messageToFreq: Map[String, Int]) { + val messages = messageToFreq.flatMap { case (s, freq) => Seq.fill(freq)(s) }.toArray + sendMessages(topic, messages) + } + + def sendMessages(topic: String, messages: Array[String]) { producer = new Producer[String, String](new ProducerConfig(getProducerConfig())) - producer.send(createTestMessage(topic, sent): _*) + producer.send(messages.map { new KeyedMessage[String, String](topic, _ ) }: _*) producer.close() - logInfo("==================== 6 ====================") + logInfo(s"==================== Sent Messages: ${messages.mkString(", ")} ====================") } private def getBrokerConfig(): Properties = { @@ -164,9 +174,9 @@ abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Loggin } private def waitUntilMetadataIsPropagated(topic: String, partition: Int) { - eventually(timeout(1000 milliseconds), interval(100 milliseconds)) { + eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { assert( - server.apis.leaderCache.keySet.contains(TopicAndPartition(topic, partition)), + server.apis.metadataCache.containsTopicAndPartition(topic, partition), s"Partition [$topic, $partition] metadata not propagated after timeout" ) } @@ -218,7 +228,7 @@ class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter { val topic = "topic1" val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) createTopic(topic) - produceAndSendMessage(topic, sent) + sendMessages(topic, sent) val kafkaParams = Map("zookeeper.connect" -> zkAddress, "group.id" -> s"test-consumer-${Random.nextInt(10000)}", diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala index 64ccc92c81fa9..fc53c23abda85 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala @@ -79,7 +79,7 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter test("Reliable Kafka input stream with single topic") { var topic = "test-topic" createTopic(topic) - produceAndSendMessage(topic, data) + sendMessages(topic, data) // Verify whether the offset of this group/topic/partition is 0 before starting. assert(getCommitOffset(groupId, topic, 0) === None) @@ -111,7 +111,7 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter val topics = Map("topic1" -> 1, "topic2" -> 1, "topic3" -> 1) topics.foreach { case (t, _) => createTopic(t) - produceAndSendMessage(t, data) + sendMessages(t, data) } // Before started, verify all the group/topic/partition offsets are 0. diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 560c8b9d18276..51e49ed334794 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.3.2-SNAPSHOT ../../pom.xml diff --git a/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 1e24da7f5f60c..cfedb5a042a35 100644 --- a/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -31,7 +31,7 @@ public void setUp() { SparkConf conf = new SparkConf() .setMaster("local[2]") .setAppName("test") - .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index e84adc088a680..05b7147ade9d3 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -25,6 +25,7 @@ import scala.concurrent.duration._ import scala.language.postfixOps import org.apache.activemq.broker.{TransportConnector, BrokerService} +import org.apache.commons.lang3.RandomUtils import org.eclipse.paho.client.mqttv3._ import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence @@ -93,6 +94,7 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { private def setupMQTT() { broker = new BrokerService() + broker.setDataDirectoryFile(Utils.createTempDir()) connector = new TransportConnector() connector.setName("mqtt") connector.setUri(new URI("mqtt:" + brokerUri)) @@ -112,7 +114,8 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { } private def findFreePort(): Int = { - Utils.startServiceOnPort(23456, (trialPort: Int) => { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { val socket = new ServerSocket(trialPort) socket.close() (null, trialPort) diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index da6ffe7662f63..297e2927eb508 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.3.2-SNAPSHOT ../../pom.xml diff --git a/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 1e24da7f5f60c..cfedb5a042a35 100644 --- a/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -31,7 +31,7 @@ public void setUp() { SparkConf conf = new SparkConf() .setMaster("local[2]") .setAppName("test") - .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index e919c2c9b19ea..090860be52456 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.3.2-SNAPSHOT ../../pom.xml diff --git a/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 1e24da7f5f60c..cfedb5a042a35 100644 --- a/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -31,7 +31,7 @@ public void setUp() { SparkConf conf = new SparkConf() .setMaster("local[2]") .setAppName("test") - .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 0fb431808bacd..757b19a89e504 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -19,8 +19,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.3.2-SNAPSHOT ../../pom.xml diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index c815eda52bda7..3ac273d9d5842 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -19,8 +19,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.3.2-SNAPSHOT ../../pom.xml @@ -67,11 +67,6 @@ scalacheck_${scala.binary.version} test - - org.easymock - easymockclassextension - test - com.novocode junit-interface diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala index 0b80b611cdce7..588e86a1887ec 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala @@ -18,9 +18,7 @@ package org.apache.spark.streaming.kinesis import org.apache.spark.Logging import org.apache.spark.streaming.Duration -import org.apache.spark.streaming.util.Clock -import org.apache.spark.streaming.util.ManualClock -import org.apache.spark.streaming.util.SystemClock +import org.apache.spark.util.{Clock, ManualClock, SystemClock} /** * This is a helper class for managing checkpoint clocks. @@ -35,7 +33,7 @@ private[kinesis] class KinesisCheckpointState( /* Initialize the checkpoint clock using the given currentClock + checkpointInterval millis */ val checkpointClock = new ManualClock() - checkpointClock.setTime(currentClock.currentTime() + checkpointInterval.milliseconds) + checkpointClock.setTime(currentClock.getTimeMillis() + checkpointInterval.milliseconds) /** * Check if it's time to checkpoint based on the current time and the derived time @@ -44,13 +42,13 @@ private[kinesis] class KinesisCheckpointState( * @return true if it's time to checkpoint */ def shouldCheckpoint(): Boolean = { - new SystemClock().currentTime() > checkpointClock.currentTime() + new SystemClock().getTimeMillis() > checkpointClock.getTimeMillis() } /** * Advance the checkpoint clock by the checkpoint interval. */ def advanceCheckpoint() = { - checkpointClock.addToTime(checkpointInterval.milliseconds) + checkpointClock.advance(checkpointInterval.milliseconds) } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index 8ecc2d90160b1..af8cd875b4541 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -104,7 +104,7 @@ private[kinesis] class KinesisRecordProcessor( logDebug(s"Checkpoint: WorkerId $workerId completed checkpoint of ${batch.size}" + s" records for shardId $shardId") logDebug(s"Checkpoint: Next checkpoint is at " + - s" ${checkpointState.checkpointClock.currentTime()} for shardId $shardId") + s" ${checkpointState.checkpointClock.getTimeMillis()} for shardId $shardId") } } catch { case e: Throwable => { diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index 41dbd64c2b1fa..255fe65819608 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -20,17 +20,17 @@ import java.nio.ByteBuffer import scala.collection.JavaConversions.seqAsJavaList -import org.apache.spark.annotation.Experimental import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.Milliseconds import org.apache.spark.streaming.Seconds import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.TestSuiteBase -import org.apache.spark.streaming.util.Clock -import org.apache.spark.streaming.util.ManualClock +import org.apache.spark.util.{ManualClock, Clock} + +import org.mockito.Mockito._ import org.scalatest.BeforeAndAfter import org.scalatest.Matchers -import org.scalatest.mock.EasyMockSugar +import org.scalatest.mock.MockitoSugar import com.amazonaws.services.kinesis.clientlibrary.exceptions.InvalidStateException import com.amazonaws.services.kinesis.clientlibrary.exceptions.KinesisClientLibDependencyException @@ -42,10 +42,10 @@ import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason import com.amazonaws.services.kinesis.model.Record /** - * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor + * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor */ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAfter - with EasyMockSugar { + with MockitoSugar { val app = "TestKinesisReceiver" val stream = "mySparkStream" @@ -73,6 +73,14 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft currentClockMock = mock[Clock] } + override def afterFunction(): Unit = { + super.afterFunction() + // Since this suite was originally written using EasyMock, add this to preserve the old + // mocking semantics (see SPARK-5735 for more details) + verifyNoMoreInteractions(receiverMock, checkpointerMock, checkpointClockMock, + checkpointStateMock, currentClockMock) + } + test("kinesis utils api") { val ssc = new StreamingContext(master, framework, batchDuration) // Tests the API, does not actually test data receiving @@ -83,193 +91,175 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft } test("process records including store and checkpoint") { - val expectedCheckpointIntervalMillis = 10 - expecting { - receiverMock.isStopped().andReturn(false).once() - receiverMock.store(record1.getData().array()).once() - receiverMock.store(record2.getData().array()).once() - checkpointStateMock.shouldCheckpoint().andReturn(true).once() - checkpointerMock.checkpoint().once() - checkpointStateMock.advanceCheckpoint().once() - } - whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, - checkpointStateMock) - recordProcessor.processRecords(batch, checkpointerMock) - } + when(receiverMock.isStopped()).thenReturn(false) + when(checkpointStateMock.shouldCheckpoint()).thenReturn(true) + + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + recordProcessor.processRecords(batch, checkpointerMock) + + verify(receiverMock, times(1)).isStopped() + verify(receiverMock, times(1)).store(record1.getData().array()) + verify(receiverMock, times(1)).store(record2.getData().array()) + verify(checkpointStateMock, times(1)).shouldCheckpoint() + verify(checkpointerMock, times(1)).checkpoint() + verify(checkpointStateMock, times(1)).advanceCheckpoint() } test("shouldn't store and checkpoint when receiver is stopped") { - expecting { - receiverMock.isStopped().andReturn(true).once() - } - whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, - checkpointStateMock) - recordProcessor.processRecords(batch, checkpointerMock) - } + when(receiverMock.isStopped()).thenReturn(true) + + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + recordProcessor.processRecords(batch, checkpointerMock) + + verify(receiverMock, times(1)).isStopped() } test("shouldn't checkpoint when exception occurs during store") { - expecting { - receiverMock.isStopped().andReturn(false).once() - receiverMock.store(record1.getData().array()).andThrow(new RuntimeException()).once() - } - whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) { - intercept[RuntimeException] { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, - checkpointStateMock) - recordProcessor.processRecords(batch, checkpointerMock) - } + when(receiverMock.isStopped()).thenReturn(false) + when(receiverMock.store(record1.getData().array())).thenThrow(new RuntimeException()) + + intercept[RuntimeException] { + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + recordProcessor.processRecords(batch, checkpointerMock) } + + verify(receiverMock, times(1)).isStopped() + verify(receiverMock, times(1)).store(record1.getData().array()) } test("should set checkpoint time to currentTime + checkpoint interval upon instantiation") { - expecting { - currentClockMock.currentTime().andReturn(0).once() - } - whenExecuting(currentClockMock) { + when(currentClockMock.getTimeMillis()).thenReturn(0) + val checkpointIntervalMillis = 10 - val checkpointState = new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) - assert(checkpointState.checkpointClock.currentTime() == checkpointIntervalMillis) - } + val checkpointState = + new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) + assert(checkpointState.checkpointClock.getTimeMillis() == checkpointIntervalMillis) + + verify(currentClockMock, times(1)).getTimeMillis() } test("should checkpoint if we have exceeded the checkpoint interval") { - expecting { - currentClockMock.currentTime().andReturn(0).once() - } - whenExecuting(currentClockMock) { - val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MinValue), currentClockMock) - assert(checkpointState.shouldCheckpoint()) - } + when(currentClockMock.getTimeMillis()).thenReturn(0) + + val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MinValue), currentClockMock) + assert(checkpointState.shouldCheckpoint()) + + verify(currentClockMock, times(1)).getTimeMillis() } test("shouldn't checkpoint if we have not exceeded the checkpoint interval") { - expecting { - currentClockMock.currentTime().andReturn(0).once() - } - whenExecuting(currentClockMock) { - val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MaxValue), currentClockMock) - assert(!checkpointState.shouldCheckpoint()) - } + when(currentClockMock.getTimeMillis()).thenReturn(0) + + val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MaxValue), currentClockMock) + assert(!checkpointState.shouldCheckpoint()) + + verify(currentClockMock, times(1)).getTimeMillis() } test("should add to time when advancing checkpoint") { - expecting { - currentClockMock.currentTime().andReturn(0).once() - } - whenExecuting(currentClockMock) { - val checkpointIntervalMillis = 10 - val checkpointState = new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) - assert(checkpointState.checkpointClock.currentTime() == checkpointIntervalMillis) - checkpointState.advanceCheckpoint() - assert(checkpointState.checkpointClock.currentTime() == (2 * checkpointIntervalMillis)) - } + when(currentClockMock.getTimeMillis()).thenReturn(0) + + val checkpointIntervalMillis = 10 + val checkpointState = + new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) + assert(checkpointState.checkpointClock.getTimeMillis() == checkpointIntervalMillis) + checkpointState.advanceCheckpoint() + assert(checkpointState.checkpointClock.getTimeMillis() == (2 * checkpointIntervalMillis)) + + verify(currentClockMock, times(1)).getTimeMillis() } test("shutdown should checkpoint if the reason is TERMINATE") { - expecting { - checkpointerMock.checkpoint().once() - } - whenExecuting(checkpointerMock, checkpointStateMock) { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, - checkpointStateMock) - val reason = ShutdownReason.TERMINATE - recordProcessor.shutdown(checkpointerMock, reason) - } + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + val reason = ShutdownReason.TERMINATE + recordProcessor.shutdown(checkpointerMock, reason) + + verify(checkpointerMock, times(1)).checkpoint() } test("shutdown should not checkpoint if the reason is something other than TERMINATE") { - expecting { - } - whenExecuting(checkpointerMock, checkpointStateMock) { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, - checkpointStateMock) - recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE) - recordProcessor.shutdown(checkpointerMock, null) - } + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE) + recordProcessor.shutdown(checkpointerMock, null) + + verify(checkpointerMock, never()).checkpoint() } test("retry success on first attempt") { val expectedIsStopped = false - expecting { - receiverMock.isStopped().andReturn(expectedIsStopped).once() - } - whenExecuting(receiverMock) { - val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) - assert(actualVal == expectedIsStopped) - } + when(receiverMock.isStopped()).thenReturn(expectedIsStopped) + + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + + verify(receiverMock, times(1)).isStopped() } test("retry success on second attempt after a Kinesis throttling exception") { val expectedIsStopped = false - expecting { - receiverMock.isStopped().andThrow(new ThrottlingException("error message")) - .andReturn(expectedIsStopped).once() - } - whenExecuting(receiverMock) { - val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) - assert(actualVal == expectedIsStopped) - } + when(receiverMock.isStopped()) + .thenThrow(new ThrottlingException("error message")) + .thenReturn(expectedIsStopped) + + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + + verify(receiverMock, times(2)).isStopped() } test("retry success on second attempt after a Kinesis dependency exception") { val expectedIsStopped = false - expecting { - receiverMock.isStopped().andThrow(new KinesisClientLibDependencyException("error message")) - .andReturn(expectedIsStopped).once() - } - whenExecuting(receiverMock) { - val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) - assert(actualVal == expectedIsStopped) - } + when(receiverMock.isStopped()) + .thenThrow(new KinesisClientLibDependencyException("error message")) + .thenReturn(expectedIsStopped) + + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + + verify(receiverMock, times(2)).isStopped() } test("retry failed after a shutdown exception") { - expecting { - checkpointerMock.checkpoint().andThrow(new ShutdownException("error message")).once() - } - whenExecuting(checkpointerMock) { - intercept[ShutdownException] { - KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) - } + when(checkpointerMock.checkpoint()).thenThrow(new ShutdownException("error message")) + + intercept[ShutdownException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) } + + verify(checkpointerMock, times(1)).checkpoint() } test("retry failed after an invalid state exception") { - expecting { - checkpointerMock.checkpoint().andThrow(new InvalidStateException("error message")).once() - } - whenExecuting(checkpointerMock) { - intercept[InvalidStateException] { - KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) - } + when(checkpointerMock.checkpoint()).thenThrow(new InvalidStateException("error message")) + + intercept[InvalidStateException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) } + + verify(checkpointerMock, times(1)).checkpoint() } test("retry failed after unexpected exception") { - expecting { - checkpointerMock.checkpoint().andThrow(new RuntimeException("error message")).once() - } - whenExecuting(checkpointerMock) { - intercept[RuntimeException] { - KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) - } + when(checkpointerMock.checkpoint()).thenThrow(new RuntimeException("error message")) + + intercept[RuntimeException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) } + + verify(checkpointerMock, times(1)).checkpoint() } test("retry failed after exhausing all retries") { val expectedErrorMessage = "final try error message" - expecting { - checkpointerMock.checkpoint().andThrow(new ThrottlingException("error message")) - .andThrow(new ThrottlingException(expectedErrorMessage)).once() - } - whenExecuting(checkpointerMock) { - val exception = intercept[RuntimeException] { - KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) - } - exception.getMessage().shouldBe(expectedErrorMessage) + when(checkpointerMock.checkpoint()) + .thenThrow(new ThrottlingException("error message")) + .thenThrow(new ThrottlingException(expectedErrorMessage)) + + val exception = intercept[RuntimeException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) } + exception.getMessage().shouldBe(expectedErrorMessage) + + verify(checkpointerMock, times(2)).checkpoint() } } diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index f2f0aa78b0a4b..efaf57e98c731 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -19,8 +19,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.3.2-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 8fac24b6ed86d..0a9fc906e39a3 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.3.2-SNAPSHOT ../pom.xml diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala index 4933aecba1286..21187be7678a6 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala @@ -77,7 +77,7 @@ object GraphLoader extends Logging { if (!line.isEmpty && line(0) != '#') { val lineArray = line.split("\\s+") if (lineArray.length < 2) { - logWarning("Invalid line: " + line) + throw new IllegalArgumentException("Invalid line: " + line) } val srcId = lineArray(0).toLong val dstId = lineArray(1).toLong diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala index 6dad167fa7411..904be213147dc 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala @@ -104,8 +104,14 @@ class VertexRDDImpl[VD] private[graphx] ( this.mapVertexPartitions(_.map(f)) override def diff(other: VertexRDD[VD]): VertexRDD[VD] = { + val otherPartition = other match { + case other: VertexRDD[_] if this.partitioner == other.partitioner => + other.partitionsRDD + case _ => + VertexRDD(other.partitionBy(this.partitioner.get)).partitionsRDD + } val newPartitionsRDD = partitionsRDD.zipPartitions( - other.partitionsRDD, preservesPartitioning = true + otherPartition, preservesPartitioning = true ) { (thisIter, otherIter) => val thisPart = thisIter.next() val otherPart = otherIter.next() @@ -133,7 +139,7 @@ class VertexRDDImpl[VD] private[graphx] ( // Test if the other vertex is a VertexRDD to choose the optimal join strategy. // If the other set is a VertexRDD then we use the much more efficient leftZipJoin other match { - case other: VertexRDD[_] => + case other: VertexRDD[_] if this.partitioner == other.partitioner => leftZipJoin(other)(f) case _ => this.withPartitionsRDD[VD3]( @@ -162,7 +168,7 @@ class VertexRDDImpl[VD] private[graphx] ( // Test if the other vertex is a VertexRDD to choose the optimal join strategy. // If the other set is a VertexRDD then we use the much more efficient innerZipJoin other match { - case other: VertexRDD[_] => + case other: VertexRDD[_] if this.partitioner == other.partitioner => innerZipJoin(other)(f) case _ => this.withPartitionsRDD( diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index f58587e10a820..fc84cfbe64184 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -17,6 +17,8 @@ package org.apache.spark.graphx.lib +import org.apache.spark.annotation.Experimental + import scala.util.Random import org.jblas.DoubleMatrix import org.apache.spark.rdd._ @@ -38,6 +40,8 @@ object SVDPlusPlus { extends Serializable /** + * :: Experimental :: + * * Implement SVD++ based on "Factorization Meets the Neighborhood: * a Multifaceted Collaborative Filtering Model", * available at [[http://public.research.att.com/~volinsky/netflix/kdd08koren.pdf]]. @@ -45,12 +49,33 @@ object SVDPlusPlus { * The prediction rule is rui = u + bu + bi + qi*(pu + |N(u)|^(-0.5)*sum(y)), * see the details on page 6. * + * This method temporarily replaces `run()`, and replaces `DoubleMatrix` in `run()`'s return + * value with `Array[Double]`. In 1.4.0, this method will be deprecated, but will be copied + * to replace `run()`, which will then be undeprecated. + * * @param edges edges for constructing the graph * * @param conf SVDPlusPlus parameters * * @return a graph with vertex attributes containing the trained model */ + @Experimental + def runSVDPlusPlus(edges: RDD[Edge[Double]], conf: Conf) + : (Graph[(Array[Double], Array[Double], Double, Double), Double], Double) = + { + val (graph, u) = run(edges, conf) + // Convert DoubleMatrix to Array[Double]: + val newVertices = graph.vertices.mapValues(v => (v._1.toArray, v._2.toArray, v._3, v._4)) + (Graph(newVertices, graph.edges), u) + } + + /** + * This method is deprecated in favor of `runSVDPlusPlus()`, which replaces `DoubleMatrix` + * with `Array[Double]` in its return value. This method is deprecated. It will effectively + * be removed in 1.4.0 when `runSVDPlusPlus()` is copied to replace `run()`, and hence the + * return type of this method changes. + */ + @deprecated("Call runSVDPlusPlus", "1.3.0") def run(edges: RDD[Edge[Double]], conf: Conf) : (Graph[(DoubleMatrix, DoubleMatrix, Double, Double), Double], Double) = { @@ -72,17 +97,22 @@ object SVDPlusPlus { // construct graph var g = Graph.fromEdges(edges, defaultF(conf.rank)).cache() + materialize(g) + edges.unpersist() // Calculate initial bias and norm val t0 = g.aggregateMessages[(Long, Double)]( ctx => { ctx.sendToSrc((1L, ctx.attr)); ctx.sendToDst((1L, ctx.attr)) }, (g1, g2) => (g1._1 + g2._1, g1._2 + g2._2)) - g = g.outerJoinVertices(t0) { + val gJoinT0 = g.outerJoinVertices(t0) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[(Long, Double)]) => (vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1)) - } + }.cache() + materialize(gJoinT0) + g.unpersist() + g = gJoinT0 def sendMsgTrainF(conf: Conf, u: Double) (ctx: EdgeContext[ @@ -114,12 +144,15 @@ object SVDPlusPlus { val t1 = g.aggregateMessages[DoubleMatrix]( ctx => ctx.sendToSrc(ctx.dstAttr._2), (g1, g2) => g1.addColumnVector(g2)) - g = g.outerJoinVertices(t1) { + val gJoinT1 = g.outerJoinVertices(t1) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[DoubleMatrix]) => if (msg.isDefined) (vd._1, vd._1 .addColumnVector(msg.get.mul(vd._4)), vd._3, vd._4) else vd - } + }.cache() + materialize(gJoinT1) + g.unpersist() + g = gJoinT1 // Phase 2, update p for user nodes and q, y for item nodes g.cache() @@ -127,13 +160,16 @@ object SVDPlusPlus { sendMsgTrainF(conf, u), (g1: (DoubleMatrix, DoubleMatrix, Double), g2: (DoubleMatrix, DoubleMatrix, Double)) => (g1._1.addColumnVector(g2._1), g1._2.addColumnVector(g2._2), g1._3 + g2._3)) - g = g.outerJoinVertices(t2) { + val gJoinT2 = g.outerJoinVertices(t2) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[(DoubleMatrix, DoubleMatrix, Double)]) => (vd._1.addColumnVector(msg.get._1), vd._2.addColumnVector(msg.get._2), vd._3 + msg.get._3, vd._4) - } + }.cache() + materialize(gJoinT2) + g.unpersist() + g = gJoinT2 } // calculate error on training set @@ -147,13 +183,26 @@ object SVDPlusPlus { val err = (ctx.attr - pred) * (ctx.attr - pred) ctx.sendToDst(err) } + g.cache() val t3 = g.aggregateMessages[Double](sendMsgTestF(conf, u), _ + _) - g = g.outerJoinVertices(t3) { + val gJoinT3 = g.outerJoinVertices(t3) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[Double]) => if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd - } + }.cache() + materialize(gJoinT3) + g.unpersist() + g = gJoinT3 (g, u) } + + /** + * Forces materialization of a Graph by count()ing its RDDs. + */ + private def materialize(g: Graph[_,_]): Unit = { + g.vertices.count() + g.edges.count() + } + } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala index 590f0474957dd..179f2843818e0 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala @@ -61,8 +61,8 @@ object ShortestPaths { } def sendMessage(edge: EdgeTriplet[SPMap, _]): Iterator[(VertexId, SPMap)] = { - val newAttr = incrementMap(edge.srcAttr) - if (edge.dstAttr != addMaps(newAttr, edge.dstAttr)) Iterator((edge.dstId, newAttr)) + val newAttr = incrementMap(edge.dstAttr) + if (edge.srcAttr != addMaps(newAttr, edge.srcAttr)) Iterator((edge.srcId, newAttr)) else Iterator.empty } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala index e01df56e94de9..9987a4b1a3c25 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala @@ -32,7 +32,7 @@ class SVDPlusPlusSuite extends FunSuite with LocalSparkContext { Edge(fields(0).toLong * 2, fields(1).toLong * 2 + 1, fields(2).toDouble) } val conf = new SVDPlusPlus.Conf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations - var (graph, u) = SVDPlusPlus.run(edges, conf) + var (graph, u) = SVDPlusPlus.runSVDPlusPlus(edges, conf) graph.cache() val err = graph.vertices.collect().map{ case (vid, vd) => if (vid % 2 == 1) vd._4 else 0.0 diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala index 265827b3341c2..f2c38e79c452c 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala @@ -40,7 +40,7 @@ class ShortestPathsSuite extends FunSuite with LocalSparkContext { val graph = Graph.fromEdgeTuples(edges, 1) val landmarks = Seq(1, 4).map(_.toLong) val results = ShortestPaths.run(graph, landmarks).vertices.collect.map { - case (v, spMap) => (v, spMap.mapValues(_.get)) + case (v, spMap) => (v, spMap.mapValues(i => i)) } assert(results.toSet === shortestPaths) } diff --git a/make-distribution.sh b/make-distribution.sh index 051c87c0894ae..dd990d4b96e46 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -98,9 +98,9 @@ done if [ -z "$JAVA_HOME" ]; then # Fall back on JAVA_HOME from rpm, if found if [ $(command -v rpm) ]; then - RPM_JAVA_HOME=$(rpm -E %java_home 2>/dev/null) + RPM_JAVA_HOME="$(rpm -E %java_home 2>/dev/null)" if [ "$RPM_JAVA_HOME" != "%java_home" ]; then - JAVA_HOME=$RPM_JAVA_HOME + JAVA_HOME="$RPM_JAVA_HOME" echo "No JAVA_HOME set, proceeding with '$JAVA_HOME' learned from rpm" fi fi @@ -113,24 +113,24 @@ fi if [ $(command -v git) ]; then GITREV=$(git rev-parse --short HEAD 2>/dev/null || :) - if [ ! -z $GITREV ]; then + if [ ! -z "$GITREV" ]; then GITREVSTRING=" (git revision $GITREV)" fi unset GITREV fi -if [ ! $(command -v $MVN) ] ; then +if [ ! $(command -v "$MVN") ] ; then echo -e "Could not locate Maven command: '$MVN'." echo -e "Specify the Maven command with the --mvn flag" exit -1; fi -VERSION=$($MVN help:evaluate -Dexpression=project.version 2>/dev/null | grep -v "INFO" | tail -n 1) -SPARK_HADOOP_VERSION=$($MVN help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\ +VERSION=$("$MVN" help:evaluate -Dexpression=project.version 2>/dev/null | grep -v "INFO" | tail -n 1) +SPARK_HADOOP_VERSION=$("$MVN" help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\ | grep -v "INFO"\ | tail -n 1) -SPARK_HIVE=$($MVN help:evaluate -Dexpression=project.activeProfiles -pl sql/hive $@ 2>/dev/null\ +SPARK_HIVE=$("$MVN" help:evaluate -Dexpression=project.activeProfiles -pl sql/hive $@ 2>/dev/null\ | grep -v "INFO"\ | fgrep --count "hive";\ # Reset exit status to 0, otherwise the script stops here if the last grep finds nothing\ @@ -147,7 +147,7 @@ if [[ ! "$JAVA_VERSION" =~ "1.6" && -z "$SKIP_JAVA_TEST" ]]; then echo "Output from 'java -version' was:" echo "$JAVA_VERSION" read -p "Would you like to continue anyways? [y,n]: " -r - if [[ ! $REPLY =~ ^[Yy]$ ]]; then + if [[ ! "$REPLY" =~ ^[Yy]$ ]]; then echo "Okay, exiting." exit 1 fi @@ -232,7 +232,7 @@ cp -r "$SPARK_HOME/ec2" "$DISTDIR" if [ "$SPARK_TACHYON" == "true" ]; then TMPD=`mktemp -d 2>/dev/null || mktemp -d -t 'disttmp'` - pushd $TMPD > /dev/null + pushd "$TMPD" > /dev/null echo "Fetching tachyon tgz" TACHYON_DL="${TACHYON_TGZ}.part" @@ -259,7 +259,7 @@ if [ "$SPARK_TACHYON" == "true" ]; then fi popd > /dev/null - rm -rf $TMPD + rm -rf "$TMPD" fi if [ "$MAKE_TGZ" == "true" ]; then diff --git a/mllib/pom.xml b/mllib/pom.xml index a8cee3d51a780..047d1287ee5e3 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.3.2-SNAPSHOT ../pom.xml @@ -63,7 +63,7 @@ org.scalanlp breeze_${scala.binary.version} - 0.10 + 0.11.2 diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index bc3defe968afd..eff7ef925dfbd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -34,7 +34,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * Fits a single model to the input data with optional parameters. * * @param dataset input dataset - * @param paramPairs optional list of param pairs (overwrite embedded params) + * @param paramPairs Optional list of param pairs. + * These values override any specified in this Estimator's embedded ParamMap. * @return fitted model */ @varargs @@ -47,7 +48,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * Fits a single model to the input data with provided parameter map. * * @param dataset input dataset - * @param paramMap parameter map + * @param paramMap Parameter map. + * These values override any specified in this Estimator's embedded ParamMap. * @return fitted model */ def fit(dataset: DataFrame, paramMap: ParamMap): M @@ -58,7 +60,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * Subclasses could overwrite this to optimize multi-model training. * * @param dataset input dataset - * @param paramMaps an array of parameter maps + * @param paramMaps An array of parameter maps. + * These values override any specified in this Estimator's embedded ParamMap. * @return fitted models, matching the input parameter maps */ def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index bb291e6e1fd7d..c4a36103303a2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml import scala.collection.mutable.ListBuffer import org.apache.spark.Logging -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -33,9 +33,17 @@ import org.apache.spark.sql.types.StructType abstract class PipelineStage extends Serializable with Logging { /** + * :: DeveloperApi :: + * * Derives the output schema from the input schema and parameters. + * The schema describes the columns and types of the data. + * + * @param schema Input schema to this stage + * @param paramMap Parameters passed to this stage + * @return Output schema from this stage */ - private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType + @DeveloperApi + def transformSchema(schema: StructType, paramMap: ParamMap): StructType /** * Derives the output schema from the input schema and parameters, optionally with logging. @@ -114,7 +122,9 @@ class Pipeline extends Estimator[PipelineModel] { throw new IllegalArgumentException( s"Do not support stage $stage of type ${stage.getClass}") } - curDataset = transformer.transform(curDataset, paramMap) + if (index < indexOfLastEstimator) { + curDataset = transformer.transform(curDataset, paramMap) + } transformers += transformer } else { transformers += stage.asInstanceOf[Transformer] @@ -124,7 +134,7 @@ class Pipeline extends Estimator[PipelineModel] { new PipelineModel(this, map, transformers.toArray) } - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { val map = this.paramMap ++ paramMap val theStages = map(stages) require(theStages.toSet.size == theStages.size, @@ -169,7 +179,7 @@ class PipelineModel private[ml] ( stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map)) } - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap val map = (fittingParamMap ++ this.paramMap) ++ paramMap stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index cd95c16aa768d..9a5848684b179 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -23,7 +23,7 @@ import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ /** @@ -62,7 +62,10 @@ abstract class Transformer extends PipelineStage with Params { private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] extends Transformer with HasInputCol with HasOutputCol with Logging { + /** @group setParam */ def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T] + + /** @group setParam */ def setOutputCol(value: String): T = set(outputCol, value).asInstanceOf[T] /** @@ -97,7 +100,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) val map = this.paramMap ++ paramMap - dataset.select($"*", callUDF( - this.createTransformFunc(map), outputDataType, dataset(map(inputCol))).as(map(outputCol))) + dataset.withColumn(map(outputCol), + callUDF(this.createTransformFunc(map), outputDataType, dataset(map(inputCol)))) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala new file mode 100644 index 0000000000000..c5fc89f935432 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.{DeveloperApi, AlphaComponent} +import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams} +import org.apache.spark.ml.param.{Params, ParamMap, HasRawPredictionCol} +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.{DataType, DoubleType, StructType} + + +/** + * :: DeveloperApi :: + * Params for classification. + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@DeveloperApi +private[spark] trait ClassifierParams extends PredictorParams + with HasRawPredictionCol { + + override protected def validateAndTransformSchema( + schema: StructType, + paramMap: ParamMap, + fitting: Boolean, + featuresDataType: DataType): StructType = { + val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType) + val map = this.paramMap ++ paramMap + addOutputColumn(parentSchema, map(rawPredictionCol), new VectorUDT) + } +} + +/** + * :: AlphaComponent :: + * Single-label binary or multiclass classification. + * Classes are indexed {0, 1, ..., numClasses - 1}. + * + * @tparam FeaturesType Type of input features. E.g., [[Vector]] + * @tparam E Concrete Estimator type + * @tparam M Concrete Model type + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] abstract class Classifier[ + FeaturesType, + E <: Classifier[FeaturesType, E, M], + M <: ClassificationModel[FeaturesType, M]] + extends Predictor[FeaturesType, E, M] + with ClassifierParams { + + /** @group setParam */ + def setRawPredictionCol(value: String): E = + set(rawPredictionCol, value).asInstanceOf[E] + + // TODO: defaultEvaluator (follow-up PR) +} + +/** + * :: AlphaComponent :: + * Model produced by a [[Classifier]]. + * Classes are indexed {0, 1, ..., numClasses - 1}. + * + * @tparam FeaturesType Type of input features. E.g., [[Vector]] + * @tparam M Concrete Model type + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] +abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[FeaturesType, M]] + extends PredictionModel[FeaturesType, M] with ClassifierParams { + + /** @group setParam */ + def setRawPredictionCol(value: String): M = set(rawPredictionCol, value).asInstanceOf[M] + + /** Number of classes (values which the label can take). */ + def numClasses: Int + + /** + * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by + * parameters: + * - predicted labels as [[predictionCol]] of type [[Double]] + * - raw predictions (confidences) as [[rawPredictionCol]] of type [[Vector]]. + * + * @param dataset input dataset + * @param paramMap additional parameters, overwrite embedded params + * @return transformed dataset + */ + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + // This default implementation should be overridden as needed. + + // Check schema + transformSchema(dataset.schema, paramMap, logging = true) + val map = this.paramMap ++ paramMap + + // Prepare model + val tmpModel = if (paramMap.size != 0) { + val tmpModel = this.copy() + Params.inheritValues(paramMap, parent, tmpModel) + tmpModel + } else { + this + } + + val (numColsOutput, outputData) = + ClassificationModel.transformColumnsImpl[FeaturesType](dataset, tmpModel, map) + if (numColsOutput == 0) { + logWarning(s"$uid: ClassificationModel.transform() was called as NOOP" + + " since no output columns were set.") + } + outputData + } + + /** + * :: DeveloperApi :: + * + * Predict label for the given features. + * This internal method is used to implement [[transform()]] and output [[predictionCol]]. + * + * This default implementation for classification predicts the index of the maximum value + * from [[predictRaw()]]. + */ + @DeveloperApi + override protected def predict(features: FeaturesType): Double = { + predictRaw(features).toArray.zipWithIndex.maxBy(_._1)._2 + } + + /** + * :: DeveloperApi :: + * + * Raw prediction for each possible label. + * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives + * a measure of confidence in each possible label (where larger = more confident). + * This internal method is used to implement [[transform()]] and output [[rawPredictionCol]]. + * + * @return vector where element i is the raw prediction for label i. + * This raw prediction may be any real number, where a larger value indicates greater + * confidence for that label. + */ + @DeveloperApi + protected def predictRaw(features: FeaturesType): Vector + +} + +private[ml] object ClassificationModel { + + /** + * Added prediction column(s). This is separated from [[ClassificationModel.transform()]] + * since it is used by [[org.apache.spark.ml.classification.ProbabilisticClassificationModel]]. + * @param dataset Input dataset + * @param map Parameter map. This will NOT be merged with the embedded paramMap; the merge + * should already be done. + * @return (number of columns added, transformed dataset) + */ + def transformColumnsImpl[FeaturesType]( + dataset: DataFrame, + model: ClassificationModel[FeaturesType, _], + map: ParamMap): (Int, DataFrame) = { + + // Output selected columns only. + // This is a bit complicated since it tries to avoid repeated computation. + var tmpData = dataset + var numColsOutput = 0 + if (map(model.rawPredictionCol) != "") { + // output raw prediction + val features2raw: FeaturesType => Vector = model.predictRaw + tmpData = tmpData.withColumn(map(model.rawPredictionCol), + callUDF(features2raw, new VectorUDT, col(map(model.featuresCol)))) + numColsOutput += 1 + if (map(model.predictionCol) != "") { + val raw2pred: Vector => Double = (rawPred) => { + rawPred.toArray.zipWithIndex.maxBy(_._1)._2 + } + tmpData = tmpData.withColumn(map(model.predictionCol), + callUDF(raw2pred, DoubleType, col(map(model.rawPredictionCol)))) + numColsOutput += 1 + } + } else if (map(model.predictionCol) != "") { + // output prediction + val features2pred: FeaturesType => Double = model.predict + tmpData = tmpData.withColumn(map(model.predictionCol), + callUDF(features2pred, DoubleType, col(map(model.featuresCol)))) + numColsOutput += 1 + } + (numColsOutput, tmpData) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 18be35ad59452..21f61d80dd95a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -18,128 +18,185 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS -import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.sql._ -import org.apache.spark.sql.Dsl._ -import org.apache.spark.sql.types.{DoubleType, StructField, StructType} +import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector, Vectors} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.DoubleType import org.apache.spark.storage.StorageLevel + /** - * :: AlphaComponent :: * Params for logistic regression. */ -@AlphaComponent -private[classification] trait LogisticRegressionParams extends Params - with HasRegParam with HasMaxIter with HasLabelCol with HasThreshold with HasFeaturesCol - with HasScoreCol with HasPredictionCol { +private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams + with HasRegParam with HasMaxIter with HasThreshold - /** - * Validates and transforms the input schema with the provided param map. - * @param schema input schema - * @param paramMap additional parameters - * @param fitting whether this is in fitting - * @return output schema - */ - protected def validateAndTransformSchema( - schema: StructType, - paramMap: ParamMap, - fitting: Boolean): StructType = { - val map = this.paramMap ++ paramMap - val featuresType = schema(map(featuresCol)).dataType - // TODO: Support casting Array[Double] and Array[Float] to Vector. - require(featuresType.isInstanceOf[VectorUDT], - s"Features column ${map(featuresCol)} must be a vector column but got $featuresType.") - if (fitting) { - val labelType = schema(map(labelCol)).dataType - require(labelType == DoubleType, - s"Cannot convert label column ${map(labelCol)} of type $labelType to a double column.") - } - val fieldNames = schema.fieldNames - require(!fieldNames.contains(map(scoreCol)), s"Score column ${map(scoreCol)} already exists.") - require(!fieldNames.contains(map(predictionCol)), - s"Prediction column ${map(predictionCol)} already exists.") - val outputFields = schema.fields ++ Seq( - StructField(map(scoreCol), DoubleType, false), - StructField(map(predictionCol), DoubleType, false)) - StructType(outputFields) - } -} /** + * :: AlphaComponent :: + * * Logistic regression. + * Currently, this class only supports binary classification. */ -class LogisticRegression extends Estimator[LogisticRegressionModel] with LogisticRegressionParams { +@AlphaComponent +class LogisticRegression + extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] + with LogisticRegressionParams { setRegParam(0.1) setMaxIter(100) setThreshold(0.5) + /** @group setParam */ def setRegParam(value: Double): this.type = set(regParam, value) + + /** @group setParam */ def setMaxIter(value: Int): this.type = set(maxIter, value) - def setLabelCol(value: String): this.type = set(labelCol, value) + + /** @group setParam */ def setThreshold(value: Double): this.type = set(threshold, value) - def setFeaturesCol(value: String): this.type = set(featuresCol, value) - def setScoreCol(value: String): this.type = set(scoreCol, value) - def setPredictionCol(value: String): this.type = set(predictionCol, value) - override def fit(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = { - transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap - val instances = dataset.select(map(labelCol), map(featuresCol)) - .map { case Row(label: Double, features: Vector) => - LabeledPoint(label, features) - }.persist(StorageLevel.MEMORY_AND_DISK) + override protected def train(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = { + // Extract columns from data. If dataset is persisted, do not persist oldDataset. + val oldDataset = extractLabeledPoints(dataset, paramMap) + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + if (handlePersistence) { + oldDataset.persist(StorageLevel.MEMORY_AND_DISK) + } + + // Train model val lr = new LogisticRegressionWithLBFGS lr.optimizer - .setRegParam(map(regParam)) - .setNumIterations(map(maxIter)) - val lrm = new LogisticRegressionModel(this, map, lr.run(instances).weights) - instances.unpersist() - // copy model params - Params.inheritValues(map, this, lrm) - lrm - } + .setRegParam(paramMap(regParam)) + .setNumIterations(paramMap(maxIter)) + val oldModel = lr.run(oldDataset) + val lrm = new LogisticRegressionModel(this, paramMap, oldModel.weights, oldModel.intercept) - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - validateAndTransformSchema(schema, paramMap, fitting = true) + if (handlePersistence) { + oldDataset.unpersist() + } + lrm } } + /** * :: AlphaComponent :: + * * Model produced by [[LogisticRegression]]. */ @AlphaComponent class LogisticRegressionModel private[ml] ( override val parent: LogisticRegression, override val fittingParamMap: ParamMap, - weights: Vector) - extends Model[LogisticRegressionModel] with LogisticRegressionParams { + val weights: Vector, + val intercept: Double) + extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] + with LogisticRegressionParams { + + setThreshold(0.5) + /** @group setParam */ def setThreshold(value: Double): this.type = set(threshold, value) - def setFeaturesCol(value: String): this.type = set(featuresCol, value) - def setScoreCol(value: String): this.type = set(scoreCol, value) - def setPredictionCol(value: String): this.type = set(predictionCol, value) - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - validateAndTransformSchema(schema, paramMap, fitting = false) + private val margin: Vector => Double = (features) => { + BLAS.dot(features, weights) + intercept + } + + private val score: Vector => Double = (features) => { + val m = margin(features) + 1.0 / (1.0 + math.exp(-m)) } override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + // This is overridden (a) to be more efficient (avoiding re-computing values when creating + // multiple output columns) and (b) to handle threshold, which the abstractions do not use. + // TODO: We should abstract away the steps defined by UDFs below so that the abstractions + // can call whichever UDFs are needed to create the output columns. + + // Check schema transformSchema(dataset.schema, paramMap, logging = true) + val map = this.paramMap ++ paramMap - val scoreFunction: Vector => Double = (v) => { - val margin = BLAS.dot(v, weights) - 1.0 / (1.0 + math.exp(-margin)) + + // Output selected columns only. + // This is a bit complicated since it tries to avoid repeated computation. + // rawPrediction (-margin, margin) + // probability (1.0-score, score) + // prediction (max margin) + var tmpData = dataset + var numColsOutput = 0 + if (map(rawPredictionCol) != "") { + val features2raw: Vector => Vector = (features) => predictRaw(features) + tmpData = tmpData.withColumn(map(rawPredictionCol), + callUDF(features2raw, new VectorUDT, col(map(featuresCol)))) + numColsOutput += 1 + } + if (map(probabilityCol) != "") { + if (map(rawPredictionCol) != "") { + val raw2prob = udf { (rawPreds: Vector) => + val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1))) + Vectors.dense(1.0 - prob1, prob1): Vector + } + tmpData = tmpData.withColumn(map(probabilityCol), raw2prob(col(map(rawPredictionCol)))) + } else { + val features2prob = udf { (features: Vector) => predictProbabilities(features) : Vector } + tmpData = tmpData.withColumn(map(probabilityCol), features2prob(col(map(featuresCol)))) + } + numColsOutput += 1 + } + if (map(predictionCol) != "") { + val t = map(threshold) + if (map(probabilityCol) != "") { + val predict = udf { probs: Vector => + if (probs(1) > t) 1.0 else 0.0 + } + tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(probabilityCol)))) + } else if (map(rawPredictionCol) != "") { + val predict = udf { rawPreds: Vector => + val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1))) + if (prob1 > t) 1.0 else 0.0 + } + tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(rawPredictionCol)))) + } else { + val predict = udf { features: Vector => this.predict(features) } + tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(featuresCol)))) + } + numColsOutput += 1 + } + if (numColsOutput == 0) { + this.logWarning(s"$uid: LogisticRegressionModel.transform() was called as NOOP" + + " since no output columns were set.") } - val t = map(threshold) - val predictFunction: Double => Double = (score) => { if (score > t) 1.0 else 0.0 } - dataset - .select($"*", callUDF(scoreFunction, col(map(featuresCol))).as(map(scoreCol))) - .select($"*", callUDF(predictFunction, col(map(scoreCol))).as(map(predictionCol))) + tmpData + } + + override val numClasses: Int = 2 + + /** + * Predict label for the given feature vector. + * The behavior of this can be adjusted using [[threshold]]. + */ + override protected def predict(features: Vector): Double = { + println(s"LR.predict with threshold: ${paramMap(threshold)}") + if (score(features) > paramMap(threshold)) 1 else 0 + } + + override protected def predictProbabilities(features: Vector): Vector = { + val s = score(features) + Vectors.dense(1.0 - s, s) + } + + override protected def predictRaw(features: Vector): Vector = { + val m = margin(features) + Vectors.dense(0.0, m) + } + + override protected def copy(): LogisticRegressionModel = { + val m = new LogisticRegressionModel(parent, fittingParamMap, weights, intercept) + Params.inheritValues(this.paramMap, this, m) + m } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala new file mode 100644 index 0000000000000..bd8caac855981 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} +import org.apache.spark.ml.param.{HasProbabilityCol, ParamMap, Params} +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DataType, StructType} + + +/** + * Params for probabilistic classification. + */ +private[classification] trait ProbabilisticClassifierParams + extends ClassifierParams with HasProbabilityCol { + + override protected def validateAndTransformSchema( + schema: StructType, + paramMap: ParamMap, + fitting: Boolean, + featuresDataType: DataType): StructType = { + val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType) + val map = this.paramMap ++ paramMap + addOutputColumn(parentSchema, map(probabilityCol), new VectorUDT) + } +} + + +/** + * :: AlphaComponent :: + * + * Single-label binary or multiclass classifier which can output class conditional probabilities. + * + * @tparam FeaturesType Type of input features. E.g., [[Vector]] + * @tparam E Concrete Estimator type + * @tparam M Concrete Model type + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] abstract class ProbabilisticClassifier[ + FeaturesType, + E <: ProbabilisticClassifier[FeaturesType, E, M], + M <: ProbabilisticClassificationModel[FeaturesType, M]] + extends Classifier[FeaturesType, E, M] with ProbabilisticClassifierParams { + + /** @group setParam */ + def setProbabilityCol(value: String): E = set(probabilityCol, value).asInstanceOf[E] +} + + +/** + * :: AlphaComponent :: + * + * Model produced by a [[ProbabilisticClassifier]]. + * Classes are indexed {0, 1, ..., numClasses - 1}. + * + * @tparam FeaturesType Type of input features. E.g., [[Vector]] + * @tparam M Concrete Model type + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] abstract class ProbabilisticClassificationModel[ + FeaturesType, + M <: ProbabilisticClassificationModel[FeaturesType, M]] + extends ClassificationModel[FeaturesType, M] with ProbabilisticClassifierParams { + + /** @group setParam */ + def setProbabilityCol(value: String): M = set(probabilityCol, value).asInstanceOf[M] + + /** + * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by + * parameters: + * - predicted labels as [[predictionCol]] of type [[Double]] + * - raw predictions (confidences) as [[rawPredictionCol]] of type [[Vector]] + * - probability of each class as [[probabilityCol]] of type [[Vector]]. + * + * @param dataset input dataset + * @param paramMap additional parameters, overwrite embedded params + * @return transformed dataset + */ + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + // This default implementation should be overridden as needed. + + // Check schema + transformSchema(dataset.schema, paramMap, logging = true) + val map = this.paramMap ++ paramMap + + // Prepare model + val tmpModel = if (paramMap.size != 0) { + val tmpModel = this.copy() + Params.inheritValues(paramMap, parent, tmpModel) + tmpModel + } else { + this + } + + val (numColsOutput, outputData) = + ClassificationModel.transformColumnsImpl[FeaturesType](dataset, tmpModel, map) + + // Output selected columns only. + if (map(probabilityCol) != "") { + // output probabilities + val features2probs: FeaturesType => Vector = (features) => { + tmpModel.predictProbabilities(features) + } + outputData.withColumn(map(probabilityCol), + callUDF(features2probs, new VectorUDT, col(map(featuresCol)))) + } else { + if (numColsOutput == 0) { + this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" + + " since no output columns were set.") + } + outputData + } + } + + /** + * :: DeveloperApi :: + * + * Predict the probability of each class given the features. + * These predictions are also called class conditional probabilities. + * + * WARNING: Not all models output well-calibrated probability estimates! These probabilities + * should be treated as confidences, not precise probabilities. + * + * This internal method is used to implement [[transform()]] and output [[probabilityCol]]. + */ + @DeveloperApi + protected def predictProbabilities(features: FeaturesType): Vector +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 1979ab9eb6516..2360f4479f1c2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -18,43 +18,53 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml._ +import org.apache.spark.ml.Evaluator import org.apache.spark.ml.param._ import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.DoubleType + /** * :: AlphaComponent :: + * * Evaluator for binary classification, which expects two input columns: score and label. */ @AlphaComponent class BinaryClassificationEvaluator extends Evaluator with Params - with HasScoreCol with HasLabelCol { + with HasRawPredictionCol with HasLabelCol { - /** param for metric name in evaluation */ + /** + * param for metric name in evaluation + * @group param + */ val metricName: Param[String] = new Param(this, "metricName", "metric name in evaluation (areaUnderROC|areaUnderPR)", Some("areaUnderROC")) + + /** @group getParam */ def getMetricName: String = get(metricName) + + /** @group setParam */ def setMetricName(value: String): this.type = set(metricName, value) - def setScoreCol(value: String): this.type = set(scoreCol, value) + /** @group setParam */ + def setScoreCol(value: String): this.type = set(rawPredictionCol, value) + + /** @group setParam */ def setLabelCol(value: String): this.type = set(labelCol, value) override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = { val map = this.paramMap ++ paramMap val schema = dataset.schema - val scoreType = schema(map(scoreCol)).dataType - require(scoreType == DoubleType, - s"Score column ${map(scoreCol)} must be double type but found $scoreType") - val labelType = schema(map(labelCol)).dataType - require(labelType == DoubleType, - s"Label column ${map(labelCol)} must be double type but found $labelType") + checkInputColumn(schema, map(rawPredictionCol), new VectorUDT) + checkInputColumn(schema, map(labelCol), DoubleType) - val scoreAndLabels = dataset.select(map(scoreCol), map(labelCol)) - .map { case Row(score: Double, label: Double) => - (score, label) + // TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2. + val scoreAndLabels = dataset.select(map(rawPredictionCol), map(labelCol)) + .map { case Row(rawPrediction: Vector, label: Double) => + (rawPrediction(1), label) } val metrics = new BinaryClassificationMetrics(scoreAndLabels) val metric = map(metricName) match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 0956062643f23..6131ba8832691 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -31,11 +31,18 @@ import org.apache.spark.sql.types.DataType @AlphaComponent class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] { - /** number of features */ + /** + * number of features + * @group param + */ val numFeatures = new IntParam(this, "numFeatures", "number of features", Some(1 << 18)) - def setNumFeatures(value: Int) = set(numFeatures, value) + + /** @group getParam */ def getNumFeatures: Int = get(numFeatures) + /** @group setParam */ + def setNumFeatures(value: Int) = set(numFeatures, value) + override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = { val hashingTF = new feature.HashingTF(paramMap(numFeatures)) hashingTF.transform diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 01a4f5eb205e5..1142aa4f8e73d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -23,7 +23,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StructField, StructType} /** @@ -39,7 +39,10 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with @AlphaComponent class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams { + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = { @@ -52,7 +55,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP model } - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { val map = this.paramMap ++ paramMap val inputType = schema(map(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], @@ -75,19 +78,20 @@ class StandardScalerModel private[ml] ( scaler: feature.StandardScalerModel) extends Model[StandardScalerModel] with StandardScalerParams { + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) val map = this.paramMap ++ paramMap - val scale: (Vector) => Vector = (v) => { - scaler.transform(v) - } - dataset.select($"*", callUDF(scale, col(map(inputCol))).as(map(outputCol))) + val scale = udf((v: Vector) => { scaler.transform(v) } : Vector) + dataset.withColumn(map(outputCol), scale(col(map(inputCol)))) } - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { val map = this.paramMap ++ paramMap val inputType = schema(map(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index e622a5cf9e6f3..0b1f90daa7d8e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -29,11 +29,11 @@ import org.apache.spark.sql.types.{DataType, StringType, ArrayType} @AlphaComponent class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] { - protected override def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { + override protected def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { _.toLowerCase.split("\\s") } - protected override def validateInputType(inputType: DataType): Unit = { + override protected def validateInputType(inputType: DataType): Unit = { require(inputType == StringType, s"Input type must be string type but got $inputType.") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala new file mode 100644 index 0000000000000..dfb89cc8d4af3 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.impl.estimator + +import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.linalg.{VectorUDT, Vector} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DataType, DoubleType, StructType} + + +/** + * :: DeveloperApi :: + * + * Trait for parameters for prediction (regression and classification). + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@DeveloperApi +private[spark] trait PredictorParams extends Params + with HasLabelCol with HasFeaturesCol with HasPredictionCol { + + /** + * Validates and transforms the input schema with the provided param map. + * @param schema input schema + * @param paramMap additional parameters + * @param fitting whether this is in fitting + * @param featuresDataType SQL DataType for FeaturesType. + * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. + * @return output schema + */ + protected def validateAndTransformSchema( + schema: StructType, + paramMap: ParamMap, + fitting: Boolean, + featuresDataType: DataType): StructType = { + val map = this.paramMap ++ paramMap + // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector + checkInputColumn(schema, map(featuresCol), featuresDataType) + if (fitting) { + // TODO: Allow other numeric types + checkInputColumn(schema, map(labelCol), DoubleType) + } + addOutputColumn(schema, map(predictionCol), DoubleType) + } +} + +/** + * :: AlphaComponent :: + * + * Abstraction for prediction problems (regression and classification). + * + * @tparam FeaturesType Type of features. + * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. + * @tparam Learner Specialization of this class. If you subclass this type, use this type + * parameter to specify the concrete type. + * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type + * parameter to specify the concrete type for the corresponding model. + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] abstract class Predictor[ + FeaturesType, + Learner <: Predictor[FeaturesType, Learner, M], + M <: PredictionModel[FeaturesType, M]] + extends Estimator[M] with PredictorParams { + + /** @group setParam */ + def setLabelCol(value: String): Learner = set(labelCol, value).asInstanceOf[Learner] + + /** @group setParam */ + def setFeaturesCol(value: String): Learner = set(featuresCol, value).asInstanceOf[Learner] + + /** @group setParam */ + def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner] + + override def fit(dataset: DataFrame, paramMap: ParamMap): M = { + // This handles a few items such as schema validation. + // Developers only need to implement train(). + transformSchema(dataset.schema, paramMap, logging = true) + val map = this.paramMap ++ paramMap + val model = train(dataset, map) + Params.inheritValues(map, this, model) // copy params to model + model + } + + /** + * :: DeveloperApi :: + * + * Train a model using the given dataset and parameters. + * Developers can implement this instead of [[fit()]] to avoid dealing with schema validation + * and copying parameters into the model. + * + * @param dataset Training dataset + * @param paramMap Parameter map. Unlike [[fit()]]'s paramMap, this paramMap has already + * been combined with the embedded ParamMap. + * @return Fitted model + */ + @DeveloperApi + protected def train(dataset: DataFrame, paramMap: ParamMap): M + + /** + * :: DeveloperApi :: + * + * Returns the SQL DataType corresponding to the FeaturesType type parameter. + * + * This is used by [[validateAndTransformSchema()]]. + * This workaround is needed since SQL has different APIs for Scala and Java. + * + * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector. + */ + @DeveloperApi + protected def featuresDataType: DataType = new VectorUDT + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap, fitting = true, featuresDataType) + } + + /** + * Extract [[labelCol]] and [[featuresCol]] from the given dataset, + * and put it in an RDD with strong types. + */ + protected def extractLabeledPoints(dataset: DataFrame, paramMap: ParamMap): RDD[LabeledPoint] = { + val map = this.paramMap ++ paramMap + dataset.select(map(labelCol), map(featuresCol)) + .map { case Row(label: Double, features: Vector) => + LabeledPoint(label, features) + } + } +} + +/** + * :: AlphaComponent :: + * + * Abstraction for a model for prediction tasks (regression and classification). + * + * @tparam FeaturesType Type of features. + * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. + * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type + * parameter to specify the concrete type for the corresponding model. + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]] + extends Model[M] with PredictorParams { + + /** @group setParam */ + def setFeaturesCol(value: String): M = set(featuresCol, value).asInstanceOf[M] + + /** @group setParam */ + def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M] + + /** + * :: DeveloperApi :: + * + * Returns the SQL DataType corresponding to the FeaturesType type parameter. + * + * This is used by [[validateAndTransformSchema()]]. + * This workaround is needed since SQL has different APIs for Scala and Java. + * + * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector. + */ + @DeveloperApi + protected def featuresDataType: DataType = new VectorUDT + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap, fitting = false, featuresDataType) + } + + /** + * Transforms dataset by reading from [[featuresCol]], calling [[predict()]], and storing + * the predictions as a new column [[predictionCol]]. + * + * @param dataset input dataset + * @param paramMap additional parameters, overwrite embedded params + * @return transformed dataset with [[predictionCol]] of type [[Double]] + */ + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + // This default implementation should be overridden as needed. + + // Check schema + transformSchema(dataset.schema, paramMap, logging = true) + val map = this.paramMap ++ paramMap + + // Prepare model + val tmpModel = if (paramMap.size != 0) { + val tmpModel = this.copy() + Params.inheritValues(paramMap, parent, tmpModel) + tmpModel + } else { + this + } + + if (map(predictionCol) != "") { + val pred: FeaturesType => Double = (features) => { + tmpModel.predict(features) + } + dataset.withColumn(map(predictionCol), callUDF(pred, DoubleType, col(map(featuresCol)))) + } else { + this.logWarning(s"$uid: Predictor.transform() was called as NOOP" + + " since no output columns were set.") + dataset + } + } + + /** + * :: DeveloperApi :: + * + * Predict label for the given features. + * This internal method is used to implement [[transform()]] and output [[predictionCol]]. + */ + @DeveloperApi + protected def predict(features: FeaturesType): Double + + /** + * Create a copy of the model. + * The copy is shallow, except for the embedded paramMap, which gets a deep copy. + */ + protected def copy(): M +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala index 51cd48c90432a..b45bd1499b72e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/package.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala @@ -20,5 +20,19 @@ package org.apache.spark /** * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly * assemble and configure practical machine learning pipelines. + * + * @groupname param Parameters + * @groupdesc param A list of (hyper-)parameter keys this algorithm can take. Users can set and get + * the parameter values through setters and getters, respectively. + * @groupprio param -5 + * + * @groupname setParam Parameter setters + * @groupprio setParam 5 + * + * @groupname getParam Parameter getters + * @groupprio getParam 6 + * + * @groupname Ungrouped Members + * @groupprio Ungrouped 0 */ package object ml diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 5fb4379e23c2f..17ece897a6c55 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -22,8 +22,10 @@ import scala.collection.mutable import java.lang.reflect.Modifier -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} import org.apache.spark.ml.Identifiable +import org.apache.spark.sql.types.{DataType, StructField, StructType} + /** * :: AlphaComponent :: @@ -65,37 +67,47 @@ class Param[T] ( // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... /** Specialized version of [[Param[Double]]] for Java. */ -class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double] = None) +class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double]) extends Param[Double](parent, name, doc, defaultValue) { + def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) + override def w(value: Double): ParamPair[Double] = super.w(value) } /** Specialized version of [[Param[Int]]] for Java. */ -class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int] = None) +class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int]) extends Param[Int](parent, name, doc, defaultValue) { + def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) + override def w(value: Int): ParamPair[Int] = super.w(value) } /** Specialized version of [[Param[Float]]] for Java. */ -class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float] = None) +class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float]) extends Param[Float](parent, name, doc, defaultValue) { + def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) + override def w(value: Float): ParamPair[Float] = super.w(value) } /** Specialized version of [[Param[Long]]] for Java. */ -class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long] = None) +class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long]) extends Param[Long](parent, name, doc, defaultValue) { + def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) + override def w(value: Long): ParamPair[Long] = super.w(value) } /** Specialized version of [[Param[Boolean]]] for Java. */ -class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean] = None) +class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean]) extends Param[Boolean](parent, name, doc, defaultValue) { + def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) + override def w(value: Boolean): ParamPair[Boolean] = super.w(value) } @@ -158,7 +170,7 @@ trait Params extends Identifiable with Serializable { /** * Sets a parameter in the embedded param map. */ - private[ml] def set[T](param: Param[T], value: T): this.type = { + protected def set[T](param: Param[T], value: T): this.type = { require(param.parent.eq(this)) paramMap.put(param.asInstanceOf[Param[Any]], value) this @@ -174,7 +186,7 @@ trait Params extends Identifiable with Serializable { /** * Gets the value of a parameter in the embedded param map. */ - private[ml] def get[T](param: Param[T]): T = { + protected def get[T](param: Param[T]): T = { require(param.parent.eq(this)) paramMap(param) } @@ -183,9 +195,40 @@ trait Params extends Identifiable with Serializable { * Internal param map. */ protected val paramMap: ParamMap = ParamMap.empty + + /** + * Check whether the given schema contains an input column. + * @param colName Parameter name for the input column. + * @param dataType SQL DataType of the input column. + */ + protected def checkInputColumn(schema: StructType, colName: String, dataType: DataType): Unit = { + val actualDataType = schema(colName).dataType + require(actualDataType.equals(dataType), + s"Input column $colName must be of type $dataType" + + s" but was actually $actualDataType. Column param description: ${getParam(colName)}") + } + + protected def addOutputColumn( + schema: StructType, + colName: String, + dataType: DataType): StructType = { + if (colName.length == 0) return schema + val fieldNames = schema.fieldNames + require(!fieldNames.contains(colName), s"Prediction column $colName already exists.") + val outputFields = schema.fields ++ Seq(StructField(colName, dataType, nullable = false)) + StructType(outputFields) + } } -private[ml] object Params { +/** + * :: DeveloperApi :: + * + * Helper functionality for developers. + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@DeveloperApi +private[spark] object Params { /** * Copies parameter values from the parent estimator to the child model it produced. @@ -279,7 +322,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten def copy: ParamMap = new ParamMap(map.clone()) override def toString: String = { - map.map { case (param, value) => + map.toSeq.sortBy(_._1.name).map { case (param, value) => s"\t${param.parent.uid}-${param.name}: $value" }.mkString("{\n", ",\n", "\n}") } @@ -310,6 +353,11 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten ParamPair(param, value) } } + + /** + * Number of param pairs in this set. + */ + def size: Int = map.size } object ParamMap { diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala index ef141d3eb2b06..5d660d1e151a7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala @@ -17,58 +17,135 @@ package org.apache.spark.ml.param +/* NOTE TO DEVELOPERS: + * If you mix these parameter traits into your algorithm, please add a setter method as well + * so that users may use a builder pattern: + * val myLearner = new MyLearner().setParam1(x).setParam2(y)... + */ + private[ml] trait HasRegParam extends Params { - /** param for regularization parameter */ + /** + * param for regularization parameter + * @group param + */ val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter") + + /** @group getParam */ def getRegParam: Double = get(regParam) } private[ml] trait HasMaxIter extends Params { - /** param for max number of iterations */ + /** + * param for max number of iterations + * @group param + */ val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") + + /** @group getParam */ def getMaxIter: Int = get(maxIter) } private[ml] trait HasFeaturesCol extends Params { - /** param for features column name */ + /** + * param for features column name + * @group param + */ val featuresCol: Param[String] = new Param(this, "featuresCol", "features column name", Some("features")) + + /** @group getParam */ def getFeaturesCol: String = get(featuresCol) } private[ml] trait HasLabelCol extends Params { - /** param for label column name */ + /** + * param for label column name + * @group param + */ val labelCol: Param[String] = new Param(this, "labelCol", "label column name", Some("label")) - def getLabelCol: String = get(labelCol) -} -private[ml] trait HasScoreCol extends Params { - /** param for score column name */ - val scoreCol: Param[String] = new Param(this, "scoreCol", "score column name", Some("score")) - def getScoreCol: String = get(scoreCol) + /** @group getParam */ + def getLabelCol: String = get(labelCol) } private[ml] trait HasPredictionCol extends Params { - /** param for prediction column name */ + /** + * param for prediction column name + * @group param + */ val predictionCol: Param[String] = new Param(this, "predictionCol", "prediction column name", Some("prediction")) + + /** @group getParam */ def getPredictionCol: String = get(predictionCol) } +private[ml] trait HasRawPredictionCol extends Params { + /** + * param for raw prediction column name + * @group param + */ + val rawPredictionCol: Param[String] = + new Param(this, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name", + Some("rawPrediction")) + + /** @group getParam */ + def getRawPredictionCol: String = get(rawPredictionCol) +} + +private[ml] trait HasProbabilityCol extends Params { + /** + * param for predicted class conditional probabilities column name + * @group param + */ + val probabilityCol: Param[String] = + new Param(this, "probabilityCol", "column name for predicted class conditional probabilities", + Some("probability")) + + /** @group getParam */ + def getProbabilityCol: String = get(probabilityCol) +} + private[ml] trait HasThreshold extends Params { - /** param for threshold in (binary) prediction */ + /** + * param for threshold in (binary) prediction + * @group param + */ val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction") + + /** @group getParam */ def getThreshold: Double = get(threshold) } private[ml] trait HasInputCol extends Params { - /** param for input column name */ + /** + * param for input column name + * @group param + */ val inputCol: Param[String] = new Param(this, "inputCol", "input column name") + + /** @group getParam */ def getInputCol: String = get(inputCol) } private[ml] trait HasOutputCol extends Params { - /** param for output column name */ + /** + * param for output column name + * @group param + */ val outputCol: Param[String] = new Param(this, "outputCol", "output column name") + + /** @group getParam */ def getOutputCol: String = get(outputCol) } + +private[ml] trait HasCheckpointInterval extends Params { + /** + * param for checkpoint interval + * @group param + */ + val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval") + + /** @group getParam */ + def getCheckpointInterval: Int = get(checkpointInterval) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 511cb2fe4005e..fc0f529713e3a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.recommendation import java.{util => ju} +import java.io.IOException import scala.collection.mutable import scala.reflect.ClassTag @@ -26,17 +27,18 @@ import scala.util.hashing.byteswap64 import com.github.fommil.netlib.BLAS.{getInstance => blas} import com.github.fommil.netlib.LAPACK.{getInstance => lapack} +import org.apache.hadoop.fs.{FileSystem, Path} import org.jblas.DoubleMatrix import org.netlib.util.intW -import org.apache.spark.{HashPartitioner, Logging, Partitioner} +import org.apache.spark.{Logging, Partitioner} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -47,45 +49,91 @@ import org.apache.spark.util.random.XORShiftRandom * Common params for ALS. */ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam - with HasPredictionCol { + with HasPredictionCol with HasCheckpointInterval { - /** Param for rank of the matrix factorization. */ + /** + * Param for rank of the matrix factorization. + * @group param + */ val rank = new IntParam(this, "rank", "rank of the factorization", Some(10)) + + /** @group getParam */ def getRank: Int = get(rank) - /** Param for number of user blocks. */ + /** + * Param for number of user blocks. + * @group param + */ val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks", Some(10)) + + /** @group getParam */ def getNumUserBlocks: Int = get(numUserBlocks) - /** Param for number of item blocks. */ + /** + * Param for number of item blocks. + * @group param + */ val numItemBlocks = new IntParam(this, "numItemBlocks", "number of item blocks", Some(10)) + + /** @group getParam */ def getNumItemBlocks: Int = get(numItemBlocks) - /** Param to decide whether to use implicit preference. */ + /** + * Param to decide whether to use implicit preference. + * @group param + */ val implicitPrefs = new BooleanParam(this, "implicitPrefs", "whether to use implicit preference", Some(false)) + + /** @group getParam */ def getImplicitPrefs: Boolean = get(implicitPrefs) - /** Param for the alpha parameter in the implicit preference formulation. */ + /** + * Param for the alpha parameter in the implicit preference formulation. + * @group param + */ val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference", Some(1.0)) + + /** @group getParam */ def getAlpha: Double = get(alpha) - /** Param for the column name for user ids. */ + /** + * Param for the column name for user ids. + * @group param + */ val userCol = new Param[String](this, "userCol", "column name for user ids", Some("user")) + + /** @group getParam */ def getUserCol: String = get(userCol) - /** Param for the column name for item ids. */ + /** + * Param for the column name for item ids. + * @group param + */ val itemCol = new Param[String](this, "itemCol", "column name for item ids", Some("item")) + + /** @group getParam */ def getItemCol: String = get(itemCol) - /** Param for the column name for ratings. */ + /** + * Param for the column name for ratings. + * @group param + */ val ratingCol = new Param[String](this, "ratingCol", "column name for ratings", Some("rating")) + + /** @group getParam */ def getRatingCol: String = get(ratingCol) + /** + * Param for whether to apply nonnegativity constraints. + * @group param + */ val nonnegative = new BooleanParam( this, "nonnegative", "whether to use nonnegative constraint for least squares", Some(false)) + + /** @group getParam */ val getNonnegative: Boolean = get(nonnegative) /** @@ -119,32 +167,31 @@ class ALSModel private[ml] ( itemFactors: RDD[(Int, Array[Float])]) extends Model[ALSModel] with ALSParams { + /** @group setParam */ def setPredictionCol(value: String): this.type = set(predictionCol, value) override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { - import dataset.sqlContext.createDataFrame + import dataset.sqlContext.implicits._ val map = this.paramMap ++ paramMap - val users = userFactors.toDataFrame("id", "features") - val items = itemFactors.toDataFrame("id", "features") - val predict: (Seq[Float], Seq[Float]) => Float = (userFeatures, itemFeatures) => { + val users = userFactors.toDF("id", "features") + val items = itemFactors.toDF("id", "features") + + // Register a UDF for DataFrame, and then + // create a new column named map(predictionCol) by running the predict UDF. + val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => if (userFeatures != null && itemFeatures != null) { blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1) } else { Float.NaN } } - val inputColumns = dataset.schema.fieldNames - val prediction = callUDF(predict, users("features"), items("features")).as(map(predictionCol)) - val outputColumns = inputColumns.map(f => dataset(f)) :+ prediction dataset .join(users, dataset(map(userCol)) === users("id"), "left") .join(items, dataset(map(itemCol)) === items("id"), "left") - .select(outputColumns: _*) - // TODO: Just use a dataset("*") - // .select(dataset("*"), prediction) + .select(dataset("*"), predict(users("features"), items("features")).as(map(predictionCol))) } - override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { validateAndTransformSchema(schema, paramMap) } } @@ -183,20 +230,49 @@ class ALS extends Estimator[ALSModel] with ALSParams { import org.apache.spark.ml.recommendation.ALS.Rating + /** @group setParam */ def setRank(value: Int): this.type = set(rank, value) + + /** @group setParam */ def setNumUserBlocks(value: Int): this.type = set(numUserBlocks, value) + + /** @group setParam */ def setNumItemBlocks(value: Int): this.type = set(numItemBlocks, value) + + /** @group setParam */ def setImplicitPrefs(value: Boolean): this.type = set(implicitPrefs, value) + + /** @group setParam */ def setAlpha(value: Double): this.type = set(alpha, value) + + /** @group setParam */ def setUserCol(value: String): this.type = set(userCol, value) + + /** @group setParam */ def setItemCol(value: String): this.type = set(itemCol, value) + + /** @group setParam */ def setRatingCol(value: String): this.type = set(ratingCol, value) + + /** @group setParam */ def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ def setRegParam(value: Double): this.type = set(regParam, value) + + /** @group setParam */ def setNonnegative(value: Boolean): this.type = set(nonnegative, value) - /** Sets both numUserBlocks and numItemBlocks to the specific value. */ + /** @group setParam */ + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + + /** + * Sets both numUserBlocks and numItemBlocks to the specific value. + * @group setParam + */ def setNumBlocks(value: Int): this.type = { setNumUserBlocks(value) setNumItemBlocks(value) @@ -205,6 +281,7 @@ class ALS extends Estimator[ALSModel] with ALSParams { setMaxIter(20) setRegParam(1.0) + setCheckpointInterval(10) override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = { val map = this.paramMap ++ paramMap @@ -216,13 +293,14 @@ class ALS extends Estimator[ALSModel] with ALSParams { val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank), numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks), maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs), - alpha = map(alpha), nonnegative = map(nonnegative)) + alpha = map(alpha), nonnegative = map(nonnegative), + checkpointInterval = map(checkpointInterval)) val model = new ALSModel(this, map, map(rank), userFactors, itemFactors) Params.inheritValues(map, this, model) model } - override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { validateAndTransformSchema(schema, paramMap) } } @@ -243,7 +321,7 @@ object ALS extends Logging { /** Trait for least squares solvers applied to the normal equation. */ private[recommendation] trait LeastSquaresNESolver extends Serializable { - /** Solves a least squares problem (possibly with other constraints). */ + /** Solves a least squares problem with regularization (possibly with other constraints). */ def solve(ne: NormalEquation, lambda: Double): Array[Float] } @@ -255,20 +333,19 @@ object ALS extends Logging { /** * Solves a least squares problem with L2 regularization: * - * min norm(A x - b)^2^ + lambda * n * norm(x)^2^ + * min norm(A x - b)^2^ + lambda * norm(x)^2^ * * @param ne a [[NormalEquation]] instance that contains AtA, Atb, and n (number of instances) - * @param lambda regularization constant, which will be scaled by n + * @param lambda regularization constant * @return the solution x */ override def solve(ne: NormalEquation, lambda: Double): Array[Float] = { val k = ne.k // Add scaled lambda to the diagonals of AtA. - val scaledlambda = lambda * ne.n var i = 0 var j = 2 while (i < ne.triK) { - ne.ata(i) += scaledlambda + ne.ata(i) += lambda i += j j += 1 } @@ -314,7 +391,7 @@ object ALS extends Logging { override def solve(ne: NormalEquation, lambda: Double): Array[Float] = { val rank = ne.k initialize(rank) - fillAtA(ne.ata, lambda * ne.n) + fillAtA(ne.ata, lambda) val x = NNLS.solve(ata, new DoubleMatrix(rank, 1, ne.atb: _*), workspace) ne.reset() x.map(x => x.toFloat) @@ -344,7 +421,15 @@ object ALS extends Logging { } } - /** Representing a normal equation (ALS' subproblem). */ + /** + * Representing a normal equation to solve the following weighted least squares problem: + * + * minimize \sum,,i,, c,,i,, (a,,i,,^T^ x - b,,i,,)^2^ + lambda * x^T^ x. + * + * Its normal equation is given by + * + * \sum,,i,, c,,i,, (a,,i,, a,,i,,^T^ x - b,,i,, a,,i,,) + lambda * x = 0. + */ private[recommendation] class NormalEquation(val k: Int) extends Serializable { /** Number of entries in the upper triangular part of a k-by-k matrix. */ @@ -353,8 +438,6 @@ object ALS extends Logging { val ata = new Array[Double](triK) /** A^T^ * b */ val atb = new Array[Double](k) - /** Number of observations. */ - var n = 0 private val da = new Array[Double](k) private val upper = "U" @@ -368,28 +451,13 @@ object ALS extends Logging { } /** Adds an observation. */ - def add(a: Array[Float], b: Float): this.type = { - require(a.length == k) - copyToDouble(a) - blas.dspr(upper, k, 1.0, da, 1, ata) - blas.daxpy(k, b.toDouble, da, 1, atb, 1) - n += 1 - this - } - - /** - * Adds an observation with implicit feedback. Note that this does not increment the counter. - */ - def addImplicit(a: Array[Float], b: Float, alpha: Double): this.type = { + def add(a: Array[Float], b: Double, c: Double = 1.0): this.type = { + require(c >= 0.0) require(a.length == k) - // Extension to the original paper to handle b < 0. confidence is a function of |b| instead - // so that it is never negative. - val confidence = 1.0 + alpha * math.abs(b) copyToDouble(a) - blas.dspr(upper, k, confidence - 1.0, da, 1, ata) - // For b <= 0, the corresponding preference is 0. So the term below is only added for b > 0. - if (b > 0) { - blas.daxpy(k, confidence, da, 1, atb, 1) + blas.dspr(upper, k, c, da, 1, ata) + if (b != 0.0) { + blas.daxpy(k, c * b, da, 1, atb, 1) } this } @@ -399,7 +467,6 @@ object ALS extends Logging { require(other.k == k) blas.daxpy(ata.length, 1.0, other.ata, 1, ata, 1) blas.daxpy(atb.length, 1.0, other.atb, 1, atb, 1) - n += other.n this } @@ -407,7 +474,6 @@ object ALS extends Logging { def reset(): Unit = { ju.Arrays.fill(ata, 0.0) ju.Arrays.fill(atb, 0.0) - n = 0 } } @@ -426,13 +492,14 @@ object ALS extends Logging { nonnegative: Boolean = false, intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK, finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK, + checkpointInterval: Int = 10, seed: Long = 0L)( implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = { require(intermediateRDDStorageLevel != StorageLevel.NONE, "ALS is not designed to run without persisting intermediate RDDs.") val sc = ratings.sparkContext - val userPart = new HashPartitioner(numUserBlocks) - val itemPart = new HashPartitioner(numItemBlocks) + val userPart = new ALSPartitioner(numUserBlocks) + val itemPart = new ALSPartitioner(numItemBlocks) val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions) val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions) val solver = if (nonnegative) new NNLSSolver else new CholeskySolver @@ -453,6 +520,18 @@ object ALS extends Logging { val seedGen = new XORShiftRandom(seed) var userFactors = initialize(userInBlocks, rank, seedGen.nextLong()) var itemFactors = initialize(itemInBlocks, rank, seedGen.nextLong()) + var previousCheckpointFile: Option[String] = None + val shouldCheckpoint: Int => Boolean = (iter) => + sc.checkpointDir.isDefined && (iter % checkpointInterval == 0) + val deletePreviousCheckpointFile: () => Unit = () => + previousCheckpointFile.foreach { file => + try { + FileSystem.get(sc.hadoopConfiguration).delete(new Path(file), true) + } catch { + case e: IOException => + logWarning(s"Cannot delete checkpoint file $file:", e) + } + } if (implicitPrefs) { for (iter <- 1 to maxIter) { userFactors.setName(s"userFactors-$iter").persist(intermediateRDDStorageLevel) @@ -460,19 +539,30 @@ object ALS extends Logging { itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam, userLocalIndexEncoder, implicitPrefs, alpha, solver) previousItemFactors.unpersist() - if (sc.checkpointDir.isDefined && (iter % 3 == 0)) { - itemFactors.checkpoint() - } itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel) + // TODO: Generalize PeriodicGraphCheckpointer and use it here. + if (shouldCheckpoint(iter)) { + itemFactors.checkpoint() // itemFactors gets materialized in computeFactors. + } val previousUserFactors = userFactors userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam, itemLocalIndexEncoder, implicitPrefs, alpha, solver) + if (shouldCheckpoint(iter)) { + deletePreviousCheckpointFile() + previousCheckpointFile = itemFactors.getCheckpointFile + } previousUserFactors.unpersist() } } else { for (iter <- 0 until maxIter) { itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam, userLocalIndexEncoder, solver = solver) + if (shouldCheckpoint(iter)) { + itemFactors.checkpoint() + itemFactors.count() // checkpoint item factors and cut lineage + deletePreviousCheckpointFile() + previousCheckpointFile = itemFactors.getCheckpointFile + } userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam, itemLocalIndexEncoder, solver = solver) } @@ -480,13 +570,23 @@ object ALS extends Logging { val userIdAndFactors = userInBlocks .mapValues(_.srcIds) .join(userFactors) - .values + .mapPartitions({ items => + items.flatMap { case (_, (ids, factors)) => + ids.view.zip(factors) + } + // Preserve the partitioning because IDs are consistent with the partitioners in userInBlocks + // and userFactors. + }, preservesPartitioning = true) .setName("userFactors") .persist(finalRDDStorageLevel) val itemIdAndFactors = itemInBlocks .mapValues(_.srcIds) .join(itemFactors) - .values + .mapPartitions({ items => + items.flatMap { case (_, (ids, factors)) => + ids.view.zip(factors) + } + }, preservesPartitioning = true) .setName("itemFactors") .persist(finalRDDStorageLevel) if (finalRDDStorageLevel != StorageLevel.NONE) { @@ -499,13 +599,7 @@ object ALS extends Logging { itemOutBlocks.unpersist() blockRatings.unpersist() } - val userOutput = userIdAndFactors.flatMap { case (ids, factors) => - ids.view.zip(factors) - } - val itemOutput = itemIdAndFactors.flatMap { case (ids, factors) => - ids.view.zip(factors) - } - (userOutput, itemOutput) + (userIdAndFactors, itemIdAndFactors) } /** @@ -925,15 +1019,15 @@ object ALS extends Logging { "Converting to local indices took " + (System.nanoTime() - start) / 1e9 + " seconds.") val dstLocalIndices = dstIds.map(dstIdToLocalIndex.apply) (srcBlockId, (dstBlockId, srcIds, dstLocalIndices, ratings)) - }.groupByKey(new HashPartitioner(srcPart.numPartitions)) - .mapValues { iter => - val builder = - new UncompressedInBlockBuilder[ID](new LocalIndexEncoder(dstPart.numPartitions)) - iter.foreach { case (dstBlockId, srcIds, dstLocalIndices, ratings) => - builder.add(dstBlockId, srcIds, dstLocalIndices, ratings) - } - builder.build().compress() - }.setName(prefix + "InBlocks") + }.groupByKey(new ALSPartitioner(srcPart.numPartitions)) + .mapValues { iter => + val builder = + new UncompressedInBlockBuilder[ID](new LocalIndexEncoder(dstPart.numPartitions)) + iter.foreach { case (dstBlockId, srcIds, dstLocalIndices, ratings) => + builder.add(dstBlockId, srcIds, dstLocalIndices, ratings) + } + builder.build().compress() + }.setName(prefix + "InBlocks") .persist(storageLevel) val outBlocks = inBlocks.mapValues { case InBlock(srcIds, dstPtrs, dstEncodedIndices, _) => val encoder = new LocalIndexEncoder(dstPart.numPartitions) @@ -994,7 +1088,7 @@ object ALS extends Logging { (dstBlockId, (srcBlockId, activeIndices.map(idx => srcFactors(idx)))) } } - val merged = srcOut.groupByKey(new HashPartitioner(dstInBlocks.partitions.length)) + val merged = srcOut.groupByKey(new ALSPartitioner(dstInBlocks.partitions.length)) dstInBlocks.join(merged).mapValues { case (InBlock(dstIds, srcPtrs, srcEncodedIndices, ratings), srcFactors) => val sortedSrcFactors = new Array[FactorBlock](numSrcBlocks) @@ -1010,6 +1104,7 @@ object ALS extends Logging { ls.merge(YtY.get) } var i = srcPtrs(j) + var numExplicits = 0 while (i < srcPtrs(j + 1)) { val encoded = srcEncodedIndices(i) val blockId = srcEncoder.blockId(encoded) @@ -1017,13 +1112,23 @@ object ALS extends Logging { val srcFactor = sortedSrcFactors(blockId)(localIndex) val rating = ratings(i) if (implicitPrefs) { - ls.addImplicit(srcFactor, rating, alpha) + // Extension to the original paper to handle b < 0. confidence is a function of |b| + // instead so that it is never negative. c1 is confidence - 1.0. + val c1 = alpha * math.abs(rating) + // For rating <= 0, the corresponding preference is 0. So the term below is only added + // for rating > 0. Because YtY is already added, we need to adjust the scaling here. + if (rating > 0) { + numExplicits += 1 + ls.add(srcFactor, (c1 + 1.0) / c1, c1) + } } else { ls.add(srcFactor, rating) + numExplicits += 1 } i += 1 } - dstFactors(j) = solver.solve(ls, regParam) + // Weight lambda by the number of explicit ratings based on the ALS-WR paper. + dstFactors(j) = solver.solve(ls, numExplicits * regParam) j += 1 } dstFactors @@ -1037,7 +1142,7 @@ object ALS extends Logging { private def computeYtY(factorBlocks: RDD[(Int, FactorBlock)], rank: Int): NormalEquation = { factorBlocks.values.aggregate(new NormalEquation(rank))( seqOp = (ne, factors) => { - factors.foreach(ne.add(_, 0.0f)) + factors.foreach(ne.add(_, 0.0)) ne }, combOp = (ne1, ne2) => ne1.merge(ne2)) @@ -1079,4 +1184,11 @@ object ALS extends Logging { encoded & localIndexMask } } + + /** + * Partitioner used by ALS. We requires that getPartition is a projection. That is, for any key k, + * we have getPartition(getPartition(k)) = getPartition(k). Since the the default HashPartitioner + * satisfies this requirement, we simply use a type alias here. + */ + private[recommendation] type ALSPartitioner = org.apache.spark.HashPartitioner } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala new file mode 100644 index 0000000000000..65f6627a0c351 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param.{Params, ParamMap, HasMaxIter, HasRegParam} +import org.apache.spark.mllib.linalg.{BLAS, Vector} +import org.apache.spark.mllib.regression.LinearRegressionWithSGD +import org.apache.spark.sql.DataFrame +import org.apache.spark.storage.StorageLevel + + +/** + * Params for linear regression. + */ +private[regression] trait LinearRegressionParams extends RegressorParams + with HasRegParam with HasMaxIter + + +/** + * :: AlphaComponent :: + * + * Linear regression. + */ +@AlphaComponent +class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel] + with LinearRegressionParams { + + setRegParam(0.1) + setMaxIter(100) + + /** @group setParam */ + def setRegParam(value: Double): this.type = set(regParam, value) + + /** @group setParam */ + def setMaxIter(value: Int): this.type = set(maxIter, value) + + override protected def train(dataset: DataFrame, paramMap: ParamMap): LinearRegressionModel = { + // Extract columns from data. If dataset is persisted, do not persist oldDataset. + val oldDataset = extractLabeledPoints(dataset, paramMap) + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + if (handlePersistence) { + oldDataset.persist(StorageLevel.MEMORY_AND_DISK) + } + + // Train model + val lr = new LinearRegressionWithSGD() + lr.optimizer + .setRegParam(paramMap(regParam)) + .setNumIterations(paramMap(maxIter)) + val model = lr.run(oldDataset) + val lrm = new LinearRegressionModel(this, paramMap, model.weights, model.intercept) + + if (handlePersistence) { + oldDataset.unpersist() + } + lrm + } +} + +/** + * :: AlphaComponent :: + * + * Model produced by [[LinearRegression]]. + */ +@AlphaComponent +class LinearRegressionModel private[ml] ( + override val parent: LinearRegression, + override val fittingParamMap: ParamMap, + val weights: Vector, + val intercept: Double) + extends RegressionModel[Vector, LinearRegressionModel] + with LinearRegressionParams { + + override protected def predict(features: Vector): Double = { + BLAS.dot(features, weights) + intercept + } + + override protected def copy(): LinearRegressionModel = { + val m = new LinearRegressionModel(parent, fittingParamMap, weights, intercept) + Params.inheritValues(this.paramMap, this, m) + m + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala new file mode 100644 index 0000000000000..d679085eeafe1 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.annotation.{DeveloperApi, AlphaComponent} +import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams} + +/** + * :: DeveloperApi :: + * Params for regression. + * Currently empty, but may add functionality later. + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@DeveloperApi +private[spark] trait RegressorParams extends PredictorParams + +/** + * :: AlphaComponent :: + * + * Single-label regression + * + * @tparam FeaturesType Type of input features. E.g., [[org.apache.spark.mllib.linalg.Vector]] + * @tparam Learner Concrete Estimator type + * @tparam M Concrete Model type + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] abstract class Regressor[ + FeaturesType, + Learner <: Regressor[FeaturesType, Learner, M], + M <: RegressionModel[FeaturesType, M]] + extends Predictor[FeaturesType, Learner, M] + with RegressorParams { + + // TODO: defaultEvaluator (follow-up PR) +} + +/** + * :: AlphaComponent :: + * + * Model produced by a [[Regressor]]. + * + * @tparam FeaturesType Type of input features. E.g., [[org.apache.spark.mllib.linalg.Vector]] + * @tparam M Concrete Model type. + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] abstract class RegressionModel[FeaturesType, M <: RegressionModel[FeaturesType, M]] + extends PredictionModel[FeaturesType, M] with RegressorParams { + + /** + * :: DeveloperApi :: + * + * Predict real-valued label for the given features. + * This internal method is used to implement [[transform()]] and output [[predictionCol]]. + */ + @DeveloperApi + protected def predict(features: FeaturesType): Double + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 5d51c51346665..2eb1dac56f1e9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -31,22 +31,42 @@ import org.apache.spark.sql.types.StructType * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ private[ml] trait CrossValidatorParams extends Params { - /** param for the estimator to be cross-validated */ + /** + * param for the estimator to be cross-validated + * @group param + */ val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection") + + /** @group getParam */ def getEstimator: Estimator[_] = get(estimator) - /** param for estimator param maps */ + /** + * param for estimator param maps + * @group param + */ val estimatorParamMaps: Param[Array[ParamMap]] = new Param(this, "estimatorParamMaps", "param maps for the estimator") + + /** @group getParam */ def getEstimatorParamMaps: Array[ParamMap] = get(estimatorParamMaps) - /** param for the evaluator for selection */ + /** + * param for the evaluator for selection + * @group param + */ val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection") + + /** @group getParam */ def getEvaluator: Evaluator = get(evaluator) - /** param for number of folds for cross validation */ + /** + * param for number of folds for cross validation + * @group param + */ val numFolds: IntParam = new IntParam(this, "numFolds", "number of folds for cross validation", Some(3)) + + /** @group getParam */ def getNumFolds: Int = get(numFolds) } @@ -59,9 +79,16 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP private val f2jBLAS = new F2jBLAS + /** @group setParam */ def setEstimator(value: Estimator[_]): this.type = set(estimator, value) + + /** @group setParam */ def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value) + + /** @group setParam */ def setEvaluator(value: Evaluator): this.type = set(evaluator, value) + + /** @group setParam */ def setNumFolds(value: Int): this.type = set(numFolds, value) override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = { @@ -76,11 +103,12 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP val metrics = new Array[Double](epm.size) val splits = MLUtils.kFold(dataset.rdd, map(numFolds), 0) splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => - val trainingDataset = sqlCtx.applySchema(training, schema).cache() - val validationDataset = sqlCtx.applySchema(validation, schema).cache() + val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() + val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() // multi-model training logDebug(s"Train split $splitIndex with multiple sets of parameters.") val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] + trainingDataset.unpersist() var i = 0 while (i < numModels) { val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)), map) @@ -88,6 +116,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP metrics(i) += metric i += 1 } + validationDataset.unpersist() } f2jBLAS.dscal(numModels, 1.0 / map(numFolds), metrics, 1) logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") @@ -100,7 +129,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP cvModel } - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { val map = this.paramMap ++ paramMap map(estimator).transformSchema(schema, paramMap) } @@ -121,7 +150,7 @@ class CrossValidatorModel private[ml] ( bestModel.transform(dataset, paramMap) } - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { bestModel.transformSchema(schema, paramMap) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala new file mode 100644 index 0000000000000..ecd3b16598438 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python + +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.recommendation.{MatrixFactorizationModel, Rating} +import org.apache.spark.rdd.RDD + +/** + * A Wrapper of MatrixFactorizationModel to provide helper method for Python. + */ +private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorizationModel) + extends MatrixFactorizationModel(model.rank, model.userFeatures, model.productFeatures) { + + def predict(userAndProducts: JavaRDD[Array[Any]]): RDD[Rating] = + predict(SerDe.asTupleRDD(userAndProducts.rdd)) + + def getUserFeatures: RDD[Array[Any]] = { + SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]]) + } + + def getProductFeatures: RDD[Array[Any]] = { + SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]]) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 980980593d194..74cbd6f5abf6d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -54,12 +54,9 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils /** - * :: DeveloperApi :: - * The Java stubs necessary for the Python mllib bindings. + * The Java stubs necessary for the Python mllib bindings. It is called by Py4J on the Python side. */ -@DeveloperApi -class PythonMLLibAPI extends Serializable { - +private[python] class PythonMLLibAPI extends Serializable { /** * Loads and serializes labeled points saved with `RDD#saveAsTextFile`. @@ -295,7 +292,7 @@ class PythonMLLibAPI extends Serializable { k: Int, convergenceTol: Double, maxIterations: Int, - seed: Long): JList[Object] = { + seed: java.lang.Long): JList[Object] = { val gmmAlg = new GaussianMixture() .setK(k) .setConvergenceTol(convergenceTol) @@ -338,21 +335,6 @@ class PythonMLLibAPI extends Serializable { model.predictSoft(data) } - /** - * A Wrapper of MatrixFactorizationModel to provide helpfer method for Python - */ - private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorizationModel) - extends MatrixFactorizationModel(model.rank, model.userFeatures, model.productFeatures) { - - def predict(userAndProducts: JavaRDD[Array[Any]]): RDD[Rating] = - predict(SerDe.asTupleRDD(userAndProducts.rdd)) - - def getUserFeatures = SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]]) - - def getProductFeatures = SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]]) - - } - /** * Java stub for Python mllib ALS.train(). This stub returns a handle * to the Java object instead of the content of the Java object. Extra care @@ -1105,7 +1087,10 @@ private[spark] object SerDe extends Serializable { iter.flatMap { row => val obj = unpickle.loads(row) if (batched) { - obj.asInstanceOf[JArrayList[_]].asScala + obj match { + case list: JArrayList[_] => list.asScala + case arr: Array[_] => arr + } } else { Seq(obj) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala index b7a1d90d24d72..35a0db76f3a8c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.classification +import org.json4s.{DefaultFormats, JValue} + import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector @@ -53,3 +55,15 @@ trait ClassificationModel extends Serializable { def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } + +private[mllib] object ClassificationModel { + + /** + * Helper method for loading GLM classification model metadata. + * @return (numFeatures, numClasses) + */ + def getNumFeaturesClasses(metadata: JValue): (Int, Int) = { + implicit val formats = DefaultFormats + ((metadata \ "numFeatures").extract[Int], (metadata \ "numClasses").extract[Int]) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index a469315a1b5c3..b787667b018e6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -17,20 +17,23 @@ package org.apache.spark.mllib.classification +import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.linalg.{DenseVector, Vector} import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{DataValidators, MLUtils} +import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader} import org.apache.spark.rdd.RDD + /** * Classification model trained using Multinomial/Binary Logistic Regression. * * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. (Only used in Binary Logistic Regression. - * In Multinomial Logistic Regression, the intercepts will not be a single values, + * In Multinomial Logistic Regression, the intercepts will not be a single value, * so the intercepts will be part of the weights.) * @param numFeatures the dimension of the features. * @param numClasses the number of possible outcomes for k classes classification problem in @@ -42,8 +45,26 @@ class LogisticRegressionModel ( override val intercept: Double, val numFeatures: Int, val numClasses: Int) - extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable { + extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable + with Saveable { + + if (numClasses == 2) { + require(weights.size == numFeatures, + s"LogisticRegressionModel with numClasses = 2 was given non-matching values:" + + s" numFeatures = $numFeatures, but weights.size = ${weights.size}") + } else { + val weightsSizeWithoutIntercept = (numClasses - 1) * numFeatures + val weightsSizeWithIntercept = (numClasses - 1) * (numFeatures + 1) + require(weights.size == weightsSizeWithoutIntercept || weights.size == weightsSizeWithIntercept, + s"LogisticRegressionModel.load with numClasses = $numClasses and numFeatures = $numFeatures" + + s" expected weights of length $weightsSizeWithoutIntercept (without intercept)" + + s" or $weightsSizeWithIntercept (with intercept)," + + s" but was given weights of length ${weights.size}") + } + /** + * Constructs a [[LogisticRegressionModel]] with weights and intercept for binary classification. + */ def this(weights: Vector, intercept: Double) = this(weights, intercept, weights.size, 2) private var threshold: Option[Double] = Some(0.5) @@ -60,6 +81,13 @@ class LogisticRegressionModel ( this } + /** + * :: Experimental :: + * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. + */ + @Experimental + def getThreshold: Option[Double] = threshold + /** * :: Experimental :: * Clears the threshold so that `predict` will output raw prediction scores. @@ -70,14 +98,16 @@ class LogisticRegressionModel ( this } - override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, + override protected def predictPoint( + dataMatrix: Vector, + weightMatrix: Vector, intercept: Double) = { require(dataMatrix.size == numFeatures) // If dataMatrix and weightMatrix have the same dimension, it's binary logistic regression. if (numClasses == 2) { require(numFeatures == weightMatrix.size) - val margin = dot(weights, dataMatrix) + intercept + val margin = dot(weightMatrix, dataMatrix) + intercept val score = 1.0 / (1.0 + math.exp(-margin)) threshold match { case Some(t) => if (score > t) 1.0 else 0.0 @@ -86,11 +116,11 @@ class LogisticRegressionModel ( } else { val dataWithBiasSize = weightMatrix.size / (numClasses - 1) - val weightsArray = weights match { + val weightsArray = weightMatrix match { case dv: DenseVector => dv.values case _ => throw new IllegalArgumentException( - s"weights only supports dense vector but got type ${weights.getClass}.") + s"weights only supports dense vector but got type ${weightMatrix.getClass}.") } val margins = (0 until numClasses - 1).map { i => @@ -126,6 +156,39 @@ class LogisticRegressionModel ( bestClass.toDouble } } + + override def save(sc: SparkContext, path: String): Unit = { + GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, + numFeatures, numClasses, weights, intercept, threshold) + } + + override protected def formatVersion: String = "1.0" +} + +object LogisticRegressionModel extends Loader[LogisticRegressionModel] { + + override def load(sc: SparkContext, path: String): LogisticRegressionModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future + val classNameV1_0 = "org.apache.spark.mllib.classification.LogisticRegressionModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata) + val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0) + // numFeatures, numClasses, weights are checked in model initialization + val model = + new LogisticRegressionModel(data.weights, data.intercept, numFeatures, numClasses) + data.threshold match { + case Some(t) => model.setThreshold(t) + case None => model.clearThreshold() + } + model + case _ => throw new Exception( + s"LogisticRegressionModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } /** @@ -292,6 +355,10 @@ class LogisticRegressionWithLBFGS } override protected def createModel(weights: Vector, intercept: Double) = { - new LogisticRegressionModel(weights, intercept, numFeatures, numOfLinearPredictor + 1) + if (numOfLinearPredictor == 1) { + new LogisticRegressionModel(weights, intercept) + } else { + new LogisticRegressionModel(weights, intercept, numFeatures, numOfLinearPredictor + 1) + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index a967df857bed3..b11fd4f128c56 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -18,12 +18,15 @@ package org.apache.spark.mllib.classification import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ -import org.apache.spark.{SparkException, Logging} -import org.apache.spark.SparkContext._ +import org.apache.spark.{Logging, SparkContext, SparkException} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, SQLContext} /** * Model for Naive Bayes Classifiers. @@ -36,7 +39,7 @@ import org.apache.spark.rdd.RDD class NaiveBayesModel private[mllib] ( val labels: Array[Double], val pi: Array[Double], - val theta: Array[Array[Double]]) extends ClassificationModel with Serializable { + val theta: Array[Array[Double]]) extends ClassificationModel with Serializable with Saveable { private val brzPi = new BDV[Double](pi) private val brzTheta = new BDM[Double](theta.length, theta(0).length) @@ -65,6 +68,84 @@ class NaiveBayesModel private[mllib] ( override def predict(testData: Vector): Double = { labels(brzArgmax(brzPi + brzTheta * testData.toBreeze)) } + + override def save(sc: SparkContext, path: String): Unit = { + val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta) + NaiveBayesModel.SaveLoadV1_0.save(sc, path, data) + } + + override protected def formatVersion: String = "1.0" +} + +object NaiveBayesModel extends Loader[NaiveBayesModel] { + + import org.apache.spark.mllib.util.Loader._ + + private object SaveLoadV1_0 { + + def thisFormatVersion = "1.0" + + /** Hard-code class name string in case it changes in the future */ + def thisClassName = "org.apache.spark.mllib.classification.NaiveBayesModel" + + /** Model data for model import/export */ + case class Data(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) + + def save(sc: SparkContext, path: String, data: Data): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Create JSON metadata. + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) + + // Create Parquet data. + val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF() + dataRDD.saveAsParquetFile(dataPath(path)) + } + + def load(sc: SparkContext, path: String): NaiveBayesModel = { + val sqlContext = new SQLContext(sc) + // Load Parquet data. + val dataRDD = sqlContext.parquetFile(dataPath(path)) + // Check schema explicitly since erasure makes it hard to use match-case for checking. + checkSchema[Data](dataRDD.schema) + val dataArray = dataRDD.select("labels", "pi", "theta").take(1) + assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") + val data = dataArray(0) + val labels = data.getAs[Seq[Double]](0).toArray + val pi = data.getAs[Seq[Double]](1).toArray + val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray + new NaiveBayesModel(labels, pi, theta) + } + } + + override def load(sc: SparkContext, path: String): NaiveBayesModel = { + val (loadedClassName, version, metadata) = loadMetadata(sc, path) + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata) + val model = SaveLoadV1_0.load(sc, path) + assert(model.pi.size == numClasses, + s"NaiveBayesModel.load expected $numClasses classes," + + s" but class priors vector pi had ${model.pi.size} elements") + assert(model.theta.size == numClasses, + s"NaiveBayesModel.load expected $numClasses classes," + + s" but class conditionals array theta had ${model.theta.size} elements") + assert(model.theta.forall(_.size == numFeatures), + s"NaiveBayesModel.load expected $numFeatures features," + + s" but class conditionals array theta had elements of size:" + + s" ${model.theta.map(_.size).mkString(",")}") + model + case _ => throw new Exception( + s"NaiveBayesModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index dd514ff8a37f2..cfc7f868a02f0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -17,11 +17,13 @@ package org.apache.spark.mllib.classification +import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.DataValidators +import org.apache.spark.mllib.util.{DataValidators, Loader, Saveable} import org.apache.spark.rdd.RDD /** @@ -33,7 +35,8 @@ import org.apache.spark.rdd.RDD class SVMModel ( override val weights: Vector, override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable { + extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable + with Saveable { private var threshold: Option[Double] = Some(0.0) @@ -49,6 +52,13 @@ class SVMModel ( this } + /** + * :: Experimental :: + * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. + */ + @Experimental + def getThreshold: Option[Double] = threshold + /** * :: Experimental :: * Clears the threshold so that `predict` will output raw prediction scores. @@ -69,6 +79,41 @@ class SVMModel ( case None => margin } } + + override def save(sc: SparkContext, path: String): Unit = { + GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, + numFeatures = weights.size, numClasses = 2, weights, intercept, threshold) + } + + override protected def formatVersion: String = "1.0" +} + +object SVMModel extends Loader[SVMModel] { + + override def load(sc: SparkContext, path: String): SVMModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future + val classNameV1_0 = "org.apache.spark.mllib.classification.SVMModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata) + val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0) + val model = new SVMModel(data.weights, data.intercept) + assert(model.weights.size == numFeatures, s"SVMModel.load with numFeatures=$numFeatures" + + s" was given non-matching weights vector of size ${model.weights.size}") + assert(numClasses == 2, + s"SVMModel.load was given numClasses=$numClasses but only supports 2 classes") + data.threshold match { + case Some(t) => model.setThreshold(t) + case None => model.clearThreshold() + } + model + case _ => throw new Exception( + s"SVMModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala index 6a3893d0e41d2..7d33df3221fbf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala @@ -35,12 +35,13 @@ import org.apache.spark.mllib.regression.StreamingLinearAlgorithm * Use a builder pattern to construct a streaming logistic regression * analysis in an application, like: * + * {{{ * val model = new StreamingLogisticRegressionWithSGD() * .setStepSize(0.5) * .setNumIterations(10) * .setInitialWeights(Vectors.dense(...)) * .trainOn(DStream) - * + * }}} */ @Experimental class StreamingLogisticRegressionWithSGD private[mllib] ( @@ -59,9 +60,11 @@ class StreamingLogisticRegressionWithSGD private[mllib] ( */ def this() = this(0.1, 50, 1.0, 0.0) - val algorithm = new LogisticRegressionWithSGD( + protected val algorithm = new LogisticRegressionWithSGD( stepSize, numIterations, regParam, miniBatchFraction) + protected var model: Option[LogisticRegressionModel] = None + /** Set the step size for gradient descent. Default: 0.1. */ def setStepSize(stepSize: Double): this.type = { this.algorithm.optimizer.setStepSize(stepSize) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala new file mode 100644 index 0000000000000..8956189ff1158 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification.impl + +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.util.Loader +import org.apache.spark.sql.{Row, SQLContext} + +/** + * Helper class for import/export of GLM classification models. + */ +private[classification] object GLMClassificationModel { + + object SaveLoadV1_0 { + + def thisFormatVersion = "1.0" + + /** Model data for import/export */ + case class Data(weights: Vector, intercept: Double, threshold: Option[Double]) + + /** + * Helper method for saving GLM classification model metadata and data. + * @param modelClass String name for model class, to be saved with metadata + * @param numClasses Number of classes label can take, to be saved with metadata + */ + def save( + sc: SparkContext, + path: String, + modelClass: String, + numFeatures: Int, + numClasses: Int, + weights: Vector, + intercept: Double, + threshold: Option[Double]): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Create JSON metadata. + val metadata = compact(render( + ("class" -> modelClass) ~ ("version" -> thisFormatVersion) ~ + ("numFeatures" -> numFeatures) ~ ("numClasses" -> numClasses))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Create Parquet data. + val data = Data(weights, intercept, threshold) + sc.parallelize(Seq(data), 1).toDF().saveAsParquetFile(Loader.dataPath(path)) + } + + /** + * Helper method for loading GLM classification model data. + * + * NOTE: Callers of this method should check numClasses, numFeatures on their own. + * + * @param modelClass String name for model class (used for error messages) + */ + def loadData(sc: SparkContext, path: String, modelClass: String): Data = { + val datapath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + val dataRDD = sqlContext.parquetFile(datapath) + val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1) + assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath") + val data = dataArray(0) + assert(data.size == 3, s"Unable to load $modelClass data from: $datapath") + val (weights, intercept) = data match { + case Row(weights: Vector, intercept: Double, _) => + (weights, intercept) + } + val threshold = if (data.isNullAt(2)) { + None + } else { + Some(data.getDouble(2)) + } + Data(weights, intercept, threshold) + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index 5c626fde4e657..568b65305649f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -19,15 +19,18 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable.IndexedSeq -import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix, diag, Transpose} +import breeze.linalg.{diag, DenseMatrix => BreezeMatrix, DenseVector => BDV, Vector => BV} -import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors, DenseVector, DenseMatrix, BLAS} +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, Matrices, Vector, Vectors} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils /** + * :: Experimental :: + * * This class performs expectation maximization for multivariate Gaussian * Mixture Models (GMMs). A GMM represents a composite distribution of * independent Gaussian distributions with associated "mixing" weights @@ -38,19 +41,27 @@ import org.apache.spark.util.Utils * less than convergenceTol, or until it has reached the max number of iterations. * While this process is generally guaranteed to converge, it is not guaranteed * to find a global optimum. - * + * + * Note: For high-dimensional data (with many features), this algorithm may perform poorly. + * This is due to high-dimensional data (a) making it difficult to cluster at all (based + * on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions. + * * @param k The number of independent Gaussians in the mixture model * @param convergenceTol The maximum change in log-likelihood at which convergence * is considered to have occurred. * @param maxIterations The maximum number of iterations to perform */ +@Experimental class GaussianMixture private ( private var k: Int, private var convergenceTol: Double, private var maxIterations: Int, private var seed: Long) extends Serializable { - /** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */ + /** + * Constructs a default instance. The default parameters are {k: 2, convergenceTol: 0.01, + * maxIterations: 100, seed: random}. + */ def this() = this(2, 0.01, 100, Utils.random.nextLong()) // number of samples per cluster to use when initializing Gaussians @@ -123,7 +134,7 @@ class GaussianMixture private ( val sc = data.sparkContext // we will operate on the data as breeze data - val breezeData = data.map(u => u.toBreeze.toDenseVector).cache() + val breezeData = data.map(_.toBreeze).cache() // Get length of the input vectors val d = breezeData.first().length @@ -141,7 +152,7 @@ class GaussianMixture private ( (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => val slice = samples.view(i * nSamples, (i + 1) * nSamples) new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) - }) + }) } } @@ -162,7 +173,7 @@ class GaussianMixture private ( var i = 0 while (i < k) { val mu = sums.means(i) / sums.weights(i) - BLAS.syr(-sums.weights(i), Vectors.fromBreeze(mu).asInstanceOf[DenseVector], + BLAS.syr(-sums.weights(i), Vectors.fromBreeze(mu), Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix]) weights(i) = sums.weights(i) / sumWeights gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i)) @@ -178,8 +189,8 @@ class GaussianMixture private ( } /** Average of dense breeze vectors */ - private def vectorMean(x: IndexedSeq[BreezeVector[Double]]): BreezeVector[Double] = { - val v = BreezeVector.zeros[Double](x(0).length) + private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = { + val v = BDV.zeros[Double](x(0).length) x.foreach(xi => v += xi) v / x.length.toDouble } @@ -188,10 +199,10 @@ class GaussianMixture private ( * Construct matrix where diagonal entries are element-wise * variance of input vectors (computes biased variance) */ - private def initCovariance(x: IndexedSeq[BreezeVector[Double]]): BreezeMatrix[Double] = { + private def initCovariance(x: IndexedSeq[BV[Double]]): BreezeMatrix[Double] = { val mu = vectorMean(x) - val ss = BreezeVector.zeros[Double](x(0).length) - x.map(xi => (xi - mu) :^ 2.0).foreach(u => ss += u) + val ss = BDV.zeros[Double](x(0).length) + x.foreach(xi => ss += (xi - mu) :^ 2.0) diag(ss / x.length.toDouble) } } @@ -200,7 +211,7 @@ class GaussianMixture private ( private object ExpectationSum { def zero(k: Int, d: Int): ExpectationSum = { new ExpectationSum(0.0, Array.fill(k)(0.0), - Array.fill(k)(BreezeVector.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d))) + Array.fill(k)(BDV.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d))) } // compute cluster contributions for each input point @@ -208,19 +219,18 @@ private object ExpectationSum { def add( weights: Array[Double], dists: Array[MultivariateGaussian]) - (sums: ExpectationSum, x: BreezeVector[Double]): ExpectationSum = { + (sums: ExpectationSum, x: BV[Double]): ExpectationSum = { val p = weights.zip(dists).map { case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(x) } val pSum = p.sum sums.logLikelihood += math.log(pSum) - val xxt = x * new Transpose(x) var i = 0 while (i < sums.k) { p(i) /= pSum sums.weights(i) += p(i) sums.means(i) += x * p(i) - BLAS.syr(p(i), Vectors.fromBreeze(x).asInstanceOf[DenseVector], + BLAS.syr(p(i), Vectors.fromBreeze(x), Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix]) i = i + 1 } @@ -232,7 +242,7 @@ private object ExpectationSum { private class ExpectationSum( var logLikelihood: Double, val weights: Array[Double], - val means: Array[BreezeVector[Double]], + val means: Array[BDV[Double]], val sigmas: Array[BreezeMatrix[Double]]) extends Serializable { val k = weights.length diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 1a2178ee7f711..af6f83c74bb40 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -19,12 +19,15 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseVector => BreezeVector} -import org.apache.spark.rdd.RDD +import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD /** + * :: Experimental :: + * * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are * the respective mean and covariance for each Gaussian distribution i=1..k. @@ -35,6 +38,7 @@ import org.apache.spark.mllib.util.MLUtils * @param sigma Covariance maxtrix for each Gaussian in the mixture, where sigma(i) is the * covariance matrix for Gaussian i */ +@Experimental class GaussianMixtureModel( val weights: Array[Double], val gaussians: Array[MultivariateGaussian]) extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index d8f82867a09d2..5e17c8da61134 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -52,6 +52,9 @@ import org.apache.spark.util.Utils * - Paper which clearly explains several algorithms, including EM: * Asuncion, Welling, Smyth, and Teh. * "On Smoothing and Inference for Topic Models." UAI, 2009. + * + * @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent Dirichlet allocation + * (Wikipedia)]] */ @Experimental class LDA private ( @@ -60,11 +63,10 @@ class LDA private ( private var docConcentration: Double, private var topicConcentration: Double, private var seed: Long, - private var checkpointDir: Option[String], private var checkpointInterval: Int) extends Logging { def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1, - seed = Utils.random.nextLong(), checkpointDir = None, checkpointInterval = 10) + seed = Utils.random.nextLong(), checkpointInterval = 10) /** * Number of topics to infer. I.e., the number of soft cluster centers. @@ -200,50 +202,18 @@ class LDA private ( this } - /** - * Directory for storing checkpoint files during learning. - * This is not necessary, but checkpointing helps with recovery (when nodes fail). - * It also helps with eliminating temporary shuffle files on disk, which can be important when - * LDA is run for many iterations. - */ - def getCheckpointDir: Option[String] = checkpointDir - - /** - * Directory for storing checkpoint files during learning. - * This is not necessary, but checkpointing helps with recovery (when nodes fail). - * It also helps with eliminating temporary shuffle files on disk, which can be important when - * LDA is run for many iterations. - * - * NOTE: If the [[org.apache.spark.SparkContext.checkpointDir]] is already set, then the value - * given to LDA is ignored, and the existing directory is kept. - * - * (default = None) - */ - def setCheckpointDir(checkpointDir: String): this.type = { - this.checkpointDir = Some(checkpointDir) - this - } - - /** - * Clear the directory for storing checkpoint files during learning. - * If one is already set in the [[org.apache.spark.SparkContext]], then checkpointing will still - * occur; otherwise, no checkpointing will be used. - */ - def clearCheckpointDir(): this.type = { - this.checkpointDir = None - this - } - /** * Period (in iterations) between checkpoints. - * @see [[getCheckpointDir]] */ def getCheckpointInterval: Int = checkpointInterval /** - * Period (in iterations) between checkpoints. - * (default = 10) - * @see [[getCheckpointDir]] + * Period (in iterations) between checkpoints (default = 10). Checkpointing helps with recovery + * (when nodes fail). It also helps with eliminating temporary shuffle files on disk, which can be + * important when LDA is run for many iterations. If the checkpoint directory is not set in + * [[org.apache.spark.SparkContext]], this setting is ignored. + * + * @see [[org.apache.spark.SparkContext#setCheckpointDir]] */ def setCheckpointInterval(checkpointInterval: Int): this.type = { this.checkpointInterval = checkpointInterval @@ -261,7 +231,7 @@ class LDA private ( */ def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = { val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed, - checkpointDir, checkpointInterval) + checkpointInterval) var iter = 0 val iterationTimes = Array.fill[Double](maxIterations)(0) while (iter < maxIterations) { @@ -337,18 +307,18 @@ private[clustering] object LDA { * Vector over topics (length k) of token counts. * The meaning of these counts can vary, and it may or may not be normalized to be a distribution. */ - type TopicCounts = BDV[Double] + private[clustering] type TopicCounts = BDV[Double] - type TokenCount = Double + private[clustering] type TokenCount = Double /** Term vertex IDs are {-1, -2, ..., -vocabSize} */ - def term2index(term: Int): Long = -(1 + term.toLong) + private[clustering] def term2index(term: Int): Long = -(1 + term.toLong) - def index2term(termIndex: Long): Int = -(1 + termIndex).toInt + private[clustering] def index2term(termIndex: Long): Int = -(1 + termIndex).toInt - def isDocumentVertex(v: (VertexId, _)): Boolean = v._1 >= 0 + private[clustering] def isDocumentVertex(v: (VertexId, _)): Boolean = v._1 >= 0 - def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0 + private[clustering] def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0 /** * Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters. @@ -360,17 +330,16 @@ private[clustering] object LDA { * @param docConcentration "alpha" * @param topicConcentration "beta" or "eta" */ - class EMOptimizer( + private[clustering] class EMOptimizer( var graph: Graph[TopicCounts, TokenCount], val k: Int, val vocabSize: Int, val docConcentration: Double, val topicConcentration: Double, - checkpointDir: Option[String], checkpointInterval: Int) { private[LDA] val graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( - graph, checkpointDir, checkpointInterval) + graph, checkpointInterval) def next(): EMOptimizer = { val eta = topicConcentration @@ -468,7 +437,6 @@ private[clustering] object LDA { docConcentration: Double, topicConcentration: Double, randomSeed: Long, - checkpointDir: Option[String], checkpointInterval: Int): EMOptimizer = { // For each document, create an edge (Document -> Term) for each unique term in the document. val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) => @@ -482,38 +450,26 @@ private[clustering] object LDA { // Create vertices. // Initially, we use random soft assignments of tokens to topics (random gamma). - val edgesWithGamma: RDD[(Edge[TokenCount], TopicCounts)] = - edges.mapPartitionsWithIndex { case (partIndex, partEdges) => - val random = new Random(partIndex + randomSeed) - partEdges.map { edge => - // Create a random gamma_{wjk} - (edge, normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0)) - } - } - def createVertices(sendToWhere: Edge[TokenCount] => VertexId): RDD[(VertexId, TopicCounts)] = { - val verticesTMP: RDD[(VertexId, (TokenCount, TopicCounts))] = - edgesWithGamma.map { case (edge, gamma: TopicCounts) => - (sendToWhere(edge), (edge.attr, gamma)) + def createVertices(): RDD[(VertexId, TopicCounts)] = { + val verticesTMP: RDD[(VertexId, TopicCounts)] = + edges.mapPartitionsWithIndex { case (partIndex, partEdges) => + val random = new Random(partIndex + randomSeed) + partEdges.flatMap { edge => + val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0) + val sum = gamma * edge.attr + Seq((edge.srcId, sum), (edge.dstId, sum)) + } } - verticesTMP.aggregateByKey(BDV.zeros[Double](k))( - (sum, t) => { - brzAxpy(t._1, t._2, sum) - sum - }, - (sum0, sum1) => { - sum0 += sum1 - } - ) + verticesTMP.reduceByKey(_ + _) } - val docVertices = createVertices(_.srcId) - val termVertices = createVertices(_.dstId) + + val docTermVertices = createVertices() // Partition such that edges are grouped by document - val graph = Graph(docVertices ++ termVertices, edges) + val graph = Graph(docTermVertices, edges) .partitionBy(PartitionStrategy.EdgePartition1D) - new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointDir, - checkpointInterval) + new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointInterval) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 19e8aab6eabd7..0a3f21ecee0dc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -130,7 +130,7 @@ abstract class LDAModel private[clustering] { /* TODO * Compute the estimated topic distribution for each document. - * This is often called “theta” in the literature. + * This is often called 'theta' in the literature. * * @param documents RDD of documents, which are term (word) count vectors paired with IDs. * The term count vectors are "bags of words" with a fixed-size vocabulary @@ -335,7 +335,7 @@ class DistributedLDAModel private ( /** * For each document in the training set, return the distribution over topics for that document - * (i.e., "theta_doc"). + * ("theta_doc"). * * @return RDD of (document ID, topic distribution) pairs */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 9b5c155b0a805..180023922a9b0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -18,6 +18,8 @@ package org.apache.spark.mllib.clustering import org.apache.spark.{Logging, SparkException} +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl import org.apache.spark.mllib.linalg.Vectors @@ -26,25 +28,33 @@ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom /** + * :: Experimental :: + * * Model produced by [[PowerIterationClustering]]. * * @param k number of clusters - * @param assignments an RDD of (vertexID, clusterID) pairs + * @param assignments an RDD of clustering [[PowerIterationClustering#Assignment]]s */ +@Experimental class PowerIterationClusteringModel( val k: Int, - val assignments: RDD[(Long, Int)]) extends Serializable + val assignments: RDD[PowerIterationClustering.Assignment]) extends Serializable /** - * Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by Lin and - * Cohen (see http://www.icml2010.org/papers/387.pdf). From the abstract: PIC finds a very + * :: Experimental :: + * + * Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by + * [[http://www.icml2010.org/papers/387.pdf Lin and Cohen]]. From the abstract: PIC finds a very * low-dimensional embedding of a dataset using truncated power iteration on a normalized pair-wise * similarity matrix of the data. * * @param k Number of clusters. * @param maxIterations Maximum number of iterations of the PIC algorithm. * @param initMode Initialization mode. + * + * @see [[http://en.wikipedia.org/wiki/Spectral_clustering Spectral clustering (Wikipedia)]] */ +@Experimental class PowerIterationClustering private[clustering] ( private var k: Int, private var maxIterations: Int, @@ -88,11 +98,12 @@ class PowerIterationClustering private[clustering] ( /** * Run the PIC algorithm. * - * @param similarities an RDD of (i, j, s_ij_) tuples representing the affinity matrix, which is - * the matrix A in the PIC paper. The similarity s_ij_ must be nonnegative. - * This is a symmetric matrix and hence s_ij_ = s_ji_. For any (i, j) with - * nonzero similarity, there should be either (i, j, s_ij_) or (j, i, s_ji_) - * in the input. Tuples with i = j are ignored, because we assume s_ij_ = 0.0. + * @param similarities an RDD of (i, j, s,,ij,,) tuples representing the affinity matrix, which is + * the matrix A in the PIC paper. The similarity s,,ij,, must be nonnegative. + * This is a symmetric matrix and hence s,,ij,, = s,,ji,,. For any (i, j) with + * nonzero similarity, there should be either (i, j, s,,ij,,) or + * (j, i, s,,ji,,) in the input. Tuples with i = j are ignored, because we + * assume s,,ij,, = 0.0. * * @return a [[PowerIterationClusteringModel]] that contains the clustering result */ @@ -105,25 +116,50 @@ class PowerIterationClustering private[clustering] ( pic(w0) } + /** + * A Java-friendly version of [[PowerIterationClustering.run]]. + */ + def run(similarities: JavaRDD[(java.lang.Long, java.lang.Long, java.lang.Double)]) + : PowerIterationClusteringModel = { + run(similarities.rdd.asInstanceOf[RDD[(Long, Long, Double)]]) + } + /** * Runs the PIC algorithm. * * @param w The normalized affinity matrix, which is the matrix W in the PIC paper with - * w_ij_ = a_ij_ / d_ii_ as its edge properties and the initial vector of the power + * w,,ij,, = a,,ij,, / d,,ii,, as its edge properties and the initial vector of the power * iteration as its vertex properties. */ private def pic(w: Graph[Double, Double]): PowerIterationClusteringModel = { val v = powerIter(w, maxIterations) - val assignments = kMeans(v, k) + val assignments = kMeans(v, k).mapPartitions({ iter => + iter.map { case (id, cluster) => + new Assignment(id, cluster) + } + }, preservesPartitioning = true) new PowerIterationClusteringModel(k, assignments) } } -private[clustering] object PowerIterationClustering extends Logging { +@Experimental +object PowerIterationClustering extends Logging { + + /** + * :: Experimental :: + * Cluster assignment. + * @param id node id + * @param cluster assigned cluster id + */ + @Experimental + class Assignment(val id: Long, val cluster: Int) extends Serializable + /** * Normalizes the affinity matrix (A) by row sums and returns the normalized affinity matrix (W). */ - def normalize(similarities: RDD[(Long, Long, Double)]): Graph[Double, Double] = { + private[clustering] + def normalize(similarities: RDD[(Long, Long, Double)]) + : Graph[Double, Double] = { val edges = similarities.flatMap { case (i, j, s) => if (s < 0.0) { throw new SparkException("Similarity must be nonnegative but found s($i, $j) = $s.") @@ -154,6 +190,7 @@ private[clustering] object PowerIterationClustering extends Logging { * @return a graph with edges representing W and vertices representing a random vector * with unit 1-norm */ + private[clustering] def randomInit(g: Graph[Double, Double]): Graph[Double, Double] = { val r = g.vertices.mapPartitionsWithIndex( (part, iter) => { @@ -175,6 +212,7 @@ private[clustering] object PowerIterationClustering extends Logging { * @param g a graph representing the normalized affinity matrix (W) * @return a graph with edges representing W and vertices representing the degree vector */ + private[clustering] def initDegreeVector(g: Graph[Double, Double]): Graph[Double, Double] = { val sum = g.vertices.values.sum() val v0 = g.vertices.mapValues(_ / sum) @@ -188,6 +226,7 @@ private[clustering] object PowerIterationClustering extends Logging { * @param maxIterations maximum number of iterations * @return a [[VertexRDD]] representing the pseudo-eigenvector */ + private[clustering] def powerIter( g: Graph[Double, Double], maxIterations: Int): VertexRDD[Double] = { @@ -227,6 +266,7 @@ private[clustering] object PowerIterationClustering extends Logging { * @param k number of clusters * @return a [[VertexRDD]] representing the clustering assignments */ + private[clustering] def kMeans(v: VertexRDD[Double], k: Int): VertexRDD[Int] = { val points = v.mapValues(x => Vectors.dense(x)).cache() val model = new KMeans() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 7752c1988fdd1..f483fd1c7d2cf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -21,7 +21,7 @@ import scala.reflect.ClassTag import org.apache.spark.Logging import org.apache.spark.SparkContext._ -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{Experimental, DeveloperApi} import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.dstream.DStream @@ -29,7 +29,8 @@ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom /** - * :: DeveloperApi :: + * :: Experimental :: + * * StreamingKMeansModel extends MLlib's KMeansModel for streaming * algorithms, so it can keep track of a continuously updated weight * associated with each cluster, and also update the model by @@ -39,8 +40,10 @@ import org.apache.spark.util.random.XORShiftRandom * generalized to incorporate forgetfullness (i.e. decay). * The update rule (for each cluster) is: * + * {{{ * c_t+1 = [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t] * n_t+t = n_t * a + m_t + * }}} * * Where c_t is the previously estimated centroid for that cluster, * n_t is the number of points assigned to it thus far, x_t is the centroid @@ -61,7 +64,7 @@ import org.apache.spark.util.random.XORShiftRandom * as batches or points. * */ -@DeveloperApi +@Experimental class StreamingKMeansModel( override val clusterCenters: Array[Vector], val clusterWeights: Array[Double]) extends KMeansModel(clusterCenters) with Logging { @@ -140,7 +143,8 @@ class StreamingKMeansModel( } /** - * :: DeveloperApi :: + * :: Experimental :: + * * StreamingKMeans provides methods for configuring a * streaming k-means analysis, training the model on streaming, * and using the model to make predictions on streaming data. @@ -149,13 +153,15 @@ class StreamingKMeansModel( * Use a builder pattern to construct a streaming k-means analysis * in an application, like: * + * {{{ * val model = new StreamingKMeans() * .setDecayFactor(0.5) * .setK(3) * .setRandomCenters(5, 100.0) * .trainOn(DStream) + * }}} */ -@DeveloperApi +@Experimental class StreamingKMeans( var k: Int, var decayFactor: Double, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index a3e40200bc063..59a79e5c6a4ac 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -21,7 +21,7 @@ import java.lang.{Iterable => JavaIterable} import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.ArrayBuilder import com.github.fommil.netlib.BLAS.{getInstance => blas} @@ -272,7 +272,7 @@ class Word2Vec extends Serializable with Logging { def hasNext: Boolean = iter.hasNext def next(): Array[Int] = { - var sentence = new ArrayBuffer[Int] + val sentence = ArrayBuilder.make[Int] var sentenceLength = 0 while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) { val word = bcVocabHash.value.get(iter.next()) @@ -283,7 +283,7 @@ class Word2Vec extends Serializable with Logging { case None => } } - sentence.toArray + sentence.result() } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 9591c7966e06a..efa8459d3cdba 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -18,38 +18,60 @@ package org.apache.spark.mllib.fpm import java.{util => ju} +import java.lang.{Iterable => JavaIterable} import scala.collection.mutable +import scala.collection.JavaConverters._ +import scala.reflect.ClassTag -import org.apache.spark.{SparkException, HashPartitioner, Logging, Partitioner} +import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException} +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -class FPGrowthModel(val freqItemsets: RDD[(Array[String], Long)]) extends Serializable +/** + * :: Experimental :: + * + * Model trained by [[FPGrowth]], which holds frequent itemsets. + * @param freqItemsets frequent itemset, which is an RDD of [[FreqItemset]] + * @tparam Item item type + */ +@Experimental +class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable /** - * This class implements Parallel FP-growth algorithm to do frequent pattern matching on input data. - * Parallel FPGrowth (PFP) partitions computation in such a way that each machine executes an - * independent group of mining tasks. More detail of this algorithm can be found at - * [[http://dx.doi.org/10.1145/1454008.1454027, PFP]], and the original FP-growth paper can be - * found at [[http://dx.doi.org/10.1145/335191.335372, FP-growth]] + * :: Experimental :: + * + * A parallel FP-growth algorithm to mine frequent itemsets. The algorithm is described in + * [[http://dx.doi.org/10.1145/1454008.1454027 Li et al., PFP: Parallel FP-Growth for Query + * Recommendation]]. PFP distributes computation in such a way that each worker executes an + * independent group of mining tasks. The FP-Growth algorithm is described in + * [[http://dx.doi.org/10.1145/335191.335372 Han et al., Mining frequent patterns without candidate + * generation]]. * * @param minSupport the minimal support level of the frequent pattern, any pattern appears * more than (minSupport * size-of-the-dataset) times will be output * @param numPartitions number of partitions used by parallel FP-growth + * + * @see [[http://en.wikipedia.org/wiki/Association_rule_learning Association rule learning + * (Wikipedia)]] */ +@Experimental class FPGrowth private ( private var minSupport: Double, private var numPartitions: Int) extends Logging with Serializable { /** - * Constructs a FPGrowth instance with default parameters: - * {minSupport: 0.3, numPartitions: auto} + * Constructs a default instance with default parameters {minSupport: `0.3`, numPartitions: same + * as the input data}. */ def this() = this(0.3, -1) /** - * Sets the minimal support level (default: 0.3). + * Sets the minimal support level (default: `0.3`). */ def setMinSupport(minSupport: Double): this.type = { this.minSupport = minSupport @@ -69,7 +91,7 @@ class FPGrowth private ( * @param data input data set, each element contains a transaction * @return an [[FPGrowthModel]] */ - def run(data: RDD[Array[String]]): FPGrowthModel = { + def run[Item: ClassTag](data: RDD[Array[Item]]): FPGrowthModel[Item] = { if (data.getStorageLevel == StorageLevel.NONE) { logWarning("Input data is not cached.") } @@ -82,19 +104,24 @@ class FPGrowth private ( new FPGrowthModel(freqItemsets) } + def run[Item, Basket <: JavaIterable[Item]](data: JavaRDD[Basket]): FPGrowthModel[Item] = { + implicit val tag = fakeClassTag[Item] + run(data.rdd.map(_.asScala.toArray)) + } + /** * Generates frequent items by filtering the input data using minimal support level. * @param minCount minimum count for frequent itemsets * @param partitioner partitioner used to distribute items * @return array of frequent pattern ordered by their frequencies */ - private def genFreqItems( - data: RDD[Array[String]], + private def genFreqItems[Item: ClassTag]( + data: RDD[Array[Item]], minCount: Long, - partitioner: Partitioner): Array[String] = { + partitioner: Partitioner): Array[Item] = { data.flatMap { t => val uniq = t.toSet - if (t.length != uniq.size) { + if (t.size != uniq.size) { throw new SparkException(s"Items in a transaction must be unique but got ${t.toSeq}.") } t @@ -114,11 +141,11 @@ class FPGrowth private ( * @param partitioner partitioner used to distribute transactions * @return an RDD of (frequent itemset, count) */ - private def genFreqItemsets( - data: RDD[Array[String]], + private def genFreqItemsets[Item: ClassTag]( + data: RDD[Array[Item]], minCount: Long, - freqItems: Array[String], - partitioner: Partitioner): RDD[(Array[String], Long)] = { + freqItems: Array[Item], + partitioner: Partitioner): RDD[FreqItemset[Item]] = { val itemToRank = freqItems.zipWithIndex.toMap data.flatMap { transaction => genCondTransactions(transaction, itemToRank, partitioner) @@ -128,7 +155,7 @@ class FPGrowth private ( .flatMap { case (part, tree) => tree.extract(minCount, x => partitioner.getPartition(x) == part) }.map { case (ranks, count) => - (ranks.map(i => freqItems(i)).toArray, count) + new FreqItemset(ranks.map(i => freqItems(i)).toArray, count) } } @@ -139,9 +166,9 @@ class FPGrowth private ( * @param partitioner partitioner used to distribute transactions * @return a map of (target partition, conditional transaction) */ - private def genCondTransactions( - transaction: Array[String], - itemToRank: Map[String, Int], + private def genCondTransactions[Item: ClassTag]( + transaction: Array[Item], + itemToRank: Map[Item, Int], partitioner: Partitioner): mutable.Map[Int, Array[Int]] = { val output = mutable.Map.empty[Int, Array[Int]] // Filter the basket by frequent items pattern and sort their ranks. @@ -160,3 +187,26 @@ class FPGrowth private ( output } } + +/** + * :: Experimental :: + */ +@Experimental +object FPGrowth { + + /** + * Frequent itemset. + * @param items items in this itemset. Java users should call [[FreqItemset#javaItems]] instead. + * @param freq frequency + * @tparam Item item type + */ + class FreqItemset[Item](val items: Array[Item], val freq: Long) extends Serializable { + + /** + * Returns items in a Java List. + */ + def javaItems: java.util.List[Item] = { + items.toList.asJava + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala index 76672fe51e834..6e5dd119dd653 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala @@ -74,7 +74,6 @@ import org.apache.spark.storage.StorageLevel * }}} * * @param currentGraph Initial graph - * @param checkpointDir The directory for storing checkpoint files * @param checkpointInterval Graphs will be checkpointed at this interval * @tparam VD Vertex descriptor type * @tparam ED Edge descriptor type @@ -83,7 +82,6 @@ import org.apache.spark.storage.StorageLevel */ private[mllib] class PeriodicGraphCheckpointer[VD, ED]( var currentGraph: Graph[VD, ED], - val checkpointDir: Option[String], val checkpointInterval: Int) extends Logging { /** FIFO queue of past checkpointed RDDs */ @@ -101,12 +99,6 @@ private[mllib] class PeriodicGraphCheckpointer[VD, ED]( */ private val sc = currentGraph.vertices.sparkContext - // If a checkpoint directory is given, and there's no prior checkpoint directory, - // then set the checkpoint directory with the given one. - if (checkpointDir.nonEmpty && sc.getCheckpointDir.isEmpty) { - sc.setCheckpointDir(checkpointDir.get) - } - updateGraph(currentGraph) /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 079f7ca564a92..87052e1ba8539 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -235,12 +235,24 @@ private[spark] object BLAS extends Serializable with Logging { * @param x the vector x that contains the n elements. * @param A the symmetric matrix A. Size of n x n. */ - def syr(alpha: Double, x: DenseVector, A: DenseMatrix) { + def syr(alpha: Double, x: Vector, A: DenseMatrix) { val mA = A.numRows val nA = A.numCols - require(mA == nA, s"A is not a symmetric matrix. A: $mA x $nA") + require(mA == nA, s"A is not a square matrix (and hence is not symmetric). A: $mA x $nA") require(mA == x.size, s"The size of x doesn't match the rank of A. A: $mA x $nA, x: ${x.size}") + x match { + case dv: DenseVector => syr(alpha, dv, A) + case sv: SparseVector => syr(alpha, sv, A) + case _ => + throw new IllegalArgumentException(s"syr doesn't support vector type ${x.getClass}.") + } + } + + private def syr(alpha: Double, x: DenseVector, A: DenseMatrix) { + val nA = A.numRows + val mA = A.numCols + nativeBLAS.dsyr("U", x.size, alpha, x.values, 1, A.values, nA) // Fill lower triangular part of A @@ -255,6 +267,26 @@ private[spark] object BLAS extends Serializable with Logging { } } + private def syr(alpha: Double, x: SparseVector, A: DenseMatrix) { + val mA = A.numCols + val xIndices = x.indices + val xValues = x.values + val nnz = xValues.length + val Avalues = A.values + + var i = 0 + while (i < nnz) { + val multiplier = alpha * xValues(i) + val offset = xIndices(i) * mA + var j = 0 + while (j < nnz) { + Avalues(xIndices(j) + offset) += multiplier * xValues(j) + j += 1 + } + i += 1 + } + } + /** * C := alpha * A * B + beta * C * @param alpha a scalar to scale the multiplication A * B. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala index 3515461b52493..623bdc5594f3a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala @@ -114,7 +114,7 @@ private[mllib] object EigenValueDecomposition { info.`val` match { case 1 => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` + " Maximum number of iterations taken. (Refer ARPACK user guide for details)") - case 2 => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` + + case 3 => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` + " No shifts could be applied. Try to increase NCV. " + "(Refer ARPACK user guide for details)") case _ => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index ad7e86827b368..0e4a4d0085895 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -50,7 +50,7 @@ sealed trait Matrix extends Serializable { private[mllib] def toBreeze: BM[Double] /** Gets the (i, j)-th element. */ - private[mllib] def apply(i: Int, j: Int): Double + def apply(i: Int, j: Int): Double /** Return the index for the (i, j)-th element in the backing array. */ private[mllib] def index(i: Int, j: Int): Int @@ -115,7 +115,7 @@ sealed trait Matrix extends Serializable { * * @param numRows number of rows * @param numCols number of columns - * @param values matrix entries in column major + * @param values matrix entries in column major if not transposed or in row major otherwise * @param isTransposed whether the matrix is transposed. If true, `values` stores the matrix in * row major. */ @@ -163,7 +163,7 @@ class DenseMatrix( private[mllib] def apply(i: Int): Double = values(i) - private[mllib] def apply(i: Int, j: Int): Double = values(index(i, j)) + override def apply(i: Int, j: Int): Double = values(index(i, j)) private[mllib] def index(i: Int, j: Int): Int = { if (!isTransposed) i + numRows * j else j + numCols * i @@ -187,7 +187,7 @@ class DenseMatrix( this } - override def transpose: Matrix = new DenseMatrix(numCols, numRows, values, !isTransposed) + override def transpose: DenseMatrix = new DenseMatrix(numCols, numRows, values, !isTransposed) private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { if (!isTransposed) { @@ -217,9 +217,11 @@ class DenseMatrix( } } - /** Generate a `SparseMatrix` from the given `DenseMatrix`. The new matrix will have isTransposed - * set to false. */ - def toSparse(): SparseMatrix = { + /** + * Generate a `SparseMatrix` from the given `DenseMatrix`. The new matrix will have isTransposed + * set to false. + */ + def toSparse: SparseMatrix = { val spVals: MArrayBuilder[Double] = new MArrayBuilder.ofDouble val colPtrs: Array[Int] = new Array[Int](numCols + 1) val rowIndices: MArrayBuilder[Int] = new MArrayBuilder.ofInt @@ -254,8 +256,11 @@ object DenseMatrix { * @param numCols number of columns of the matrix * @return `DenseMatrix` with size `numRows` x `numCols` and values of zeros */ - def zeros(numRows: Int, numCols: Int): DenseMatrix = + def zeros(numRows: Int, numCols: Int): DenseMatrix = { + require(numRows.toLong * numCols <= Int.MaxValue, + s"$numRows x $numCols dense matrix is too large to allocate") new DenseMatrix(numRows, numCols, new Array[Double](numRows * numCols)) + } /** * Generate a `DenseMatrix` consisting of ones. @@ -263,8 +268,11 @@ object DenseMatrix { * @param numCols number of columns of the matrix * @return `DenseMatrix` with size `numRows` x `numCols` and values of ones */ - def ones(numRows: Int, numCols: Int): DenseMatrix = + def ones(numRows: Int, numCols: Int): DenseMatrix = { + require(numRows.toLong * numCols <= Int.MaxValue, + s"$numRows x $numCols dense matrix is too large to allocate") new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(1.0)) + } /** * Generate an Identity Matrix in `DenseMatrix` format. @@ -282,24 +290,28 @@ object DenseMatrix { } /** - * Generate a `DenseMatrix` consisting of i.i.d. uniform random numbers. + * Generate a `DenseMatrix` consisting of `i.i.d.` uniform random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param rng a random number generator * @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 1) */ def rand(numRows: Int, numCols: Int, rng: Random): DenseMatrix = { + require(numRows.toLong * numCols <= Int.MaxValue, + s"$numRows x $numCols dense matrix is too large to allocate") new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextDouble())) } /** - * Generate a `DenseMatrix` consisting of i.i.d. gaussian random numbers. + * Generate a `DenseMatrix` consisting of `i.i.d.` gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param rng a random number generator * @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 1) */ def randn(numRows: Int, numCols: Int, rng: Random): DenseMatrix = { + require(numRows.toLong * numCols <= Int.MaxValue, + s"$numRows x $numCols dense matrix is too large to allocate") new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextGaussian())) } @@ -336,10 +348,10 @@ object DenseMatrix { * * @param numRows number of rows * @param numCols number of columns - * @param colPtrs the index corresponding to the start of a new column - * @param rowIndices the row index of the entry. They must be in strictly increasing order for each - * column - * @param values non-zero matrix entries in column major + * @param colPtrs the index corresponding to the start of a new column (if not transposed) + * @param rowIndices the row index of the entry (if not transposed). They must be in strictly + * increasing order for each column + * @param values nonzero matrix entries in column major (if not transposed) * @param isTransposed whether the matrix is transposed. If true, the matrix can be considered * Compressed Sparse Row (CSR) format, where `colPtrs` behaves as rowPtrs, * and `rowIndices` behave as colIndices, and `values` are stored in row major. @@ -396,7 +408,7 @@ class SparseMatrix( } } - private[mllib] def apply(i: Int, j: Int): Double = { + override def apply(i: Int, j: Int): Double = { val ind = index(i, j) if (ind < 0) 0.0 else values(ind) } @@ -434,7 +446,7 @@ class SparseMatrix( this } - override def transpose: Matrix = + override def transpose: SparseMatrix = new SparseMatrix(numCols, numRows, colPtrs, rowIndices, values, !isTransposed) private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { @@ -464,9 +476,11 @@ class SparseMatrix( } } - /** Generate a `DenseMatrix` from the given `SparseMatrix`. The new matrix will have isTransposed - * set to false. */ - def toDense(): DenseMatrix = { + /** + * Generate a `DenseMatrix` from the given `SparseMatrix`. The new matrix will have isTransposed + * set to false. + */ + def toDense: DenseMatrix = { new DenseMatrix(numRows, numCols, toArray) } } @@ -593,7 +607,7 @@ object SparseMatrix { } /** - * Generate a `SparseMatrix` consisting of i.i.d. uniform random numbers. The number of non-zero + * Generate a `SparseMatrix` consisting of `i.i.d`. uniform random numbers. The number of non-zero * elements equal the ceiling of `numRows` x `numCols` x `density` * * @param numRows number of rows of the matrix @@ -608,7 +622,7 @@ object SparseMatrix { } /** - * Generate a `SparseMatrix` consisting of i.i.d. gaussian random numbers. + * Generate a `SparseMatrix` consisting of `i.i.d`. gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param density the desired density for the matrix @@ -626,7 +640,7 @@ object SparseMatrix { * @return Square `SparseMatrix` with size `values.length` x `values.length` and non-zero * `values` on the diagonal */ - def diag(vector: Vector): SparseMatrix = { + def spdiag(vector: Vector): SparseMatrix = { val n = vector.size vector match { case sVec: SparseVector => @@ -692,7 +706,7 @@ object Matrices { } /** - * Generate a `DenseMatrix` consisting of zeros. + * Generate a `Matrix` consisting of zeros. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @return `Matrix` with size `numRows` x `numCols` and values of zeros @@ -722,7 +736,7 @@ object Matrices { def speye(n: Int): Matrix = SparseMatrix.speye(n) /** - * Generate a `DenseMatrix` consisting of i.i.d. uniform random numbers. + * Generate a `DenseMatrix` consisting of `i.i.d.` uniform random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param rng a random number generator @@ -732,7 +746,7 @@ object Matrices { DenseMatrix.rand(numRows, numCols, rng) /** - * Generate a `SparseMatrix` consisting of i.i.d. gaussian random numbers. + * Generate a `SparseMatrix` consisting of `i.i.d.` gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param density the desired density for the matrix @@ -743,7 +757,7 @@ object Matrices { SparseMatrix.sprand(numRows, numCols, density, rng) /** - * Generate a `DenseMatrix` consisting of i.i.d. gaussian random numbers. + * Generate a `DenseMatrix` consisting of `i.i.d.` gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param rng a random number generator @@ -753,7 +767,7 @@ object Matrices { DenseMatrix.randn(numRows, numCols, rng) /** - * Generate a `SparseMatrix` consisting of i.i.d. gaussian random numbers. + * Generate a `SparseMatrix` consisting of `i.i.d.` gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param density the desired density for the matrix @@ -764,8 +778,8 @@ object Matrices { SparseMatrix.sprandn(numRows, numCols, density, rng) /** - * Generate a diagonal matrix in `DenseMatrix` format from the supplied values. - * @param vector a `Vector` tat will form the values on the diagonal of the matrix + * Generate a diagonal matrix in `Matrix` format from the supplied values. + * @param vector a `Vector` that will form the values on the diagonal of the matrix * @return Square `Matrix` with size `values.length` x `values.length` and `values` * on the diagonal */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 8f75e6f46e05d..e9d25dcb7e778 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -26,6 +26,7 @@ import scala.collection.JavaConverters._ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} import org.apache.spark.SparkException +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.util.NumericParser import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericMutableRow @@ -110,9 +111,14 @@ sealed trait Vector extends Serializable { } /** + * :: DeveloperApi :: + * * User-defined type for [[Vector]] which allows easy interaction with SQL * via [[org.apache.spark.sql.DataFrame]]. + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. */ +@DeveloperApi private[spark] class VectorUDT extends UserDefinedType[Vector] { override def sqlType: StructType = { @@ -169,6 +175,15 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { override def pyUDT: String = "pyspark.mllib.linalg.VectorUDT" override def userClass: Class[Vector] = classOf[Vector] + + override def equals(o: Any): Boolean = { + o match { + case v: VectorUDT => true + case _ => false + } + } + + private[spark] override def asNullable: VectorUDT = this } /** @@ -234,7 +249,7 @@ object Vectors { } /** - * Creates a dense vector of all zeros. + * Creates a vector of all zeros. * * @param size vector size * @return a zero vector @@ -244,8 +259,7 @@ object Vectors { } /** - * Parses a string resulted from `Vector#toString` into - * an [[org.apache.spark.mllib.linalg.Vector]]. + * Parses a string resulted from [[Vector.toString]] into a [[Vector]]. */ def parse(s: String): Vector = { parseNumeric(NumericParser.parse(s)) @@ -483,6 +497,7 @@ class DenseVector(val values: Array[Double]) extends Vector { } object DenseVector { + /** Extracts the value array from a dense vector. */ def unapply(dv: DenseVector): Option[Array[Double]] = Some(dv.values) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala index 3871152d065a7..1d253963130f1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -21,7 +21,8 @@ import scala.collection.mutable.ArrayBuffer import breeze.linalg.{DenseMatrix => BDM} -import org.apache.spark.{SparkException, Logging, Partitioner} +import org.apache.spark.{Logging, Partitioner, SparkException} +import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -104,6 +105,8 @@ private[mllib] object GridPartitioner { } /** + * :: Experimental :: + * * Represents a distributed matrix in blocks of local matrices. * * @param blocks The RDD of sub-matrix blocks ((blockRowIndex, blockColIndex), sub-matrix) that @@ -118,6 +121,7 @@ private[mllib] object GridPartitioner { * @param nCols Number of columns of this matrix. If the supplied value is less than or equal to * zero, the number of columns will be calculated when `numCols` is invoked. */ +@Experimental class BlockMatrix( val blocks: RDD[((Int, Int), Matrix)], val rowsPerBlock: Int, @@ -177,6 +181,10 @@ class BlockMatrix( assert(cols <= nCols, s"The number of columns $cols is more than claimed $nCols.") } + /** + * Validates the block matrix info against the matrix data (`blocks`) and throws an exception if + * any error is found. + */ def validate(): Unit = { logDebug("Validating BlockMatrix...") // check if the matrix is larger than the claimed dimensions @@ -351,7 +359,7 @@ class BlockMatrix( if (a.nonEmpty && b.nonEmpty) { val C = b.head match { case dense: DenseMatrix => a.head.multiply(dense) - case sparse: SparseMatrix => a.head.multiply(sparse.toDense()) + case sparse: SparseMatrix => a.head.multiply(sparse.toDense) case _ => throw new SparkException(s"Unrecognized matrix type ${b.head.getClass}.") } Iterator(((blockRowIndex, blockColIndex), C.toBreeze)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index 0acdab797e8f3..8bfa0d2b64995 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -63,10 +63,12 @@ abstract class Gradient extends Serializable { * http://statweb.stanford.edu/~tibs/ElemStatLearn/ , Eq. (4.17) on page 119 gives the formula of * multinomial logistic regression model. A simple calculation shows that * + * {{{ * P(y=0|x, w) = 1 / (1 + \sum_i^{K-1} \exp(x w_i)) * P(y=1|x, w) = exp(x w_1) / (1 + \sum_i^{K-1} \exp(x w_i)) * ... * P(y=K-1|x, w) = exp(x w_{K-1}) / (1 + \sum_i^{K-1} \exp(x w_i)) + * }}} * * for K classes multiclass classification problem. * @@ -75,9 +77,11 @@ abstract class Gradient extends Serializable { * will be (K-1) * N. * * As a result, the loss of objective function for a single instance of data can be written as + * {{{ * l(w, x) = -log P(y|x, w) = -\alpha(y) log P(y=0|x, w) - (1-\alpha(y)) log P(y|x, w) * = log(1 + \sum_i^{K-1}\exp(x w_i)) - (1-\alpha(y)) x w_{y-1} * = log(1 + \sum_i^{K-1}\exp(margins_i)) - (1-\alpha(y)) margins_{y-1} + * }}} * * where \alpha(i) = 1 if i != 0, and * \alpha(i) = 0 if i == 0, @@ -86,14 +90,16 @@ abstract class Gradient extends Serializable { * For optimization, we have to calculate the first derivative of the loss function, and * a simple calculation shows that * + * {{{ * \frac{\partial l(w, x)}{\partial w_{ij}} * = (\exp(x w_i) / (1 + \sum_k^{K-1} \exp(x w_k)) - (1-\alpha(y)\delta_{y, i+1})) * x_j * = multiplier_i * x_j + * }}} * * where \delta_{i, j} = 1 if i == j, * \delta_{i, j} = 0 if i != j, and - * multiplier - * = \exp(margins_i) / (1 + \sum_k^{K-1} \exp(margins_i)) - (1-\alpha(y)\delta_{y, i+1}) + * multiplier = + * \exp(margins_i) / (1 + \sum_k^{K-1} \exp(margins_i)) - (1-\alpha(y)\delta_{y, i+1}) * * If any of margins is larger than 709.78, the numerical computation of multiplier and loss * function will be suffered from arithmetic overflow. This issue occurs when there are outliers @@ -103,10 +109,12 @@ abstract class Gradient extends Serializable { * Fortunately, when max(margins) = maxMargin > 0, the loss function and the multiplier can be * easily rewritten into the following equivalent numerically stable formula. * + * {{{ * l(w, x) = log(1 + \sum_i^{K-1}\exp(margins_i)) - (1-\alpha(y)) margins_{y-1} * = log(\exp(-maxMargin) + \sum_i^{K-1}\exp(margins_i - maxMargin)) + maxMargin * - (1-\alpha(y)) margins_{y-1} * = log(1 + sum) + maxMargin - (1-\alpha(y)) margins_{y-1} + * }}} * * where sum = \exp(-maxMargin) + \sum_i^{K-1}\exp(margins_i - maxMargin) - 1. * @@ -115,8 +123,10 @@ abstract class Gradient extends Serializable { * * For multiplier, similar trick can be applied as the following, * + * {{{ * multiplier = \exp(margins_i) / (1 + \sum_k^{K-1} \exp(margins_i)) - (1-\alpha(y)\delta_{y, i+1}) * = \exp(margins_i - maxMargin) / (1 + sum) - (1-\alpha(y)\delta_{y, i+1}) + * }}} * * where each term in \exp is also smaller than zero, so overflow is not a concern. * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index d5e4f4ccbff10..ef6eccd90711a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -60,6 +60,8 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) /** * Set the convergence tolerance of iterations for L-BFGS. Default 1E-4. * Smaller value will lead to higher accuracy with the cost of more iterations. + * This value must be nonnegative. Lower convergence values are less tolerant + * and therefore generally cause more iterations to be run. */ def setConvergenceTol(tolerance: Double): this.type = { this.convergenceTol = tolerance @@ -142,7 +144,9 @@ object LBFGS extends Logging { * one single data example) * @param updater - Updater function to actually perform a gradient step in a given direction. * @param numCorrections - The number of corrections used in the L-BFGS update. - * @param convergenceTol - The convergence tolerance of iterations for L-BFGS + * @param convergenceTol - The convergence tolerance of iterations for L-BFGS which is must be + * nonnegative. Lower values are less tolerant and therefore generally + * cause more iterations to be run. * @param maxNumIterations - Maximal number of iterations that L-BFGS can be run. * @param regParam - Regularization parameter * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index 955c593a085d5..8341bb86afd71 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -29,13 +29,13 @@ import org.apache.spark.util.Utils /** * :: Experimental :: - * Generator methods for creating RDDs comprised of i.i.d. samples from some distribution. + * Generator methods for creating RDDs comprised of `i.i.d.` samples from some distribution. */ @Experimental object RandomRDDs { /** - * Generates an RDD comprised of i.i.d. samples from the uniform distribution `U(0.0, 1.0)`. + * Generates an RDD comprised of `i.i.d.` samples from the uniform distribution `U(0.0, 1.0)`. * * To transform the distribution in the generated RDD from `U(0.0, 1.0)` to `U(a, b)`, use * `RandomRDDs.uniformRDD(sc, n, p, seed).map(v => a + (b - a) * v)`. @@ -44,7 +44,7 @@ object RandomRDDs { * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of i.i.d. samples ~ `U(0.0, 1.0)`. + * @return RDD[Double] comprised of `i.i.d.` samples ~ `U(0.0, 1.0)`. */ def uniformRDD( sc: SparkContext, @@ -81,7 +81,7 @@ object RandomRDDs { } /** - * Generates an RDD comprised of i.i.d. samples from the standard normal distribution. + * Generates an RDD comprised of `i.i.d.` samples from the standard normal distribution. * * To transform the distribution in the generated RDD from standard normal to some other normal * `N(mean, sigma^2^)`, use `RandomRDDs.normalRDD(sc, n, p, seed).map(v => mean + sigma * v)`. @@ -90,7 +90,7 @@ object RandomRDDs { * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of i.i.d. samples ~ N(0.0, 1.0). + * @return RDD[Double] comprised of `i.i.d.` samples ~ N(0.0, 1.0). */ def normalRDD( sc: SparkContext, @@ -127,14 +127,15 @@ object RandomRDDs { } /** - * Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. + * Generates an RDD comprised of `i.i.d.` samples from the Poisson distribution with the input + * mean. * * @param sc SparkContext used to create the RDD. * @param mean Mean, or lambda, for the Poisson distribution. * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). + * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean). */ def poissonRDD( sc: SparkContext, @@ -177,7 +178,7 @@ object RandomRDDs { } /** - * Generates an RDD comprised of i.i.d. samples from the exponential distribution with + * Generates an RDD comprised of `i.i.d.` samples from the exponential distribution with * the input mean. * * @param sc SparkContext used to create the RDD. @@ -185,7 +186,7 @@ object RandomRDDs { * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). + * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean). */ def exponentialRDD( sc: SparkContext, @@ -228,7 +229,7 @@ object RandomRDDs { } /** - * Generates an RDD comprised of i.i.d. samples from the gamma distribution with the input + * Generates an RDD comprised of `i.i.d.` samples from the gamma distribution with the input * shape and scale. * * @param sc SparkContext used to create the RDD. @@ -237,7 +238,7 @@ object RandomRDDs { * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). + * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean). */ def gammaRDD( sc: SparkContext, @@ -287,7 +288,7 @@ object RandomRDDs { } /** - * Generates an RDD comprised of i.i.d. samples from the log normal distribution with the input + * Generates an RDD comprised of `i.i.d.` samples from the log normal distribution with the input * mean and standard deviation * * @param sc SparkContext used to create the RDD. @@ -296,7 +297,7 @@ object RandomRDDs { * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). + * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean). */ def logNormalRDD( sc: SparkContext, @@ -348,14 +349,14 @@ object RandomRDDs { /** * :: DeveloperApi :: - * Generates an RDD comprised of i.i.d. samples produced by the input RandomDataGenerator. + * Generates an RDD comprised of `i.i.d.` samples produced by the input RandomDataGenerator. * * @param sc SparkContext used to create the RDD. * @param generator RandomDataGenerator used to populate the RDD. * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of i.i.d. samples produced by generator. + * @return RDD[Double] comprised of `i.i.d.` samples produced by generator. */ @DeveloperApi def randomRDD[T: ClassTag]( @@ -370,7 +371,7 @@ object RandomRDDs { // TODO Generate RDD[Vector] from multivariate distributions. /** - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the + * Generates an RDD[Vector] with vectors containing `i.i.d.` samples drawn from the * uniform distribution on `U(0.0, 1.0)`. * * @param sc SparkContext used to create the RDD. @@ -424,7 +425,7 @@ object RandomRDDs { } /** - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the + * Generates an RDD[Vector] with vectors containing `i.i.d.` samples drawn from the * standard normal distribution. * * @param sc SparkContext used to create the RDD. @@ -432,7 +433,7 @@ object RandomRDDs { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Vector] with vectors containing i.i.d. samples ~ `N(0.0, 1.0)`. + * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ `N(0.0, 1.0)`. */ def normalVectorRDD( sc: SparkContext, @@ -478,7 +479,7 @@ object RandomRDDs { } /** - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from a + * Generates an RDD[Vector] with vectors containing `i.i.d.` samples drawn from a * log normal distribution. * * @param sc SparkContext used to create the RDD. @@ -488,7 +489,7 @@ object RandomRDDs { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Vector] with vectors containing i.i.d. samples. + * @return RDD[Vector] with vectors containing `i.i.d.` samples. */ def logNormalVectorRDD( sc: SparkContext, @@ -544,7 +545,7 @@ object RandomRDDs { } /** - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the + * Generates an RDD[Vector] with vectors containing `i.i.d.` samples drawn from the * Poisson distribution with the input mean. * * @param sc SparkContext used to create the RDD. @@ -553,7 +554,7 @@ object RandomRDDs { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`) * @param seed Random seed (default: a random long integer). - * @return RDD[Vector] with vectors containing i.i.d. samples ~ Pois(mean). + * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ Pois(mean). */ def poissonVectorRDD( sc: SparkContext, @@ -603,7 +604,7 @@ object RandomRDDs { } /** - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the + * Generates an RDD[Vector] with vectors containing `i.i.d.` samples drawn from the * exponential distribution with the input mean. * * @param sc SparkContext used to create the RDD. @@ -612,7 +613,7 @@ object RandomRDDs { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`) * @param seed Random seed (default: a random long integer). - * @return RDD[Vector] with vectors containing i.i.d. samples ~ Exp(mean). + * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ Exp(mean). */ def exponentialVectorRDD( sc: SparkContext, @@ -665,7 +666,7 @@ object RandomRDDs { /** - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the + * Generates an RDD[Vector] with vectors containing `i.i.d.` samples drawn from the * gamma distribution with the input shape and scale. * * @param sc SparkContext used to create the RDD. @@ -675,7 +676,7 @@ object RandomRDDs { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`) * @param seed Random seed (default: a random long integer). - * @return RDD[Vector] with vectors containing i.i.d. samples ~ Exp(mean). + * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ Exp(mean). */ def gammaVectorRDD( sc: SparkContext, @@ -731,7 +732,7 @@ object RandomRDDs { /** * :: DeveloperApi :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples produced by the + * Generates an RDD[Vector] with vectors containing `i.i.d.` samples produced by the * input RandomDataGenerator. * * @param sc SparkContext used to create the RDD. @@ -740,7 +741,7 @@ object RandomRDDs { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Vector] with vectors containing i.i.d. samples produced by generator. + * @return RDD[Vector] with vectors containing `i.i.d.` samples produced by generator. */ @DeveloperApi def randomVectorRDD(sc: SparkContext, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index f4f51f2ac5210..dddefe1944e9d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -18,17 +18,15 @@ package org.apache.spark.mllib.recommendation import org.apache.spark.Logging -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.JavaRDD import org.apache.spark.ml.recommendation.{ALS => NewALS} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel /** - * :: Experimental :: * A more compact class to represent a rating than Tuple3[Int, Int, Double]. */ -@Experimental case class Rating(user: Int, product: Int, rating: Double) /** @@ -84,6 +82,9 @@ class ALS private ( private var intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK private var finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK + /** checkpoint interval */ + private var checkpointInterval: Int = 10 + /** * Set the number of blocks for both user blocks and product blocks to parallelize the computation * into; pass -1 for an auto-configured number of blocks. Default: -1. @@ -135,10 +136,8 @@ class ALS private ( } /** - * :: Experimental :: * Sets the constant used in computing confidence in implicit ALS. Default: 1.0. */ - @Experimental def setAlpha(alpha: Double): this.type = { this.alpha = alpha this @@ -186,6 +185,19 @@ class ALS private ( this } + /** + * Set period (in iterations) between checkpoints (default = 10). Checkpointing helps with + * recovery (when nodes fail) and StackOverflow exceptions caused by long lineage. It also helps + * with eliminating temporary shuffle files on disk, which can be important when there are many + * ALS iterations. If the checkpoint directory is not set in [[org.apache.spark.SparkContext]], + * this setting is ignored. + */ + @DeveloperApi + def setCheckpointInterval(checkpointInterval: Int): this.type = { + this.checkpointInterval = checkpointInterval + this + } + /** * Run ALS with the configured parameters on an input RDD of (user, product, rating) triples. * Returns a MatrixFactorizationModel with feature vectors for each user and product. @@ -216,6 +228,7 @@ class ALS private ( nonnegative = nonnegative, intermediateRDDStorageLevel = intermediateRDDStorageLevel, finalRDDStorageLevel = StorageLevel.NONE, + checkpointInterval = checkpointInterval, seed = seed) val userFactors = floatUserFactors diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index ed2f8b41bcae5..c399496568bfb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -17,13 +17,20 @@ package org.apache.spark.mllib.recommendation +import java.io.IOException import java.lang.{Integer => JavaInteger} +import org.apache.hadoop.fs.Path import org.jblas.DoubleMatrix +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.storage.StorageLevel /** @@ -41,7 +48,8 @@ import org.apache.spark.storage.StorageLevel class MatrixFactorizationModel( val rank: Int, val userFeatures: RDD[(Int, Array[Double])], - val productFeatures: RDD[(Int, Array[Double])]) extends Serializable with Logging { + val productFeatures: RDD[(Int, Array[Double])]) + extends Saveable with Serializable with Logging { require(rank > 0) validateFeatures("User", userFeatures) @@ -125,6 +133,12 @@ class MatrixFactorizationModel( recommend(productFeatures.lookup(product).head, userFeatures, num) .map(t => Rating(t._1, product, t._2)) + protected override val formatVersion: String = "1.0" + + override def save(sc: SparkContext, path: String): Unit = { + MatrixFactorizationModel.SaveLoadV1_0.save(this, path) + } + private def recommend( recommendToFeatures: Array[Double], recommendableFeatures: RDD[(Int, Array[Double])], @@ -136,3 +150,71 @@ class MatrixFactorizationModel( scored.top(num)(Ordering.by(_._2)) } } + +object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { + + import org.apache.spark.mllib.util.Loader._ + + override def load(sc: SparkContext, path: String): MatrixFactorizationModel = { + val (loadedClassName, formatVersion, _) = loadMetadata(sc, path) + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, formatVersion) match { + case (className, "1.0") if className == classNameV1_0 => + SaveLoadV1_0.load(sc, path) + case _ => + throw new IOException("MatrixFactorizationModel.load did not recognize model with" + + s"(class: $loadedClassName, version: $formatVersion). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } + + private[recommendation] + object SaveLoadV1_0 { + + private val thisFormatVersion = "1.0" + + private[recommendation] + val thisClassName = "org.apache.spark.mllib.recommendation.MatrixFactorizationModel" + + /** + * Saves a [[MatrixFactorizationModel]], where user features are saved under `data/users` and + * product features are saved under `data/products`. + */ + def save(model: MatrixFactorizationModel, path: String): Unit = { + val sc = model.userFeatures.sparkContext + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) + model.userFeatures.toDF("id", "features").saveAsParquetFile(userPath(path)) + model.productFeatures.toDF("id", "features").saveAsParquetFile(productPath(path)) + } + + def load(sc: SparkContext, path: String): MatrixFactorizationModel = { + implicit val formats = DefaultFormats + val sqlContext = new SQLContext(sc) + val (className, formatVersion, metadata) = loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + val rank = (metadata \ "rank").extract[Int] + val userFeatures = sqlContext.parquetFile(userPath(path)) + .map { case Row(id: Int, features: Seq[Double]) => + (id, features.toArray) + } + val productFeatures = sqlContext.parquetFile(productPath(path)) + .map { case Row(id: Int, features: Seq[Double]) => + (id, features.toArray) + } + new MatrixFactorizationModel(rank, userFeatures, productFeatures) + } + + private def userPath(path: String): String = { + new Path(dataPath(path), "user").toUri.toString + } + + private def productPath(path: String): String = { + new Path(dataPath(path), "product").toUri.toString + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 17de215b97f9d..9a2751aef908e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -126,7 +126,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] /** * The dimension of training features. */ - protected var numFeatures: Int = 0 + protected var numFeatures: Int = -1 /** * Set if the algorithm should use feature scaling to improve the convergence during optimization. @@ -163,7 +163,9 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] * RDD of LabeledPoint entries. */ def run(input: RDD[LabeledPoint]): M = { - numFeatures = input.first().features.size + if (numFeatures < 0) { + numFeatures = input.map(_.features.size).first() + } /** * When `numOfLinearPredictor > 1`, the intercepts are encapsulated into weights, @@ -193,7 +195,10 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] * of LabeledPoint entries starting from the initial weights provided. */ def run(input: RDD[LabeledPoint], initialWeights: Vector): M = { - numFeatures = input.first().features.size + + if (numFeatures < 0) { + numFeatures = input.map(_.features.size).first() + } if (input.getStorageLevel == StorageLevel.NONE) { logWarning("The input data is not directly cached, which may hurt performance if its" @@ -205,7 +210,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] throw new SparkException("Input validation failed.") } - /** + /* * Scaling columns to unit variance as a heuristic to reduce the condition number: * * During the optimization process, the convergence (rate) depends on the condition number of @@ -225,26 +230,27 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] * Currently, it's only enabled in LogisticRegressionWithLBFGS */ val scaler = if (useFeatureScaling) { - (new StandardScaler(withStd = true, withMean = false)).fit(input.map(x => x.features)) + new StandardScaler(withStd = true, withMean = false).fit(input.map(_.features)) } else { null } // Prepend an extra variable consisting of all 1.0's for the intercept. - val data = if (addIntercept) { - if (useFeatureScaling) { - input.map(labeledPoint => - (labeledPoint.label, appendBias(scaler.transform(labeledPoint.features)))) - } else { - input.map(labeledPoint => (labeledPoint.label, appendBias(labeledPoint.features))) - } - } else { - if (useFeatureScaling) { - input.map(labeledPoint => (labeledPoint.label, scaler.transform(labeledPoint.features))) + // TODO: Apply feature scaling to the weight vector instead of input data. + val data = + if (addIntercept) { + if (useFeatureScaling) { + input.map(lp => (lp.label, appendBias(scaler.transform(lp.features)))).cache() + } else { + input.map(lp => (lp.label, appendBias(lp.features))).cache() + } } else { - input.map(labeledPoint => (labeledPoint.label, labeledPoint.features)) + if (useFeatureScaling) { + input.map(lp => (lp.label, scaler.transform(lp.features))).cache() + } else { + input.map(lp => (lp.label, lp.features)) + } } - } /** * TODO: For better convergence, in logistic regression, the intercepts should be computed diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index 5ed6477bae3b2..cb70852e3cc8d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -23,10 +23,13 @@ import java.util.Arrays.binarySearch import scala.collection.mutable.ArrayBuffer +import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD} import org.apache.spark.rdd.RDD /** + * :: Experimental :: + * * Regression model for isotonic regression. * * @param boundaries Array of boundaries for which predictions are known. @@ -35,6 +38,7 @@ import org.apache.spark.rdd.RDD * Results of isotonic regression and therefore monotone. * @param isotonic indicates whether this is isotonic or antitonic. */ +@Experimental class IsotonicRegressionModel ( val boundaries: Array[Double], val predictions: Array[Double], @@ -123,6 +127,8 @@ class IsotonicRegressionModel ( } /** + * :: Experimental :: + * * Isotonic regression. * Currently implemented using parallelized pool adjacent violators algorithm. * Only univariate (single feature) algorithm supported. @@ -130,14 +136,17 @@ class IsotonicRegressionModel ( * Sequential PAV implementation based on: * Tibshirani, Ryan J., Holger Hoefling, and Robert Tibshirani. * "Nearly-isotonic regression." Technometrics 53.1 (2011): 54-61. - * Available from http://www.stat.cmu.edu/~ryantibs/papers/neariso.pdf + * Available from [[http://www.stat.cmu.edu/~ryantibs/papers/neariso.pdf]] * * Sequential PAV parallelization based on: * Kearsley, Anthony J., Richard A. Tapia, and Michael W. Trosset. * "An approach to parallelizing isotonic regression." * Applied Mathematics and Parallel Computing. Physica-Verlag HD, 1996. 141-147. - * Available from http://softlib.rice.edu/pub/CRPC-TRs/reports/CRPC-TR96640.pdf + * Available from [[http://softlib.rice.edu/pub/CRPC-TRs/reports/CRPC-TR96640.pdf]] + * + * @see [[http://en.wikipedia.org/wiki/Isotonic_regression Isotonic regression (Wikipedia)]] */ +@Experimental class IsotonicRegression private (private var isotonic: Boolean) extends Serializable { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index 8ecd5c6ad93c0..e8b03816573cf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -17,9 +17,11 @@ package org.apache.spark.mllib.regression -import org.apache.spark.annotation.Experimental +import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.regression.impl.GLMRegressionModel +import org.apache.spark.mllib.util.{Saveable, Loader} import org.apache.spark.rdd.RDD /** @@ -32,7 +34,7 @@ class LassoModel ( override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable { + with RegressionModel with Serializable with Saveable { override protected def predictPoint( dataMatrix: Vector, @@ -40,12 +42,37 @@ class LassoModel ( intercept: Double): Double = { weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } + + override def save(sc: SparkContext, path: String): Unit = { + GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept) + } + + override protected def formatVersion: String = "1.0" +} + +object LassoModel extends Loader[LassoModel] { + + override def load(sc: SparkContext, path: String): LassoModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future + val classNameV1_0 = "org.apache.spark.mllib.regression.LassoModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val numFeatures = RegressionModel.getNumFeatures(metadata) + val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures) + new LassoModel(data.weights, data.intercept) + case _ => throw new Exception( + s"LassoModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } /** * Train a regression model with L1-regularization using Stochastic Gradient Descent. * This solves the l1-regularized least squares regression formulation - * f(weights) = 1/2n ||A weights-y||^2 + regParam ||weights||_1 + * f(weights) = 1/2n ||A weights-y||^2^ + regParam ||weights||_1 * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with * its corresponding right hand side label y. * See also the documentation for the precise formulation. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 81b6598377ff5..6fa7ad52a5b33 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -17,9 +17,12 @@ package org.apache.spark.mllib.regression -import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.regression.impl.GLMRegressionModel +import org.apache.spark.mllib.util.{Saveable, Loader} +import org.apache.spark.rdd.RDD /** * Regression model trained using LinearRegression. @@ -30,7 +33,8 @@ import org.apache.spark.mllib.optimization._ class LinearRegressionModel ( override val weights: Vector, override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { + extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable + with Saveable { override protected def predictPoint( dataMatrix: Vector, @@ -38,12 +42,37 @@ class LinearRegressionModel ( intercept: Double): Double = { weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } + + override def save(sc: SparkContext, path: String): Unit = { + GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept) + } + + override protected def formatVersion: String = "1.0" +} + +object LinearRegressionModel extends Loader[LinearRegressionModel] { + + override def load(sc: SparkContext, path: String): LinearRegressionModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future + val classNameV1_0 = "org.apache.spark.mllib.regression.LinearRegressionModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val numFeatures = RegressionModel.getNumFeatures(metadata) + val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures) + new LinearRegressionModel(data.weights, data.intercept) + case _ => throw new Exception( + s"LinearRegressionModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } /** * Train a linear regression model with no regularization using Stochastic Gradient Descent. * This solves the least squares regression formulation - * f(weights) = 1/n ||A weights-y||^2 + * f(weights) = 1/n ||A weights-y||^2^ * (which is the mean squared error). * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with * its corresponding right hand side label y. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala index 64b02f7a6e7a9..214ac4d0ed7dd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala @@ -17,10 +17,12 @@ package org.apache.spark.mllib.regression +import org.json4s.{DefaultFormats, JValue} + import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD @Experimental trait RegressionModel extends Serializable { @@ -48,3 +50,15 @@ trait RegressionModel extends Serializable { def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } + +private[mllib] object RegressionModel { + + /** + * Helper method for loading GLM regression model metadata. + * @return numFeatures + */ + def getNumFeatures(metadata: JValue): Int = { + implicit val formats = DefaultFormats + (metadata \ "numFeatures").extract[Int] + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index 076ba35051c9d..8838ca8c14718 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -17,10 +17,13 @@ package org.apache.spark.mllib.regression -import org.apache.spark.annotation.Experimental -import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.optimization._ +import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.regression.impl.GLMRegressionModel +import org.apache.spark.mllib.util.{Loader, Saveable} +import org.apache.spark.rdd.RDD + /** * Regression model trained using RidgeRegression. @@ -32,7 +35,7 @@ class RidgeRegressionModel ( override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable { + with RegressionModel with Serializable with Saveable { override protected def predictPoint( dataMatrix: Vector, @@ -40,12 +43,37 @@ class RidgeRegressionModel ( intercept: Double): Double = { weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } + + override def save(sc: SparkContext, path: String): Unit = { + GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept) + } + + override protected def formatVersion: String = "1.0" +} + +object RidgeRegressionModel extends Loader[RidgeRegressionModel] { + + override def load(sc: SparkContext, path: String): RidgeRegressionModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future + val classNameV1_0 = "org.apache.spark.mllib.regression.RidgeRegressionModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val numFeatures = RegressionModel.getNumFeatures(metadata) + val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures) + new RidgeRegressionModel(data.weights, data.intercept) + case _ => throw new Exception( + s"RidgeRegressionModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } /** * Train a regression model with L2-regularization using Stochastic Gradient Descent. * This solves the l1-regularized least squares regression formulation - * f(weights) = 1/2n ||A weights-y||^2 + regParam/2 ||weights||^2 + * f(weights) = 1/2n ||A weights-y||^2^ + regParam/2 ||weights||^2^ * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with * its corresponding right hand side label y. * See also the documentation for the precise formulation. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index 44a8dbb994cfb..cea8f3f47307b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -21,7 +21,9 @@ import scala.reflect.ClassTag import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.streaming.api.java.{JavaDStream, JavaPairDStream} import org.apache.spark.streaming.dstream.DStream /** @@ -39,14 +41,14 @@ import org.apache.spark.streaming.dstream.DStream * * For example usage, see `StreamingLinearRegressionWithSGD`. * - * NOTE(Freeman): In some use cases, the order in which trainOn and predictOn + * NOTE: In some use cases, the order in which trainOn and predictOn * are called in an application will affect the results. When called on * the same DStream, if trainOn is called before predictOn, when new data * arrive the model will update and the prediction will be based on the new * model. Whereas if predictOn is called first, the prediction will use the model * from the previous update. * - * NOTE(Freeman): It is ok to call predictOn repeatedly on multiple streams; this + * NOTE: It is ok to call predictOn repeatedly on multiple streams; this * will generate predictions for each one all using the current model. * It is also ok to call trainOn on different streams; this will update * the model using each of the different sources, in sequence. @@ -58,7 +60,7 @@ abstract class StreamingLinearAlgorithm[ A <: GeneralizedLinearAlgorithm[M]] extends Logging { /** The model to be updated and used for prediction. */ - protected var model: Option[M] = None + protected var model: Option[M] /** The algorithm to use for updating. */ protected val algorithm: A @@ -76,7 +78,7 @@ abstract class StreamingLinearAlgorithm[ * * @param data DStream containing labeled data */ - def trainOn(data: DStream[LabeledPoint]) { + def trainOn(data: DStream[LabeledPoint]): Unit = { if (model.isEmpty) { throw new IllegalArgumentException("Model must be initialized before starting training.") } @@ -99,6 +101,9 @@ abstract class StreamingLinearAlgorithm[ } } + /** Java-friendly version of `trainOn`. */ + def trainOn(data: JavaDStream[LabeledPoint]): Unit = trainOn(data.dstream) + /** * Use the model to make predictions on batches of data from a DStream * @@ -109,7 +114,12 @@ abstract class StreamingLinearAlgorithm[ if (model.isEmpty) { throw new IllegalArgumentException("Model must be initialized before starting prediction.") } - data.map(model.get.predict) + data.map{x => model.get.predict(x)} + } + + /** Java-friendly version of `predictOn`. */ + def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Double] = { + JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Double]]) } /** @@ -122,6 +132,14 @@ abstract class StreamingLinearAlgorithm[ if (model.isEmpty) { throw new IllegalArgumentException("Model must be initialized before starting prediction") } - data.mapValues(model.get.predict) + data.mapValues{x => model.get.predict(x)} + } + + + /** Java-friendly version of `predictOnValues`. */ + def predictOnValues[K](data: JavaPairDStream[K, Vector]): JavaPairDStream[K, java.lang.Double] = { + implicit val tag = fakeClassTag[K] + JavaPairDStream.fromPairDStream( + predictOnValues(data.dstream).asInstanceOf[DStream[(K, java.lang.Double)]]) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index e5e6301127a28..a49153bf73c0d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -59,6 +59,8 @@ class StreamingLinearRegressionWithSGD private[mllib] ( val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction) + protected var model: Option[LinearRegressionModel] = None + /** Set the step size for gradient descent. Default: 0.1. */ def setStepSize(stepSize: Double): this.type = { this.algorithm.optimizer.setStepSize(stepSize) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala new file mode 100644 index 0000000000000..bd7e340ca2d8e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression.impl + +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.util.Loader +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + +/** + * Helper methods for import/export of GLM regression models. + */ +private[regression] object GLMRegressionModel { + + object SaveLoadV1_0 { + + def thisFormatVersion = "1.0" + + /** Model data for model import/export */ + case class Data(weights: Vector, intercept: Double) + + /** + * Helper method for saving GLM regression model metadata and data. + * @param modelClass String name for model class, to be saved with metadata + */ + def save( + sc: SparkContext, + path: String, + modelClass: String, + weights: Vector, + intercept: Double): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Create JSON metadata. + val metadata = compact(render( + ("class" -> modelClass) ~ ("version" -> thisFormatVersion) ~ + ("numFeatures" -> weights.size))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Create Parquet data. + val data = Data(weights, intercept) + val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF() + // TODO: repartition with 1 partition after SPARK-5532 gets fixed + dataRDD.saveAsParquetFile(Loader.dataPath(path)) + } + + /** + * Helper method for loading GLM regression model data. + * @param modelClass String name for model class (used for error messages) + * @param numFeatures Number of features, to be checked against loaded data. + * The length of the weights vector should equal numFeatures. + */ + def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = { + val datapath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + val dataRDD = sqlContext.parquetFile(datapath) + val dataArray = dataRDD.select("weights", "intercept").take(1) + assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath") + val data = dataArray(0) + assert(data.size == 2, s"Unable to load $modelClass data from: $datapath") + data match { + case Row(weights: Vector, intercept: Double) => + assert(weights.size == numFeatures, s"Expected $numFeatures features, but" + + s" found ${weights.size} features when loading $modelClass weights from $datapath") + Data(weights, intercept) + } + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index 3cf4e807b4cf7..b3fad0c52d655 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -26,36 +26,32 @@ import org.apache.spark.mllib.stat.test.{ChiSqTest, ChiSqTestResult} import org.apache.spark.rdd.RDD /** + * :: Experimental :: * API for statistical functions in MLlib. */ @Experimental object Statistics { /** - * :: Experimental :: * Computes column-wise summary statistics for the input RDD[Vector]. * * @param X an RDD[Vector] for which column-wise summary statistics are to be computed. * @return [[MultivariateStatisticalSummary]] object containing column-wise summary statistics. */ - @Experimental def colStats(X: RDD[Vector]): MultivariateStatisticalSummary = { new RowMatrix(X).computeColumnSummaryStatistics() } /** - * :: Experimental :: * Compute the Pearson correlation matrix for the input RDD of Vectors. * Columns with 0 covariance produce NaN entries in the correlation matrix. * * @param X an RDD[Vector] for which the correlation matrix is to be computed. * @return Pearson correlation matrix comparing columns in X. */ - @Experimental def corr(X: RDD[Vector]): Matrix = Correlations.corrMatrix(X) /** - * :: Experimental :: * Compute the correlation matrix for the input RDD of Vectors using the specified method. * Methods currently supported: `pearson` (default), `spearman`. * @@ -69,11 +65,9 @@ object Statistics { * Supported: `pearson` (default), `spearman` * @return Correlation matrix comparing columns in X. */ - @Experimental def corr(X: RDD[Vector], method: String): Matrix = Correlations.corrMatrix(X, method) /** - * :: Experimental :: * Compute the Pearson correlation for the input RDDs. * Returns NaN if either vector has 0 variance. * @@ -84,11 +78,9 @@ object Statistics { * @param y RDD[Double] of the same cardinality as x. * @return A Double containing the Pearson correlation between the two input RDD[Double]s */ - @Experimental def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y) /** - * :: Experimental :: * Compute the correlation for the input RDDs using the specified method. * Methods currently supported: `pearson` (default), `spearman`. * @@ -99,14 +91,12 @@ object Statistics { * @param y RDD[Double] of the same cardinality as x. * @param method String specifying the method to use for computing correlation. * Supported: `pearson` (default), `spearman` - *@return A Double containing the correlation between the two input RDD[Double]s using the + * @return A Double containing the correlation between the two input RDD[Double]s using the * specified method. */ - @Experimental def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) /** - * :: Experimental :: * Conduct Pearson's chi-squared goodness of fit test of the observed data against the * expected distribution. * @@ -120,13 +110,11 @@ object Statistics { * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. */ - @Experimental def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = { ChiSqTest.chiSquared(observed, expected) } /** - * :: Experimental :: * Conduct Pearson's chi-squared goodness of fit test of the observed data against the uniform * distribution, with each category having an expected frequency of `1 / observed.size`. * @@ -136,11 +124,9 @@ object Statistics { * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. */ - @Experimental def chiSqTest(observed: Vector): ChiSqTestResult = ChiSqTest.chiSquared(observed) /** - * :: Experimental :: * Conduct Pearson's independence test on the input contingency matrix, which cannot contain * negative entries or columns or rows that sum up to 0. * @@ -148,11 +134,9 @@ object Statistics { * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. */ - @Experimental def chiSqTest(observed: Matrix): ChiSqTestResult = ChiSqTest.chiSquaredMatrix(observed) /** - * :: Experimental :: * Conduct Pearson's independence test for every feature against the label across the input RDD. * For each feature, the (feature, label) pairs are converted into a contingency matrix for which * the chi-squared statistic is computed. All label and feature values must be categorical. @@ -162,7 +146,6 @@ object Statistics { * @return an array containing the ChiSquaredTestResult for every feature against the label. * The order of the elements in the returned array reflects the order of input features. */ - @Experimental def chiSqTest(data: RDD[LabeledPoint]): Array[ChiSqTestResult] = { ChiSqTest.chiSquaredFeatures(data) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index fd186b5ee6f72..cd6add9d60b0d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.stat.distribution -import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym} +import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym, Vector => BV} import org.apache.spark.annotation.DeveloperApi; import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix} @@ -62,21 +62,21 @@ class MultivariateGaussian ( /** Returns density of this multivariate Gaussian at given point, x */ def pdf(x: Vector): Double = { - pdf(x.toBreeze.toDenseVector) + pdf(x.toBreeze) } /** Returns the log-density of this multivariate Gaussian at given point, x */ def logpdf(x: Vector): Double = { - logpdf(x.toBreeze.toDenseVector) + logpdf(x.toBreeze) } /** Returns density of this multivariate Gaussian at given point, x */ - private[mllib] def pdf(x: DBV[Double]): Double = { + private[mllib] def pdf(x: BV[Double]): Double = { math.exp(logpdf(x)) } /** Returns the log-density of this multivariate Gaussian at given point, x */ - private[mllib] def logpdf(x: DBV[Double]): Double = { + private[mllib] def logpdf(x: BV[Double]): Double = { val delta = x - breezeMu val v = rootSigmaInv * delta u + v.t * v * -0.5 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index b3e8ed9af8c51..f1f85994e61b2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -19,12 +19,11 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - +import scala.collection.mutable.ArrayBuilder +import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD -import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo import org.apache.spark.mllib.tree.configuration.Strategy @@ -32,13 +31,10 @@ import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.impl._ -import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity} import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom -import org.apache.spark.SparkContext._ - /** * :: Experimental :: @@ -1140,7 +1136,7 @@ object DecisionTree extends Serializable with Logging { logDebug("stride = " + stride) // iterate `valueCount` to find splits - val splits = new ArrayBuffer[Double] + val splitsBuilder = ArrayBuilder.make[Double] var index = 1 // currentCount: sum of counts of values that have been visited var currentCount = valueCounts(0)._2 @@ -1158,13 +1154,13 @@ object DecisionTree extends Serializable with Logging { // makes the gap between currentCount and targetCount smaller, // previous value is a split threshold. if (previousGap < currentGap) { - splits.append(valueCounts(index - 1)._1) + splitsBuilder += valueCounts(index - 1)._1 targetCount += stride } index += 1 } - splits.toArray + splitsBuilder.result() } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 482dd4b272d1d..db01f2e229e5a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.tree +import java.io.IOException + import scala.collection.mutable import scala.collection.JavaConverters._ @@ -202,7 +204,6 @@ private class RandomForest ( Some(NodeIdCache.init( data = baggedInput, numTrees = numTrees, - checkpointDir = strategy.checkpointDir, checkpointInterval = strategy.checkpointInterval, initVal = 1)) } else { @@ -244,7 +245,12 @@ private class RandomForest ( // Delete any remaining checkpoints used for node Id cache. if (nodeIdCache.nonEmpty) { - nodeIdCache.get.deleteAllCheckpoints() + try { + nodeIdCache.get.deleteAllCheckpoints() + } catch { + case e:IOException => + logWarning(s"delete all chackpoints failed. Error reason: ${e.getMessage}") + } } val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 3308adb6752ff..8d5c36da32bdb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -62,11 +62,10 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * @param subsamplingRate Fraction of the training data used for learning decision tree. * @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will * maintain a separate RDD of node Id cache for each row. - * @param checkpointDir If the node Id cache is used, it will help to checkpoint - * the node Id cache periodically. This is the checkpoint directory - * to be used for the node Id cache. * @param checkpointInterval How often to checkpoint when the node Id cache gets updated. - * E.g. 10 means that the cache will get checkpointed every 10 updates. + * E.g. 10 means that the cache will get checkpointed every 10 updates. If + * the checkpoint directory is not set in + * [[org.apache.spark.SparkContext]], this setting is ignored. */ @Experimental class Strategy ( @@ -82,7 +81,6 @@ class Strategy ( @BeanProperty var maxMemoryInMB: Int = 256, @BeanProperty var subsamplingRate: Double = 1, @BeanProperty var useNodeIdCache: Boolean = false, - @BeanProperty var checkpointDir: Option[String] = None, @BeanProperty var checkpointInterval: Int = 10) extends Serializable { def isMulticlassClassification = @@ -165,7 +163,7 @@ class Strategy ( def copy: Strategy = { new Strategy(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, - maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointDir, checkpointInterval) + maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala index 83011b48b7d9b..bdd0f576b048d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala @@ -71,15 +71,12 @@ private[tree] case class NodeIndexUpdater( * The nodeIdsForInstances RDD needs to be updated at each iteration. * @param nodeIdsForInstances The initial values in the cache * (should be an Array of all 1's (meaning the root nodes)). - * @param checkpointDir The checkpoint directory where - * the checkpointed files will be stored. * @param checkpointInterval The checkpointing interval * (how often should the cache be checkpointed.). */ @DeveloperApi private[tree] class NodeIdCache( var nodeIdsForInstances: RDD[Array[Int]], - val checkpointDir: Option[String], val checkpointInterval: Int) { // Keep a reference to a previous node Ids for instances. @@ -91,12 +88,6 @@ private[tree] class NodeIdCache( private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]() private var rddUpdateCount = 0 - // If a checkpoint directory is given, and there's no prior checkpoint directory, - // then set the checkpoint directory with the given one. - if (checkpointDir.nonEmpty && nodeIdsForInstances.sparkContext.getCheckpointDir.isEmpty) { - nodeIdsForInstances.sparkContext.setCheckpointDir(checkpointDir.get) - } - /** * Update the node index values in the cache. * This updates the RDD and its lineage. @@ -184,7 +175,6 @@ private[tree] object NodeIdCache { * Initialize the node Id cache with initial node Id values. * @param data The RDD of training rows. * @param numTrees The number of trees that we want to create cache for. - * @param checkpointDir The checkpoint directory where the checkpointed files will be stored. * @param checkpointInterval The checkpointing interval * (how often should the cache be checkpointed.). * @param initVal The initial values in the cache. @@ -193,12 +183,10 @@ private[tree] object NodeIdCache { def init( data: RDD[BaggedPoint[TreePoint]], numTrees: Int, - checkpointDir: Option[String], checkpointInterval: Int, initVal: Int = 1): NodeIdCache = { new NodeIdCache( data.map(_ => Array.fill[Int](numTrees)(initVal)), - checkpointDir, checkpointInterval) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index a5760963068c3..8a57ebc387d01 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -17,11 +17,22 @@ package org.apache.spark.mllib.tree.model +import scala.collection.mutable + +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.configuration.{Algo, FeatureType} import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.util.Utils /** * :: Experimental :: @@ -31,7 +42,7 @@ import org.apache.spark.rdd.RDD * @param algo algorithm type -- classification or regression */ @Experimental -class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable { +class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable with Saveable { /** * Predict values for a single data point using the model trained. @@ -53,7 +64,6 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable features.map(x => predict(x)) } - /** * Predict values for the given data set using the model trained. * @@ -99,4 +109,205 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable header + topNode.subtreeToString(2) } + override def save(sc: SparkContext, path: String): Unit = { + DecisionTreeModel.SaveLoadV1_0.save(sc, path, this) + } + + override protected def formatVersion: String = "1.0" +} + +object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { + + private[tree] object SaveLoadV1_0 { + + def thisFormatVersion = "1.0" + + // Hard-code class name string in case it changes in the future + def thisClassName = "org.apache.spark.mllib.tree.DecisionTreeModel" + + case class PredictData(predict: Double, prob: Double) { + def toPredict: Predict = new Predict(predict, prob) + } + + object PredictData { + def apply(p: Predict): PredictData = PredictData(p.predict, p.prob) + + def apply(r: Row): PredictData = PredictData(r.getDouble(0), r.getDouble(1)) + } + + case class SplitData( + feature: Int, + threshold: Double, + featureType: Int, + categories: Seq[Double]) { // TODO: Change to List once SPARK-3365 is fixed + def toSplit: Split = { + new Split(feature, threshold, FeatureType(featureType), categories.toList) + } + } + + object SplitData { + def apply(s: Split): SplitData = { + SplitData(s.feature, s.threshold, s.featureType.id, s.categories) + } + + def apply(r: Row): SplitData = { + SplitData(r.getInt(0), r.getDouble(1), r.getInt(2), r.getAs[Seq[Double]](3)) + } + } + + /** Model data for model import/export */ + case class NodeData( + treeId: Int, + nodeId: Int, + predict: PredictData, + impurity: Double, + isLeaf: Boolean, + split: Option[SplitData], + leftNodeId: Option[Int], + rightNodeId: Option[Int], + infoGain: Option[Double]) + + object NodeData { + def apply(treeId: Int, n: Node): NodeData = { + NodeData(treeId, n.id, PredictData(n.predict), n.impurity, n.isLeaf, + n.split.map(SplitData.apply), n.leftNode.map(_.id), n.rightNode.map(_.id), + n.stats.map(_.gain)) + } + + def apply(r: Row): NodeData = { + val split = if (r.isNullAt(5)) None else Some(SplitData(r.getStruct(5))) + val leftNodeId = if (r.isNullAt(6)) None else Some(r.getInt(6)) + val rightNodeId = if (r.isNullAt(7)) None else Some(r.getInt(7)) + val infoGain = if (r.isNullAt(8)) None else Some(r.getDouble(8)) + NodeData(r.getInt(0), r.getInt(1), PredictData(r.getStruct(2)), r.getDouble(3), + r.getBoolean(4), split, leftNodeId, rightNodeId, infoGain) + } + } + + def save(sc: SparkContext, path: String, model: DecisionTreeModel): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // SPARK-6120: We do a hacky check here so users understand why save() is failing + // when they run the ML guide example. + // TODO: Fix this issue for real. + val memThreshold = 768 + if (sc.isLocal) { + val driverMemory = sc.getConf.getOption("spark.driver.memory") + .orElse(Option(System.getenv("SPARK_DRIVER_MEMORY"))) + .map(Utils.memoryStringToMb) + .getOrElse(512) + if (driverMemory <= memThreshold) { + logWarning(s"$thisClassName.save() was called, but it may fail because of too little" + + s" driver memory (${driverMemory}m)." + + s" If failure occurs, try setting driver-memory ${memThreshold}m (or larger).") + } + } else { + if (sc.executorMemory <= memThreshold) { + logWarning(s"$thisClassName.save() was called, but it may fail because of too little" + + s" executor memory (${sc.executorMemory}m)." + + s" If failure occurs try setting executor-memory ${memThreshold}m (or larger).") + } + } + + // Create JSON metadata. + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("algo" -> model.algo.toString) ~ ("numNodes" -> model.numNodes))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Create Parquet data. + val nodes = model.topNode.subtreeIterator.toSeq + val dataRDD: DataFrame = sc.parallelize(nodes) + .map(NodeData.apply(0, _)) + .toDF() + dataRDD.saveAsParquetFile(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = { + val datapath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + // Load Parquet data. + val dataRDD = sqlContext.parquetFile(datapath) + // Check schema explicitly since erasure makes it hard to use match-case for checking. + Loader.checkSchema[NodeData](dataRDD.schema) + val nodes = dataRDD.map(NodeData.apply) + // Build node data into a tree. + val trees = constructTrees(nodes) + assert(trees.size == 1, + "Decision tree should contain exactly one tree but got ${trees.size} trees.") + val model = new DecisionTreeModel(trees(0), Algo.fromString(algo)) + assert(model.numNodes == numNodes, s"Unable to load DecisionTreeModel data from: $datapath." + + s" Expected $numNodes nodes but found ${model.numNodes}") + model + } + + def constructTrees(nodes: RDD[NodeData]): Array[Node] = { + val trees = nodes + .groupBy(_.treeId) + .mapValues(_.toArray) + .collect() + .map { case (treeId, data) => + (treeId, constructTree(data)) + }.sortBy(_._1) + val numTrees = trees.size + val treeIndices = trees.map(_._1).toSeq + assert(treeIndices == (0 until numTrees), + s"Tree indices must start from 0 and increment by 1, but we found $treeIndices.") + trees.map(_._2) + } + + /** + * Given a list of nodes from a tree, construct the tree. + * @param data array of all node data in a tree. + */ + def constructTree(data: Array[NodeData]): Node = { + val dataMap: Map[Int, NodeData] = data.map(n => n.nodeId -> n).toMap + assert(dataMap.contains(1), + s"DecisionTree missing root node (id = 1).") + constructNode(1, dataMap, mutable.Map.empty) + } + + /** + * Builds a node from the node data map and adds new nodes to the input nodes map. + */ + private def constructNode( + id: Int, + dataMap: Map[Int, NodeData], + nodes: mutable.Map[Int, Node]): Node = { + if (nodes.contains(id)) { + return nodes(id) + } + val data = dataMap(id) + val node = + if (data.isLeaf) { + Node(data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf) + } else { + val leftNode = constructNode(data.leftNodeId.get, dataMap, nodes) + val rightNode = constructNode(data.rightNodeId.get, dataMap, nodes) + val stats = new InformationGainStats(data.infoGain.get, data.impurity, leftNode.impurity, + rightNode.impurity, leftNode.predict, rightNode.predict) + new Node(data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf, + data.split.map(_.toSplit), Some(leftNode), Some(rightNode), Some(stats)) + } + nodes += node.id -> node + node + } + } + + override def load(sc: SparkContext, path: String): DecisionTreeModel = { + implicit val formats = DefaultFormats + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + val algo = (metadata \ "algo").extract[String] + val numNodes = (metadata \ "numNodes").extract[Int] + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + SaveLoadV1_0.load(sc, path, algo, numNodes) + case _ => throw new Exception( + s"DecisionTreeModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index 9a50ecb550c38..80990aa9a603f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -49,7 +49,9 @@ class InformationGainStats( gain == other.gain && impurity == other.impurity && leftImpurity == other.leftImpurity && - rightImpurity == other.rightImpurity + rightImpurity == other.rightImpurity && + leftPredict == other.leftPredict && + rightPredict == other.rightPredict } case _ => false } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 2179da8dbe03e..d961081d185e9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -166,6 +166,11 @@ class Node ( } } + /** Returns an iterator that traverses (DFS, left to right) the subtree of this node. */ + private[tree] def subtreeIterator: Iterator[Node] = { + Iterator.single(this) ++ leftNode.map(_.subtreeIterator).getOrElse(Iterator.empty) ++ + rightNode.map(_.subtreeIterator).getOrElse(Iterator.empty) + } } private[tree] object Node { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala index 004838ee5ba0e..ad4c0dbbfb3e5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala @@ -32,4 +32,11 @@ class Predict( override def toString = { "predict = %f, prob = %f".format(predict, prob) } + + override def equals(other: Any): Boolean = { + other match { + case p: Predict => predict == p.predict && prob == p.prob + case _ => false + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 22997110de8dd..30a8f7ca301af 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -20,13 +20,21 @@ package org.apache.spark.mllib.tree.model import scala.collection.mutable import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.configuration.Algo import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._ +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext +import org.apache.spark.util.Utils /** * :: Experimental :: @@ -38,9 +46,42 @@ import org.apache.spark.rdd.RDD @Experimental class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel]) extends TreeEnsembleModel(algo, trees, Array.fill(trees.size)(1.0), - combiningStrategy = if (algo == Classification) Vote else Average) { + combiningStrategy = if (algo == Classification) Vote else Average) + with Saveable { require(trees.forall(_.algo == algo)) + + override def save(sc: SparkContext, path: String): Unit = { + TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this, + RandomForestModel.SaveLoadV1_0.thisClassName) + } + + override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion +} + +object RandomForestModel extends Loader[RandomForestModel] { + + override def load(sc: SparkContext, path: String): RandomForestModel = { + val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path) + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val metadata = TreeEnsembleModel.SaveLoadV1_0.readMetadata(jsonMetadata) + assert(metadata.treeWeights.forall(_ == 1.0)) + val trees = + TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc, path, metadata.treeAlgo) + new RandomForestModel(Algo.fromString(metadata.algo), trees) + case _ => throw new Exception(s"RandomForestModel.load did not recognize model" + + s" with (className, format version): ($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } + + private object SaveLoadV1_0 { + // Hard-code class name string in case it changes in the future + def thisClassName = "org.apache.spark.mllib.tree.model.RandomForestModel" + } + } /** @@ -56,9 +97,42 @@ class GradientBoostedTreesModel( override val algo: Algo, override val trees: Array[DecisionTreeModel], override val treeWeights: Array[Double]) - extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum) { + extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum) + with Saveable { require(trees.size == treeWeights.size) + + override def save(sc: SparkContext, path: String): Unit = { + TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this, + GradientBoostedTreesModel.SaveLoadV1_0.thisClassName) + } + + override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion +} + +object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { + + override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = { + val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path) + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val metadata = TreeEnsembleModel.SaveLoadV1_0.readMetadata(jsonMetadata) + assert(metadata.combiningStrategy == Sum.toString) + val trees = + TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc, path, metadata.treeAlgo) + new GradientBoostedTreesModel(Algo.fromString(metadata.algo), trees, metadata.treeWeights) + case _ => throw new Exception(s"GradientBoostedTreesModel.load did not recognize model" + + s" with (className, format version): ($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } + + private object SaveLoadV1_0 { + // Hard-code class name string in case it changes in the future + def thisClassName = "org.apache.spark.mllib.tree.model.GradientBoostedTreesModel" + } + } /** @@ -176,3 +250,96 @@ private[tree] sealed class TreeEnsembleModel( */ def totalNumNodes: Int = trees.map(_.numNodes).sum } + +private[tree] object TreeEnsembleModel extends Logging { + + object SaveLoadV1_0 { + + import org.apache.spark.mllib.tree.model.DecisionTreeModel.SaveLoadV1_0.{NodeData, constructTrees} + + def thisFormatVersion = "1.0" + + case class Metadata( + algo: String, + treeAlgo: String, + combiningStrategy: String, + treeWeights: Array[Double]) + + /** + * Model data for model import/export. + * We have to duplicate NodeData here since Spark SQL does not yet support extracting subfields + * of nested fields; once that is possible, we can use something like: + * case class EnsembleNodeData(treeId: Int, node: NodeData), + * where NodeData is from DecisionTreeModel. + */ + case class EnsembleNodeData(treeId: Int, node: NodeData) + + def save(sc: SparkContext, path: String, model: TreeEnsembleModel, className: String): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // SPARK-6120: We do a hacky check here so users understand why save() is failing + // when they run the ML guide example. + // TODO: Fix this issue for real. + val memThreshold = 768 + if (sc.isLocal) { + val driverMemory = sc.getConf.getOption("spark.driver.memory") + .orElse(Option(System.getenv("SPARK_DRIVER_MEMORY"))) + .map(Utils.memoryStringToMb) + .getOrElse(512) + if (driverMemory <= memThreshold) { + logWarning(s"$className.save() was called, but it may fail because of too little" + + s" driver memory (${driverMemory}m)." + + s" If failure occurs, try setting driver-memory ${memThreshold}m (or larger).") + } + } else { + if (sc.executorMemory <= memThreshold) { + logWarning(s"$className.save() was called, but it may fail because of too little" + + s" executor memory (${sc.executorMemory}m)." + + s" If failure occurs try setting executor-memory ${memThreshold}m (or larger).") + } + } + + // Create JSON metadata. + implicit val format = DefaultFormats + val ensembleMetadata = Metadata(model.algo.toString, model.trees(0).algo.toString, + model.combiningStrategy.toString, model.treeWeights) + val metadata = compact(render( + ("class" -> className) ~ ("version" -> thisFormatVersion) ~ + ("metadata" -> Extraction.decompose(ensembleMetadata)))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Create Parquet data. + val dataRDD = sc.parallelize(model.trees.zipWithIndex).flatMap { case (tree, treeId) => + tree.topNode.subtreeIterator.toSeq.map(node => NodeData(treeId, node)) + }.toDF() + dataRDD.saveAsParquetFile(Loader.dataPath(path)) + } + + /** + * Read metadata from the loaded JSON metadata. + */ + def readMetadata(metadata: JValue): Metadata = { + implicit val formats = DefaultFormats + (metadata \ "metadata").extract[Metadata] + } + + /** + * Load trees for an ensemble, and return them in order. + * @param path path to load the model from + * @param treeAlgo Algorithm for individual trees (which may differ from the ensemble's + * algorithm). + */ + def loadTrees( + sc: SparkContext, + path: String, + treeAlgo: String): Array[DecisionTreeModel] = { + val datapath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + val nodes = sqlContext.parquetFile(datapath).map(NodeData.apply) + val trees = constructTrees(nodes) + trees.map(new DecisionTreeModel(_, Algo.fromString(treeAlgo))) + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala index f7cba6c6cb628..308f7f3578e21 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.util import java.util.StringTokenizer -import scala.collection.mutable.{ArrayBuffer, ListBuffer} +import scala.collection.mutable.{ArrayBuilder, ListBuffer} import org.apache.spark.SparkException @@ -51,7 +51,7 @@ private[mllib] object NumericParser { } private def parseArray(tokenizer: StringTokenizer): Array[Double] = { - val values = ArrayBuffer.empty[Double] + val values = ArrayBuilder.make[Double] var parsing = true var allowComma = false var token: String = null @@ -67,14 +67,14 @@ private[mllib] object NumericParser { } } else { // expecting a number - values.append(parseDouble(token)) + values += parseDouble(token) allowComma = true } } if (parsing) { throw new SparkException(s"An array must end with ']'.") } - values.toArray + values.result() } private def parseTuple(tokenizer: StringTokenizer): Seq[_] = { @@ -114,7 +114,7 @@ private[mllib] object NumericParser { try { java.lang.Double.parseDouble(s) } catch { - case e: Throwable => + case e: NumberFormatException => throw new SparkException(s"Cannot parse a double from: $s", e) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala new file mode 100644 index 0000000000000..30d642c754b7c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import scala.reflect.runtime.universe.TypeTag + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.types.{DataType, StructField, StructType} + +/** + * :: DeveloperApi :: + * + * Trait for models and transformers which may be saved as files. + * This should be inherited by the class which implements model instances. + */ +@DeveloperApi +trait Saveable { + + /** + * Save this model to the given path. + * + * This saves: + * - human-readable (JSON) model metadata to path/metadata/ + * - Parquet formatted data to path/data/ + * + * The model may be loaded using [[Loader.load]]. + * + * @param sc Spark context used to save model data. + * @param path Path specifying the directory in which to save this model. + * If the directory already exists, this method throws an exception. + */ + def save(sc: SparkContext, path: String): Unit + + /** Current version of model save/load format. */ + protected def formatVersion: String + +} + +/** + * :: DeveloperApi :: + * + * Trait for classes which can load models and transformers from files. + * This should be inherited by an object paired with the model class. + */ +@DeveloperApi +trait Loader[M <: Saveable] { + + /** + * Load a model from the given path. + * + * The model should have been saved by [[Saveable.save]]. + * + * @param sc Spark context used for loading model files. + * @param path Path specifying the directory to which the model was saved. + * @return Model instance + */ + def load(sc: SparkContext, path: String): M + +} + +/** + * Helper methods for loading models from files. + */ +private[mllib] object Loader { + + /** Returns URI for path/data using the Hadoop filesystem */ + def dataPath(path: String): String = new Path(path, "data").toUri.toString + + /** Returns URI for path/metadata using the Hadoop filesystem */ + def metadataPath(path: String): String = new Path(path, "metadata").toUri.toString + + /** + * Check the schema of loaded model data. + * + * This checks every field in the expected schema to make sure that a field with the same + * name and DataType appears in the loaded schema. Note that this does NOT check metadata + * or containsNull. + * + * @param loadedSchema Schema for model data loaded from file. + * @tparam Data Expected data type from which an expected schema can be derived. + */ + def checkSchema[Data: TypeTag](loadedSchema: StructType): Unit = { + // Check schema explicitly since erasure makes it hard to use match-case for checking. + val expectedFields: Array[StructField] = + ScalaReflection.schemaFor[Data].dataType.asInstanceOf[StructType].fields + val loadedFields: Map[String, DataType] = + loadedSchema.map(field => field.name -> field.dataType).toMap + expectedFields.foreach { field => + assert(loadedFields.contains(field.name), s"Unable to parse model data." + + s" Expected field with name ${field.name} was missing in loaded schema:" + + s" ${loadedFields.mkString(", ")}") + assert(loadedFields(field.name).sameType(field.dataType), + s"Unable to parse model data. Expected field $field but found field" + + s" with different type: ${loadedFields(field.name)}") + } + } + + /** + * Load metadata from the given path. + * @return (class name, version, metadata) + */ + def loadMetadata(sc: SparkContext, path: String): (String, String, JValue) = { + implicit val formats = DefaultFormats + val metadata = parse(sc.textFile(metadataPath(path)).first()) + val clazz = (metadata \ "class").extract[String] + val version = (metadata \ "version").extract[String] + (clazz, version, metadata) + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java index 56a9dbdd58b64..0a8c9e5954676 100644 --- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java @@ -45,7 +45,7 @@ public void setUp() { jsql = new SQLContext(jsc); JavaRDD points = jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2); - dataset = jsql.applySchema(points, LabeledPoint.class); + dataset = jsql.createDataFrame(points, LabeledPoint.class); } @After @@ -65,7 +65,7 @@ public void pipeline() { .setStages(new PipelineStage[] {scaler, lr}); PipelineModel model = pipeline.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); predictions.collectAsList(); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index f4ba23c44563e..3f8e59de0f05c 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -18,17 +18,22 @@ package org.apache.spark.ml.classification; import java.io.Serializable; +import java.lang.Math; import java.util.List; import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; +import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; -import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; +import org.apache.spark.sql.Row; + public class JavaLogisticRegressionSuite implements Serializable { @@ -36,12 +41,17 @@ public class JavaLogisticRegressionSuite implements Serializable { private transient SQLContext jsql; private transient DataFrame dataset; + private transient JavaRDD datasetRDD; + private double eps = 1e-5; + @Before public void setUp() { jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); jsql = new SQLContext(jsc); List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); - dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class); + datasetRDD = jsc.parallelize(points, 2); + dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); + dataset.registerTempTable("dataset"); } @After @@ -51,29 +61,88 @@ public void tearDown() { } @Test - public void logisticRegression() { + public void logisticRegressionDefaultParams() { LogisticRegression lr = new LogisticRegression(); + assert(lr.getLabelCol().equals("label")); LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); predictions.collectAsList(); + // Check defaults + assert(model.getThreshold() == 0.5); + assert(model.getFeaturesCol().equals("features")); + assert(model.getPredictionCol().equals("prediction")); + assert(model.getProbabilityCol().equals("probability")); } @Test public void logisticRegressionWithSetters() { + // Set params, train, and check as many params as we can. LogisticRegression lr = new LogisticRegression() .setMaxIter(10) - .setRegParam(1.0); + .setRegParam(1.0) + .setThreshold(0.6) + .setProbabilityCol("myProbability"); LogisticRegressionModel model = lr.fit(dataset); - model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold - .registerTempTable("prediction"); - DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); - predictions.collectAsList(); + assert(model.fittingParamMap().apply(lr.maxIter()).equals(10)); + assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0)); + assert(model.fittingParamMap().apply(lr.threshold()).equals(0.6)); + assert(model.getThreshold() == 0.6); + + // Modify model params, and check that the params worked. + model.setThreshold(1.0); + model.transform(dataset).registerTempTable("predAllZero"); + DataFrame predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero"); + for (Row r: predAllZero.collectAsList()) { + assert(r.getDouble(0) == 0.0); + } + // Call transform with params, and check that the params worked. + model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb")) + .registerTempTable("predNotAllZero"); + DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero"); + boolean foundNonZero = false; + for (Row r: predNotAllZero.collectAsList()) { + if (r.getDouble(0) != 0.0) foundNonZero = true; + } + assert(foundNonZero); + + // Call fit() with new params, and check as many params as we can. + LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), + lr.threshold().w(0.4), lr.probabilityCol().w("theProb")); + assert(model2.fittingParamMap().apply(lr.maxIter()).equals(5)); + assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1)); + assert(model2.fittingParamMap().apply(lr.threshold()).equals(0.4)); + assert(model2.getThreshold() == 0.4); + assert(model2.getProbabilityCol().equals("theProb")); } + @SuppressWarnings("unchecked") @Test - public void logisticRegressionFitWithVarargs() { + public void logisticRegressionPredictorClassifierMethods() { LogisticRegression lr = new LogisticRegression(); - lr.fit(dataset, lr.maxIter().w(10), lr.regParam().w(1.0)); + LogisticRegressionModel model = lr.fit(dataset); + assert(model.numClasses() == 2); + + model.transform(dataset).registerTempTable("transformed"); + DataFrame trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed"); + for (Row row: trans1.collect()) { + Vector raw = (Vector)row.get(0); + Vector prob = (Vector)row.get(1); + assert(raw.size() == 2); + assert(prob.size() == 2); + double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1))); + assert(Math.abs(prob.apply(1) - probFromRaw1) < eps); + assert(Math.abs(prob.apply(0) - (1.0 - probFromRaw1)) < eps); + } + + DataFrame trans2 = jsql.sql("SELECT prediction, probability FROM transformed"); + for (Row row: trans2.collect()) { + double pred = row.getDouble(0); + Vector prob = (Vector)row.get(1); + double probOfPred = prob.apply((int)pred); + for (int i = 0; i < prob.size(); ++i) { + assert(probOfPred >= prob.apply(i)); + } + } } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java new file mode 100644 index 0000000000000..640d2ec55e4e7 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.Serializable; +import java.util.List; + +import scala.Tuple2; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.SparkConf; +import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import static org.apache.spark.streaming.JavaTestUtils.*; + +public class JavaStreamingLogisticRegressionSuite implements Serializable { + + protected transient JavaStreamingContext ssc; + + @Before + public void setUp() { + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); + ssc = new JavaStreamingContext(conf, new Duration(1000)); + ssc.checkpoint("checkpoint"); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + } + + @Test + @SuppressWarnings("unchecked") + public void javaAPI() { + List trainingBatch = Lists.newArrayList( + new LabeledPoint(1.0, Vectors.dense(1.0)), + new LabeledPoint(0.0, Vectors.dense(0.0))); + JavaDStream training = + attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2); + List> testBatch = Lists.newArrayList( + new Tuple2(10, Vectors.dense(1.0)), + new Tuple2(11, Vectors.dense(0.0))); + JavaPairDStream test = JavaPairDStream.fromJavaDStream( + attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2)); + StreamingLogisticRegressionWithSGD slr = new StreamingLogisticRegressionWithSGD() + .setNumIterations(2) + .setInitialWeights(Vectors.dense(0.0)); + slr.trainOn(training); + JavaPairDStream prediction = slr.predictOnValues(test); + attachTestOutputStream(prediction.count()); + runStreams(ssc, 2, 2); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java new file mode 100644 index 0000000000000..0cc36c8d56d70 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite + .generateLogisticInputAsList; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + + +public class JavaLinearRegressionSuite implements Serializable { + + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + private transient DataFrame dataset; + private transient JavaRDD datasetRDD; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaLinearRegressionSuite"); + jsql = new SQLContext(jsc); + List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); + datasetRDD = jsc.parallelize(points, 2); + dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); + dataset.registerTempTable("dataset"); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void linearRegressionDefaultParams() { + LinearRegression lr = new LinearRegression(); + assert(lr.getLabelCol().equals("label")); + LinearRegressionModel model = lr.fit(dataset); + model.transform(dataset).registerTempTable("prediction"); + DataFrame predictions = jsql.sql("SELECT label, prediction FROM prediction"); + predictions.collect(); + // Check defaults + assert(model.getFeaturesCol().equals("features")); + assert(model.getPredictionCol().equals("prediction")); + } + + @Test + public void linearRegressionWithSetters() { + // Set params, train, and check as many params as we can. + LinearRegression lr = new LinearRegression() + .setMaxIter(10) + .setRegParam(1.0); + LinearRegressionModel model = lr.fit(dataset); + assert(model.fittingParamMap().apply(lr.maxIter()).equals(10)); + assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0)); + + // Call fit() with new params, and check as many params as we can. + LinearRegressionModel model2 = + lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred")); + assert(model2.fittingParamMap().apply(lr.maxIter()).equals(5)); + assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1)); + assert(model2.getPredictionCol().equals("thePred")); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java index 074b58c07df7a..0bb6b489f2757 100644 --- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java @@ -45,7 +45,7 @@ public void setUp() { jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite"); jsql = new SQLContext(jsc); List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); - dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class); + dataset = jsql.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class); } @After diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java new file mode 100644 index 0000000000000..bd0edf2b9ea62 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.fpm; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import com.google.common.collect.Lists; +import static org.junit.Assert.*; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset; + +public class JavaFPGrowthSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaFPGrowth"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runFPGrowth() { + + @SuppressWarnings("unchecked") + JavaRDD> rdd = sc.parallelize(Lists.newArrayList( + Lists.newArrayList("r z h k p".split(" ")), + Lists.newArrayList("z y x w v u t s".split(" ")), + Lists.newArrayList("s x o n r".split(" ")), + Lists.newArrayList("x z y m t s q e".split(" ")), + Lists.newArrayList("z".split(" ")), + Lists.newArrayList("x z y r q t p".split(" "))), 2); + + FPGrowthModel model = new FPGrowth() + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd); + + List> freqItemsets = model.freqItemsets().toJavaRDD().collect(); + assertEquals(18, freqItemsets.size()); + + for (FreqItemset itemset: freqItemsets) { + // Test return types. + List items = itemset.javaItems(); + long freq = itemset.freq(); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java index 704d484d0b585..3349c5022423a 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java @@ -71,8 +71,8 @@ public void diagonalMatrixConstruction() { Matrix sm = Matrices.diag(sv); DenseMatrix d = DenseMatrix.diag(v); DenseMatrix sd = DenseMatrix.diag(sv); - SparseMatrix s = SparseMatrix.diag(v); - SparseMatrix ss = SparseMatrix.diag(sv); + SparseMatrix s = SparseMatrix.spdiag(v); + SparseMatrix ss = SparseMatrix.spdiag(sv); assertArrayEquals(m.toArray(), sm.toArray(), 0.0); assertArrayEquals(d.toArray(), sm.toArray(), 0.0); diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java new file mode 100644 index 0000000000000..899c4ea607869 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression; + +import java.io.Serializable; +import java.util.List; + +import scala.Tuple2; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.SparkConf; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import static org.apache.spark.streaming.JavaTestUtils.*; + +public class JavaStreamingLinearRegressionSuite implements Serializable { + + protected transient JavaStreamingContext ssc; + + @Before + public void setUp() { + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); + ssc = new JavaStreamingContext(conf, new Duration(1000)); + ssc.checkpoint("checkpoint"); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + } + + @Test + @SuppressWarnings("unchecked") + public void javaAPI() { + List trainingBatch = Lists.newArrayList( + new LabeledPoint(1.0, Vectors.dense(1.0)), + new LabeledPoint(0.0, Vectors.dense(0.0))); + JavaDStream training = + attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2); + List> testBatch = Lists.newArrayList( + new Tuple2(10, Vectors.dense(1.0)), + new Tuple2(11, Vectors.dense(0.0))); + JavaPairDStream test = JavaPairDStream.fromJavaDStream( + attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2)); + StreamingLinearRegressionWithSGD slr = new StreamingLinearRegressionWithSGD() + .setNumIterations(2) + .setInitialWeights(Vectors.dense(0.0)); + slr.trainOn(training); + JavaPairDStream prediction = slr.predictOnValues(test); + attachTestOutputStream(prediction.count()); + runStreams(ssc, 2, 2); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 33e40dc7410cc..b3d1bfcfbee0f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -20,44 +20,108 @@ package org.apache.spark.ml.classification import org.scalatest.FunSuite import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{SQLContext, DataFrame} +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { @transient var sqlContext: SQLContext = _ @transient var dataset: DataFrame = _ + private val eps: Double = 1e-5 override def beforeAll(): Unit = { super.beforeAll() sqlContext = new SQLContext(sc) dataset = sqlContext.createDataFrame( - sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) + sc.parallelize(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 2)) } - test("logistic regression") { + test("logistic regression: default params") { val lr = new LogisticRegression + assert(lr.getLabelCol == "label") + assert(lr.getFeaturesCol == "features") + assert(lr.getPredictionCol == "prediction") + assert(lr.getRawPredictionCol == "rawPrediction") + assert(lr.getProbabilityCol == "probability") val model = lr.fit(dataset) model.transform(dataset) - .select("label", "prediction") + .select("label", "probability", "prediction", "rawPrediction") .collect() + assert(model.getThreshold === 0.5) + assert(model.getFeaturesCol == "features") + assert(model.getPredictionCol == "prediction") + assert(model.getRawPredictionCol == "rawPrediction") + assert(model.getProbabilityCol == "probability") } test("logistic regression with setters") { + // Set params, train, and check as many params as we can. val lr = new LogisticRegression() .setMaxIter(10) .setRegParam(1.0) + .setThreshold(0.6) + .setProbabilityCol("myProbability") val model = lr.fit(dataset) - model.transform(dataset, model.threshold -> 0.8) // overwrite threshold - .select("label", "score", "prediction") + assert(model.fittingParamMap.get(lr.maxIter) === Some(10)) + assert(model.fittingParamMap.get(lr.regParam) === Some(1.0)) + assert(model.fittingParamMap.get(lr.threshold) === Some(0.6)) + assert(model.getThreshold === 0.6) + + // Modify model params, and check that the params worked. + model.setThreshold(1.0) + val predAllZero = model.transform(dataset) + .select("prediction", "myProbability") .collect() + .map { case Row(pred: Double, prob: Vector) => pred } + assert(predAllZero.forall(_ === 0), + s"With threshold=1.0, expected predictions to be all 0, but only" + + s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.") + // Call transform with params, and check that the params worked. + val predNotAllZero = + model.transform(dataset, model.threshold -> 0.0, model.probabilityCol -> "myProb") + .select("prediction", "myProb") + .collect() + .map { case Row(pred: Double, prob: Vector) => pred } + assert(predNotAllZero.exists(_ !== 0.0)) + + // Call fit() with new params, and check as many params as we can. + val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.threshold -> 0.4, + lr.probabilityCol -> "theProb") + assert(model2.fittingParamMap.get(lr.maxIter).get === 5) + assert(model2.fittingParamMap.get(lr.regParam).get === 0.1) + assert(model2.fittingParamMap.get(lr.threshold).get === 0.4) + assert(model2.getThreshold === 0.4) + assert(model2.getProbabilityCol == "theProb") } - test("logistic regression fit and transform with varargs") { + test("logistic regression: Predictor, Classifier methods") { + val sqlContext = this.sqlContext val lr = new LogisticRegression - val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0) - model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability") - .select("label", "probability", "prediction") - .collect() + + val model = lr.fit(dataset) + assert(model.numClasses === 2) + + val threshold = model.getThreshold + val results = model.transform(dataset) + + // Compare rawPrediction with probability + results.select("rawPrediction", "probability").collect().map { + case Row(raw: Vector, prob: Vector) => + assert(raw.size === 2) + assert(prob.size === 2) + val probFromRaw1 = 1.0 / (1.0 + math.exp(-raw(1))) + assert(prob(1) ~== probFromRaw1 relTol eps) + assert(prob(0) ~== 1.0 - probFromRaw1 relTol eps) + } + + // Compare prediction with probability + results.select("prediction", "probability").collect().map { + case Row(pred: Double, prob: Vector) => + val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2 + assert(pred == predFromProb) + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index acc447742bad0..29d4ec5f85c1e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.recommendation +import java.io.File import java.util.Random import scala.collection.mutable @@ -25,23 +26,32 @@ import scala.collection.mutable.ArrayBuffer import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.scalatest.FunSuite -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkException} import org.apache.spark.ml.recommendation.ALS._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.util.Utils class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { private var sqlContext: SQLContext = _ + private var tempDir: File = _ override def beforeAll(): Unit = { super.beforeAll() + tempDir = Utils.createTempDir() + sc.setCheckpointDir(tempDir.getAbsolutePath) sqlContext = new SQLContext(sc) } + override def afterAll(): Unit = { + Utils.deleteRecursively(tempDir) + super.afterAll() + } + test("LocalIndexEncoder") { val random = new Random for (numBlocks <- Seq(1, 2, 5, 10, 20, 50, 100)) { @@ -58,39 +68,42 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { } } - test("normal equation construction with explict feedback") { + test("normal equation construction") { val k = 2 val ne0 = new NormalEquation(k) - .add(Array(1.0f, 2.0f), 3.0f) - .add(Array(4.0f, 5.0f), 6.0f) + .add(Array(1.0f, 2.0f), 3.0) + .add(Array(4.0f, 5.0f), 6.0, 2.0) // weighted assert(ne0.k === k) assert(ne0.triK === k * (k + 1) / 2) - assert(ne0.n === 2) // NumPy code that computes the expected values: // A = np.matrix("1 2; 4 5") // b = np.matrix("3; 6") - // ata = A.transpose() * A - // atb = A.transpose() * b - assert(Vectors.dense(ne0.ata) ~== Vectors.dense(17.0, 22.0, 29.0) relTol 1e-8) - assert(Vectors.dense(ne0.atb) ~== Vectors.dense(27.0, 36.0) relTol 1e-8) + // C = np.matrix(np.diag([1, 2])) + // ata = A.transpose() * C * A + // atb = A.transpose() * C * b + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(33.0, 42.0, 54.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(51.0, 66.0) relTol 1e-8) val ne1 = new NormalEquation(2) - .add(Array(7.0f, 8.0f), 9.0f) + .add(Array(7.0f, 8.0f), 9.0) ne0.merge(ne1) - assert(ne0.n === 3) // NumPy code that computes the expected values: // A = np.matrix("1 2; 4 5; 7 8") // b = np.matrix("3; 6; 9") - // ata = A.transpose() * A - // atb = A.transpose() * b - assert(Vectors.dense(ne0.ata) ~== Vectors.dense(66.0, 78.0, 93.0) relTol 1e-8) - assert(Vectors.dense(ne0.atb) ~== Vectors.dense(90.0, 108.0) relTol 1e-8) + // C = np.matrix(np.diag([1, 2, 1])) + // ata = A.transpose() * C * A + // atb = A.transpose() * C * b + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(82.0, 98.0, 118.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(114.0, 138.0) relTol 1e-8) intercept[IllegalArgumentException] { - ne0.add(Array(1.0f), 2.0f) + ne0.add(Array(1.0f), 2.0) + } + intercept[IllegalArgumentException] { + ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0) } intercept[IllegalArgumentException] { - ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0f) + ne0.add(Array(1.0f, 2.0f), 0.0, -1.0) } intercept[IllegalArgumentException] { val ne2 = new NormalEquation(3) @@ -98,41 +111,16 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { } ne0.reset() - assert(ne0.n === 0) assert(ne0.ata.forall(_ == 0.0)) assert(ne0.atb.forall(_ == 0.0)) } - test("normal equation construction with implicit feedback") { - val k = 2 - val alpha = 0.5 - val ne0 = new NormalEquation(k) - .addImplicit(Array(-5.0f, -4.0f), -3.0f, alpha) - .addImplicit(Array(-2.0f, -1.0f), 0.0f, alpha) - .addImplicit(Array(1.0f, 2.0f), 3.0f, alpha) - assert(ne0.k === k) - assert(ne0.triK === k * (k + 1) / 2) - assert(ne0.n === 0) // addImplicit doesn't increase the count. - // NumPy code that computes the expected values: - // alpha = 0.5 - // A = np.matrix("-5 -4; -2 -1; 1 2") - // b = np.matrix("-3; 0; 3") - // b1 = b > 0 - // c = 1.0 + alpha * np.abs(b) - // C = np.diag(c.A1) - // I = np.eye(3) - // ata = A.transpose() * (C - I) * A - // atb = A.transpose() * C * b1 - assert(Vectors.dense(ne0.ata) ~== Vectors.dense(39.0, 33.0, 30.0) relTol 1e-8) - assert(Vectors.dense(ne0.atb) ~== Vectors.dense(2.5, 5.0) relTol 1e-8) - } - test("CholeskySolver") { val k = 2 val ne0 = new NormalEquation(k) - .add(Array(1.0f, 2.0f), 4.0f) - .add(Array(1.0f, 3.0f), 9.0f) - .add(Array(1.0f, 4.0f), 16.0f) + .add(Array(1.0f, 2.0f), 4.0) + .add(Array(1.0f, 3.0f), 9.0) + .add(Array(1.0f, 4.0f), 16.0) val ne1 = new NormalEquation(k) .merge(ne0) @@ -144,13 +132,12 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { // x0 = np.linalg.lstsq(A, b)[0] assert(Vectors.dense(x0) ~== Vectors.dense(-8.333333, 6.0) relTol 1e-6) - assert(ne0.n === 0) assert(ne0.ata.forall(_ == 0.0)) assert(ne0.atb.forall(_ == 0.0)) - val x1 = chol.solve(ne1, 0.5).map(_.toDouble) + val x1 = chol.solve(ne1, 1.5).map(_.toDouble) // NumPy code that computes the expected solution, where lambda is scaled by n: - // x0 = np.linalg.solve(A.transpose() * A + 0.5 * 3 * np.eye(2), A.transpose() * b) + // x0 = np.linalg.solve(A.transpose() * A + 1.5 * np.eye(2), A.transpose() * b) assert(Vectors.dense(x1) ~== Vectors.dense(-0.1155556, 3.28) relTol 1e-6) } @@ -350,7 +337,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { numItemBlocks: Int = 3, targetRMSE: Double = 0.05): Unit = { val sqlContext = this.sqlContext - import sqlContext.createDataFrame + import sqlContext.implicits._ val als = new ALS() .setRank(rank) .setRegParam(regParam) @@ -358,8 +345,8 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { .setNumUserBlocks(numUserBlocks) .setNumItemBlocks(numItemBlocks) val alpha = als.getAlpha - val model = als.fit(training) - val predictions = model.transform(test) + val model = als.fit(training.toDF()) + val predictions = model.transform(test.toDF()) .select("rating", "prediction") .map { case Row(rating: Float, prediction: Float) => (rating.toDouble, prediction.toDouble) @@ -455,4 +442,41 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { assert(isNonnegative(itemFactors)) // TODO: Validate the solution. } + + test("als partitioner is a projection") { + for (p <- Seq(1, 10, 100, 1000)) { + val part = new ALSPartitioner(p) + var k = 0 + while (k < p) { + assert(k === part.getPartition(k)) + assert(k === part.getPartition(k.toLong)) + k += 1 + } + } + } + + test("partitioner in returned factors") { + val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) + val (userFactors, itemFactors) = ALS.train( + ratings, rank = 2, maxIter = 4, numUserBlocks = 3, numItemBlocks = 4) + for ((tpe, factors) <- Seq(("User", userFactors), ("Item", itemFactors))) { + assert(userFactors.partitioner.isDefined, s"$tpe factors should have partitioner.") + val part = userFactors.partitioner.get + userFactors.mapPartitionsWithIndex { (idx, items) => + items.foreach { case (id, _) => + if (part.getPartition(id) != idx) { + throw new SparkException(s"$tpe with ID $id should not be in partition $idx.") + } + } + Iterator.empty + }.count() + } + } + + test("als with large number of iterations") { + val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) + ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2) + ALS.train( + ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2, implicitPrefs = true) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala new file mode 100644 index 0000000000000..bbb44c3e2dfc2 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, SQLContext} + +class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { + + @transient var sqlContext: SQLContext = _ + @transient var dataset: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + dataset = sqlContext.createDataFrame( + sc.parallelize(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 2)) + } + + test("linear regression: default params") { + val lr = new LinearRegression + assert(lr.getLabelCol == "label") + val model = lr.fit(dataset) + model.transform(dataset) + .select("label", "prediction") + .collect() + // Check defaults + assert(model.getFeaturesCol == "features") + assert(model.getPredictionCol == "prediction") + } + + test("linear regression with setters") { + // Set params, train, and check as many as we can. + val lr = new LinearRegression() + .setMaxIter(10) + .setRegParam(1.0) + val model = lr.fit(dataset) + assert(model.fittingParamMap.get(lr.maxIter).get === 10) + assert(model.fittingParamMap.get(lr.regParam).get === 1.0) + + // Call fit() with new params, and check as many as we can. + val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.predictionCol -> "thePred") + assert(model2.fittingParamMap.get(lr.maxIter).get === 5) + assert(model2.fittingParamMap.get(lr.regParam).get === 0.1) + assert(model2.getPredictionCol == "thePred") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 3fb45938f75db..a26c52852c4d7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.mllib.classification -import scala.util.control.Breaks._ -import scala.util.Random import scala.collection.JavaConversions._ +import scala.util.Random +import scala.util.control.Breaks._ import org.scalatest.FunSuite import org.scalatest.Matchers @@ -28,6 +28,8 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils + object LogisticRegressionSuite { @@ -147,8 +149,25 @@ object LogisticRegressionSuite { val testData = (0 until nPoints).map(i => LabeledPoint(y(i), x(i))) testData } + + /** Binary labels, 3 features */ + private val binaryModel = new LogisticRegressionModel( + weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5, numFeatures = 3, numClasses = 2) + + /** 3 classes, 2 features */ + private val multiclassModel = new LogisticRegressionModel( + weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, numFeatures = 2, numClasses = 3) + + private def checkModelsEqual(a: LogisticRegressionModel, b: LogisticRegressionModel): Unit = { + assert(a.weights == b.weights) + assert(a.intercept == b.intercept) + assert(a.numClasses == b.numClasses) + assert(a.numFeatures == b.numFeatures) + assert(a.getThreshold == b.getThreshold) + } } + class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers { def validatePrediction( predictions: Seq[Double], @@ -353,8 +372,12 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M testRDD2.cache() testRDD3.cache() + val numIteration = 10 + val lrA = new LogisticRegressionWithLBFGS().setIntercept(true) + lrA.optimizer.setNumIterations(numIteration) val lrB = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(false) + lrB.optimizer.setNumIterations(numIteration) val modelA1 = lrA.run(testRDD1, initialWeights) val modelA2 = lrA.run(testRDD2, initialWeights) @@ -402,6 +425,12 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M val model = lr.run(testRDD) + val numFeatures = testRDD.map(_.features.size).first() + val initialWeights = Vectors.dense(new Array[Double]((numFeatures + 1) * 2)) + val model2 = lr.run(testRDD, initialWeights) + + LogisticRegressionSuite.checkModelsEqual(model, model2) + /** * The following is the instruction to reproduce the model using R's glmnet package. * @@ -462,6 +491,53 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M } + test("model save/load: binary classification") { + // NOTE: This will need to be generalized once there are multiple model format versions. + val model = LogisticRegressionSuite.binaryModel + + model.clearThreshold() + assert(model.getThreshold.isEmpty) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = LogisticRegressionModel.load(sc, path) + LogisticRegressionSuite.checkModelsEqual(model, sameModel) + } finally { + Utils.deleteRecursively(tempDir) + } + + // Save model with threshold. + try { + model.setThreshold(0.7) + model.save(sc, path) + val sameModel = LogisticRegressionModel.load(sc, path) + LogisticRegressionSuite.checkModelsEqual(model, sameModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + + test("model save/load: multiclass classification") { + // NOTE: This will need to be generalized once there are multiple model format versions. + val model = LogisticRegressionSuite.multiclassModel + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = LogisticRegressionModel.load(sc, path) + LogisticRegressionSuite.checkModelsEqual(model, sameModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + } class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index e68fe89d6ccea..64dcc0fb9f82c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -25,6 +25,8 @@ import org.apache.spark.SparkException import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} +import org.apache.spark.util.Utils + object NaiveBayesSuite { @@ -58,6 +60,18 @@ object NaiveBayesSuite { LabeledPoint(y, Vectors.dense(xi)) } } + + private val smallPi = Array(0.5, 0.3, 0.2).map(math.log) + + private val smallTheta = Array( + Array(0.91, 0.03, 0.03, 0.03), // label 0 + Array(0.03, 0.91, 0.03, 0.03), // label 1 + Array(0.03, 0.03, 0.91, 0.03) // label 2 + ).map(_.map(math.log)) + + /** Binary labels, 3 features */ + private val binaryModel = new NaiveBayesModel(labels = Array(0.0, 1.0), pi = Array(0.2, 0.8), + theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4))) } class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { @@ -74,12 +88,8 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { test("Naive Bayes") { val nPoints = 10000 - val pi = Array(0.5, 0.3, 0.2).map(math.log) - val theta = Array( - Array(0.91, 0.03, 0.03, 0.03), // label 0 - Array(0.03, 0.91, 0.03, 0.03), // label 1 - Array(0.03, 0.03, 0.91, 0.03) // label 2 - ).map(_.map(math.log)) + val pi = NaiveBayesSuite.smallPi + val theta = NaiveBayesSuite.smallTheta val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42) val testRDD = sc.parallelize(testData, 2) @@ -123,6 +133,24 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { NaiveBayes.train(sc.makeRDD(nan, 2)) } } + + test("model save/load") { + val model = NaiveBayesSuite.binaryModel + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = NaiveBayesModel.load(sc, path) + assert(model.labels === sameModel.labels) + assert(model.pi === sameModel.pi) + assert(model.theta === sameModel.theta) + } finally { + Utils.deleteRecursively(tempDir) + } + } } class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index a2de7fbd41383..6de098b383ba3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.SparkException import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} +import org.apache.spark.util.Utils object SVMSuite { @@ -56,6 +57,9 @@ object SVMSuite { y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2))) } + /** Binary labels, 3 features */ + private val binaryModel = new SVMModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) + } class SVMSuite extends FunSuite with MLlibTestSparkContext { @@ -191,6 +195,38 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { // Turning off data validation should not throw an exception new SVMWithSGD().setValidateData(false).run(testRDDInvalid) } + + test("model save/load") { + // NOTE: This will need to be generalized once there are multiple model format versions. + val model = SVMSuite.binaryModel + + model.clearThreshold() + assert(model.getThreshold.isEmpty) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = SVMModel.load(sc, path) + assert(model.weights == sameModel.weights) + assert(model.intercept == sameModel.intercept) + assert(sameModel.getThreshold.isEmpty) + } finally { + Utils.deleteRecursively(tempDir) + } + + // Save model with threshold. + try { + model.setThreshold(0.7) + model.save(sc, path) + val sameModel2 = SVMModel.load(sc, path) + assert(model.getThreshold.get == sameModel2.getThreshold.get) + } finally { + Utils.deleteRecursively(tempDir) + } + } } class SVMClusterSuite extends FunSuite with LocalClusterSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala index 8b3e6e5ce9249..d50c43d439187 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -132,4 +132,31 @@ class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase { assert(errors.forall(x => x <= 0.4)) } + // Test training combined with prediction + test("training and prediction") { + // create model initialized with zero weights + val model = new StreamingLogisticRegressionWithSGD() + .setInitialWeights(Vectors.dense(-0.1)) + .setStepSize(0.01) + .setNumIterations(10) + + // generate sequence of simulated data for testing + val numBatches = 10 + val nPoints = 100 + val testInput = (0 until numBatches).map { i => + LogisticRegressionSuite.generateLogisticInput(0.0, 5.0, nPoints, 42 * (i + 1)) + } + + // train and predict + val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + }) + + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + + // assert that prediction error improves, ensuring that the updated model is being used + val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList + assert(error.head > 0.8 & error.last < 0.2) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala index c2cd56ea40adc..1b46a4012d731 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala @@ -31,7 +31,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { Vectors.dense(5.0, 10.0), Vectors.dense(4.0, 11.0) )) - + // expectations val Ew = 1.0 val Emu = Vectors.dense(5.0, 10.0) @@ -44,6 +44,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { assert(gmm.gaussians(0).mu ~== Emu absTol 1E-5) assert(gmm.gaussians(0).sigma ~== Esigma absTol 1E-5) } + } test("two clusters") { @@ -54,7 +55,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026), Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734) )) - + // we set an initial gaussian to induce expected results val initialGmm = new GaussianMixtureModel( Array(0.5, 0.5), @@ -63,7 +64,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { new MultivariateGaussian(Vectors.dense(1.0), Matrices.dense(1, 1, Array(1.0))) ) ) - + val Ew = Array(1.0 / 3.0, 2.0 / 3.0) val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604)) val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644))) @@ -72,7 +73,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { .setK(2) .setInitialModel(initialGmm) .run(data) - + assert(gmm.weights(0) ~== Ew(0) absTol 1E-3) assert(gmm.weights(1) ~== Ew(1) absTol 1E-3) assert(gmm.gaussians(0).mu ~== Emu(0) absTol 1E-3) @@ -80,4 +81,61 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { assert(gmm.gaussians(0).sigma ~== Esigma(0) absTol 1E-3) assert(gmm.gaussians(1).sigma ~== Esigma(1) absTol 1E-3) } + + test("single cluster with sparse data") { + val data = sc.parallelize(Array( + Vectors.sparse(3, Array(0, 2), Array(4.0, 2.0)), + Vectors.sparse(3, Array(0, 2), Array(2.0, 4.0)), + Vectors.sparse(3, Array(1), Array(6.0)) + )) + + val Ew = 1.0 + val Emu = Vectors.dense(2.0, 2.0, 2.0) + val Esigma = Matrices.dense(3, 3, + Array(8.0 / 3.0, -4.0, 4.0 / 3.0, -4.0, 8.0, -4.0, 4.0 / 3.0, -4.0, 8.0 / 3.0) + ) + + val seeds = Array(42, 1994, 27, 11, 0) + seeds.foreach { seed => + val gmm = new GaussianMixture().setK(1).setSeed(seed).run(data) + assert(gmm.weights(0) ~== Ew absTol 1E-5) + assert(gmm.gaussians(0).mu ~== Emu absTol 1E-5) + assert(gmm.gaussians(0).sigma ~== Esigma absTol 1E-5) + } + } + + test("two clusters with sparse data") { + val data = sc.parallelize(Array( + Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220), + Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118), + Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322), + Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026), + Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734) + )) + + val sparseData = data.map(point => Vectors.sparse(1, Array(0), point.toArray)) + // we set an initial gaussian to induce expected results + val initialGmm = new GaussianMixtureModel( + Array(0.5, 0.5), + Array( + new MultivariateGaussian(Vectors.dense(-1.0), Matrices.dense(1, 1, Array(1.0))), + new MultivariateGaussian(Vectors.dense(1.0), Matrices.dense(1, 1, Array(1.0))) + ) + ) + val Ew = Array(1.0 / 3.0, 2.0 / 3.0) + val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604)) + val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644))) + + val sparseGMM = new GaussianMixture() + .setK(2) + .setInitialModel(initialGmm) + .run(data) + + assert(sparseGMM.weights(0) ~== Ew(0) absTol 1E-3) + assert(sparseGMM.weights(1) ~== Ew(1) absTol 1E-3) + assert(sparseGMM.gaussians(0).mu ~== Emu(0) absTol 1E-3) + assert(sparseGMM.gaussians(1).mu ~== Emu(1) absTol 1E-3) + assert(sparseGMM.gaussians(0).sigma ~== Esigma(0) absTol 1E-3) + assert(sparseGMM.gaussians(1).sigma ~== Esigma(1) absTol 1E-3) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala index 03ecd9ca730be..6315c03a700f1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala @@ -51,8 +51,8 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext .setK(2) .run(sc.parallelize(similarities, 2)) val predictions = Array.fill(2)(mutable.Set.empty[Long]) - model.assignments.collect().foreach { case (i, c) => - predictions(c) += i + model.assignments.collect().foreach { a => + predictions(a.cluster) += a.id } assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) @@ -61,8 +61,8 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext .setInitializationMode("degree") .run(sc.parallelize(similarities, 2)) val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) - model2.assignments.collect().foreach { case (i, c) => - predictions2(c) += i + model2.assignments.collect().foreach { a => + predictions2(a.cluster) += a.id } assert(predictions2.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index 71ef60da6dd32..bd5b9cc3afa10 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -22,7 +22,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { - test("FP-Growth") { + + test("FP-Growth using String type") { val transactions = Seq( "r z h k p", "z y x w v u t s", @@ -45,8 +46,8 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { .setMinSupport(0.5) .setNumPartitions(2) .run(rdd) - val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) => - (items.toSet, count) + val freqItemsets3 = model3.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) } val expected = Set( (Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L), @@ -70,4 +71,52 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { .run(rdd) assert(model1.freqItemsets.count() === 625) } + + test("FP-Growth using Int type") { + val transactions = Seq( + "1 2 3", + "1 2 3 4", + "5 4 3 2 1", + "6 5 4 3 2 1", + "2 4", + "1 3", + "1 7") + .map(_.split(" ").map(_.toInt).toArray) + val rdd = sc.parallelize(transactions, 2).cache() + + val fpg = new FPGrowth() + + val model6 = fpg + .setMinSupport(0.9) + .setNumPartitions(1) + .run(rdd) + assert(model6.freqItemsets.count() === 0) + + val model3 = fpg + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd) + assert(model3.freqItemsets.first().items.getClass === Array(1).getClass, + "frequent itemsets should use primitive arrays") + val freqItemsets3 = model3.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) + } + val expected = Set( + (Set(1), 6L), (Set(2), 5L), (Set(3), 5L), (Set(4), 4L), + (Set(1, 2), 4L), (Set(1, 3), 5L), (Set(2, 3), 4L), + (Set(2, 4), 4L), (Set(1, 2, 3), 4L)) + assert(freqItemsets3.toSet === expected) + + val model2 = fpg + .setMinSupport(0.3) + .setNumPartitions(4) + .run(rdd) + assert(model2.freqItemsets.count() === 15) + + val model1 = fpg + .setMinSupport(0.1) + .setNumPartitions(8) + .run(rdd) + assert(model1.freqItemsets.count() === 65) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala index dac28a369b5b2..699f009f0f2ec 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala @@ -38,7 +38,7 @@ class PeriodicGraphCheckpointerSuite extends FunSuite with MLlibTestSparkContext var graphsToCheck = Seq.empty[GraphToCheck] val graph1 = createGraph(sc) - val checkpointer = new PeriodicGraphCheckpointer(graph1, None, 10) + val checkpointer = new PeriodicGraphCheckpointer(graph1, 10) graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) checkPersistence(graphsToCheck, 1) @@ -57,9 +57,9 @@ class PeriodicGraphCheckpointerSuite extends FunSuite with MLlibTestSparkContext val path = tempDir.toURI.toString val checkpointInterval = 2 var graphsToCheck = Seq.empty[GraphToCheck] - + sc.setCheckpointDir(path) val graph1 = createGraph(sc) - val checkpointer = new PeriodicGraphCheckpointer(graph1, Some(path), checkpointInterval) + val checkpointer = new PeriodicGraphCheckpointer(graph1, checkpointInterval) graph1.edges.count() graph1.vertices.count() graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index b0b78acd6df16..002cb253862b5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -166,6 +166,14 @@ class BLASSuite extends FunSuite { syr(alpha, y, dA) } } + + val xSparse = new SparseVector(4, Array(0, 2, 3), Array(1.0, 3.0, 4.0)) + val dD = new DenseMatrix(4, 4, + Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0, 3.1, 4.6, 3.0, 0.8)) + syr(0.1, xSparse, dD) + val expectedSparse = new DenseMatrix(4, 4, + Array(0.1, 1.2, 2.5, 3.5, 1.2, 3.2, 5.3, 4.6, 2.5, 5.3, 2.7, 4.2, 3.5, 4.6, 4.2, 2.4)) + assert(dD ~== expectedSparse absTol 1e-15) } test("gemm") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index b1ebfde0e5e57..c098b5458fe6b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -137,8 +137,8 @@ class MatricesSuite extends FunSuite { val spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values) val deMat1 = new DenseMatrix(m, n, allValues) - val spMat2 = deMat1.toSparse() - val deMat2 = spMat1.toDense() + val spMat2 = deMat1.toSparse + val deMat2 = spMat1.toDense assert(spMat1.toBreeze === spMat2.toBreeze) assert(deMat1.toBreeze === deMat2.toBreeze) @@ -185,8 +185,8 @@ class MatricesSuite extends FunSuite { assert(!dA.toArray.eq(dAT.toArray), "has to have a new array") assert(dA.values.eq(dAT.transpose.asInstanceOf[DenseMatrix].values), "should not copy array") - assert(dAT.toSparse().toBreeze === sATexpected.toBreeze) - assert(sAT.toDense().toBreeze === dATexpected.toBreeze) + assert(dAT.toSparse.toBreeze === sATexpected.toBreeze) + assert(sAT.toDense.toBreeze === dATexpected.toBreeze) } test("foreachActive") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala index b9caecc904a23..9801e87576744 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext { @@ -53,4 +54,22 @@ class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext new MatrixFactorizationModel(rank, userFeatures, prodFeatures1) } } + + test("save/load") { + val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures) + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + def collect(features: RDD[(Int, Array[Double])]): Set[(Int, Seq[Double])] = { + features.mapValues(_.toSeq).collect().toSet + } + try { + model.save(sc, path) + val newModel = MatrixFactorizationModel.load(sc, path) + assert(newModel.rank === rank) + assert(collect(newModel.userFeatures) === collect(userFeatures)) + assert(collect(newModel.productFeatures) === collect(prodFeatures)) + } finally { + Utils.deleteRecursively(tempDir) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index 2668dcc14a842..c9f5dc069ef2e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -24,6 +24,13 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.util.Utils + +private object LassoSuite { + + /** 3 features */ + val model = new LassoModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) +} class LassoSuite extends FunSuite with MLlibTestSparkContext { @@ -115,6 +122,23 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext { // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } + + test("model save/load") { + val model = LassoSuite.model + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = LassoModel.load(sc, path) + assert(model.weights == sameModel.weights) + assert(model.intercept == sameModel.intercept) + } finally { + Utils.deleteRecursively(tempDir) + } + } } class LassoClusterSuite extends FunSuite with LocalClusterSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index 864622a9296a6..3781931c2f819 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -24,6 +24,13 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.util.Utils + +private object LinearRegressionSuite { + + /** 3 features */ + val model = new LinearRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) +} class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { @@ -124,6 +131,23 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { validatePrediction( sparseValidationData.map(row => model.predict(row.features)), sparseValidationData) } + + test("model save/load") { + val model = LinearRegressionSuite.model + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = LinearRegressionModel.load(sc, path) + assert(model.weights == sameModel.weights) + assert(model.intercept == sameModel.intercept) + } finally { + Utils.deleteRecursively(tempDir) + } + } } class LinearRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index 18d3bf5ea4eca..43d61151e2471 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -25,6 +25,13 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.util.Utils + +private object RidgeRegressionSuite { + + /** 3 features */ + val model = new RidgeRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) +} class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { @@ -75,6 +82,23 @@ class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { assert(ridgeErr < linearErr, "ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")") } + + test("model save/load") { + val model = RidgeRegressionSuite.model + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = RidgeRegressionModel.load(sc, path) + assert(model.weights == sameModel.weights) + assert(model.intercept == sameModel.intercept) + } finally { + Utils.deleteRecursively(tempDir) + } + } } class RidgeRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index 70b43ddb7daf5..24fd8df691817 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -139,4 +139,32 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { val errors = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints) assert(errors.forall(x => x <= 0.1)) } + + // Test training combined with prediction + test("training and prediction") { + // create model initialized with zero weights + val model = new StreamingLinearRegressionWithSGD() + .setInitialWeights(Vectors.dense(0.0, 0.0)) + .setStepSize(0.2) + .setNumIterations(25) + + // generate sequence of simulated data for testing + val numBatches = 10 + val nPoints = 100 + val testInput = (0 until numBatches).map { i => + LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), nPoints, 42 * (i + 1)) + } + + // train and predict + val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + }) + + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + + // assert that prediction error improves, ensuring that the updated model is being used + val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList + assert((error.head - error.last) > 2) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 9347eaf9221a8..7b1aed5ffeb3e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -29,8 +29,10 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy} import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} -import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node} +import org.apache.spark.mllib.tree.model._ import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils + class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { @@ -857,9 +859,32 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { assert(topNode.leftNode.get.impurity === 0.0) assert(topNode.rightNode.get.impurity === 0.0) } + + test("Node.subtreeIterator") { + val model = DecisionTreeSuite.createModel(Classification) + val nodeIds = model.topNode.subtreeIterator.map(_.id).toArray.sorted + assert(nodeIds === DecisionTreeSuite.createdModelNodeIds) + } + + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + Array(Classification, Regression).foreach { algo => + val model = DecisionTreeSuite.createModel(algo) + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = DecisionTreeModel.load(sc, path) + DecisionTreeSuite.checkEqual(model, sameModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + } } -object DecisionTreeSuite { +object DecisionTreeSuite extends FunSuite { def validateClassifier( model: DecisionTreeModel, @@ -979,4 +1004,95 @@ object DecisionTreeSuite { arr } + /** Create a leaf node with the given node ID */ + private def createLeafNode(id: Int): Node = { + Node(nodeIndex = id, new Predict(0.0, 1.0), impurity = 0.5, isLeaf = true) + } + + /** + * Create an internal node with the given node ID and feature type. + * Note: This does NOT set the child nodes. + */ + private def createInternalNode(id: Int, featureType: FeatureType): Node = { + val node = Node(nodeIndex = id, new Predict(0.0, 1.0), impurity = 0.5, isLeaf = false) + featureType match { + case Continuous => + node.split = Some(new Split(feature = 0, threshold = 0.5, Continuous, + categories = List.empty[Double])) + case Categorical => + node.split = Some(new Split(feature = 1, threshold = 0.0, Categorical, + categories = List(0.0, 1.0))) + } + // TODO: The information gain stats should be consistent with the same info stored in children. + node.stats = Some(new InformationGainStats(gain = 0.1, impurity = 0.2, + leftImpurity = 0.3, rightImpurity = 0.4, new Predict(1.0, 0.4), new Predict(0.0, 0.6))) + node + } + + /** + * Create a tree model. This is deterministic and contains a variety of node and feature types. + */ + private[tree] def createModel(algo: Algo): DecisionTreeModel = { + val topNode = createInternalNode(id = 1, Continuous) + val (node2, node3) = (createLeafNode(id = 2), createInternalNode(id = 3, Categorical)) + val (node6, node7) = (createLeafNode(id = 6), createLeafNode(id = 7)) + topNode.leftNode = Some(node2) + topNode.rightNode = Some(node3) + node3.leftNode = Some(node6) + node3.rightNode = Some(node7) + new DecisionTreeModel(topNode, algo) + } + + /** Sorted Node IDs matching the model returned by [[createModel()]] */ + private val createdModelNodeIds = Array(1, 2, 3, 6, 7) + + /** + * Check if the two trees are exactly the same. + * Note: I hesitate to override Node.equals since it could cause problems if users + * make mistakes such as creating loops of Nodes. + * If the trees are not equal, this prints the two trees and throws an exception. + */ + private[tree] def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = { + try { + assert(a.algo === b.algo) + checkEqual(a.topNode, b.topNode) + } catch { + case ex: Exception => + throw new AssertionError("checkEqual failed since the two trees were not identical.\n" + + "TREE A:\n" + a.toDebugString + "\n" + + "TREE B:\n" + b.toDebugString + "\n", ex) + } + } + + /** + * Return true iff the two nodes and their descendents are exactly the same. + * Note: I hesitate to override Node.equals since it could cause problems if users + * make mistakes such as creating loops of Nodes. + */ + private def checkEqual(a: Node, b: Node): Unit = { + assert(a.id === b.id) + assert(a.predict === b.predict) + assert(a.impurity === b.impurity) + assert(a.isLeaf === b.isLeaf) + assert(a.split === b.split) + (a.stats, b.stats) match { + // TODO: Check other fields besides the infomation gain. + case (Some(aStats), Some(bStats)) => assert(aStats.gain === bStats.gain) + case (None, None) => + case _ => throw new AssertionError( + s"Only one instance has stats defined. (a.stats: ${a.stats}, b.stats: ${b.stats})") + } + (a.leftNode, b.leftNode) match { + case (Some(aNode), Some(bNode)) => checkEqual(aNode, bNode) + case (None, None) => + case _ => throw new AssertionError("Only one instance has leftNode defined. " + + s"(a.leftNode: ${a.leftNode}, b.leftNode: ${b.leftNode})") + } + (a.rightNode, b.rightNode) match { + case (Some(aNode: Node), Some(bNode: Node)) => checkEqual(aNode, bNode) + case (None, None) => + case _ => throw new AssertionError("Only one instance has rightNode defined. " + + s"(a.rightNode: ${a.rightNode}, b.rightNode: ${b.rightNode})") + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index e8341a5d0d104..bde47606eb001 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -24,8 +24,10 @@ import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} import org.apache.spark.mllib.tree.impurity.Variance import org.apache.spark.mllib.tree.loss.{AbsoluteError, SquaredError, LogLoss} - +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils + /** * Test suite for [[GradientBoostedTrees]]. @@ -35,32 +37,30 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { test("Regression with continuous features: SquaredError") { GradientBoostedTreesSuite.testCombinations.foreach { case (numIterations, learningRate, subsamplingRate) => - GradientBoostedTreesSuite.randomSeeds.foreach { randomSeed => - val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2) - - val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, - categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate) - val boostingStrategy = - new BoostingStrategy(treeStrategy, SquaredError, numIterations, learningRate) - - val gbt = GradientBoostedTrees.train(rdd, boostingStrategy) - - assert(gbt.trees.size === numIterations) - try { - EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06) - } catch { - case e: java.lang.AssertionError => - println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + - s" subsamplingRate=$subsamplingRate") - throw e - } - - val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - val dt = DecisionTree.train(remappedInput, treeStrategy) - - // Make sure trees are the same. - assert(gbt.trees.head.toString == dt.toString) + val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2) + + val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate) + val boostingStrategy = + new BoostingStrategy(treeStrategy, SquaredError, numIterations, learningRate) + + val gbt = GradientBoostedTrees.train(rdd, boostingStrategy) + + assert(gbt.trees.size === numIterations) + try { + EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06) + } catch { + case e: java.lang.AssertionError => + println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + + s" subsamplingRate=$subsamplingRate") + throw e } + + val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + val dt = DecisionTree.train(remappedInput, treeStrategy) + + // Make sure trees are the same. + assert(gbt.trees.head.toString == dt.toString) } } @@ -133,14 +133,37 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { BoostingStrategy.defaultParams(algo) } } + + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val trees = Range(0, 3).map(_ => DecisionTreeSuite.createModel(Regression)).toArray + val treeWeights = Array(0.1, 0.3, 1.1) + + Array(Classification, Regression).foreach { algo => + val model = new GradientBoostedTreesModel(algo, trees, treeWeights) + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = GradientBoostedTreesModel.load(sc, path) + assert(model.algo == sameModel.algo) + model.trees.zip(sameModel.trees).foreach { case (treeA, treeB) => + DecisionTreeSuite.checkEqual(treeA, treeB) + } + assert(model.treeWeights === sameModel.treeWeights) + } finally { + Utils.deleteRecursively(tempDir) + } + } + } } -object GradientBoostedTreesSuite { +private object GradientBoostedTreesSuite { // Combinations for estimators, learning rates and subsamplingRate val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75)) - val randomSeeds = Array(681283, 4398) - val data = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index 55e963977b54f..ee3bc98486862 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -27,8 +27,10 @@ import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata import org.apache.spark.mllib.tree.impurity.{Gini, Variance} -import org.apache.spark.mllib.tree.model.Node +import org.apache.spark.mllib.tree.model.{Node, RandomForestModel} import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils + /** * Test suite for [[RandomForest]]. @@ -212,6 +214,26 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext { assert(rf1.toDebugString != rf2.toDebugString) } -} - + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + Array(Classification, Regression).foreach { algo => + val trees = Range(0, 3).map(_ => DecisionTreeSuite.createModel(algo)).toArray + val model = new RandomForestModel(algo, trees) + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = RandomForestModel.load(sc, path) + assert(model.algo == sameModel.algo) + model.trees.zip(sameModel.trees).foreach { case (treeA, treeB) => + DecisionTreeSuite.checkEqual(treeA, treeB) + } + } finally { + Utils.deleteRecursively(tempDir) + } + } + } +} diff --git a/network/common/pom.xml b/network/common/pom.xml index 8f7c924d6b3a3..dff5ebd476188 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -21,8 +21,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.3.2-SNAPSHOT ../../pom.xml @@ -80,6 +80,11 @@ mockito-all test + + org.slf4j + slf4j-log4j12 + test + diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index 91d1e8a538a77..0f999f5dfe8d8 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -72,9 +72,11 @@ public void encode(ChannelHandlerContext ctx, Message in, List out) { in.encode(header); assert header.writableBytes() == 0; - out.add(header); if (body != null && bodyLength > 0) { - out.add(body); + out.add(new MessageWithHeader(header, body, bodyLength)); + } else { + out.add(header); } } + } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java new file mode 100644 index 0000000000000..d686a951467cf --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.io.IOException; +import java.nio.channels.WritableByteChannel; + +import com.google.common.base.Preconditions; +import io.netty.buffer.ByteBuf; +import io.netty.channel.FileRegion; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.ReferenceCountUtil; + +/** + * A wrapper message that holds two separate pieces (a header and a body). + * + * The header must be a ByteBuf, while the body can be a ByteBuf or a FileRegion. + */ +class MessageWithHeader extends AbstractReferenceCounted implements FileRegion { + + private final ByteBuf header; + private final int headerLength; + private final Object body; + private final long bodyLength; + private long totalBytesTransferred; + + MessageWithHeader(ByteBuf header, Object body, long bodyLength) { + Preconditions.checkArgument(body instanceof ByteBuf || body instanceof FileRegion, + "Body must be a ByteBuf or a FileRegion."); + this.header = header; + this.headerLength = header.readableBytes(); + this.body = body; + this.bodyLength = bodyLength; + } + + @Override + public long count() { + return headerLength + bodyLength; + } + + @Override + public long position() { + return 0; + } + + @Override + public long transfered() { + return totalBytesTransferred; + } + + /** + * This code is more complicated than you would think because we might require multiple + * transferTo invocations in order to transfer a single MessageWithHeader to avoid busy waiting. + * + * The contract is that the caller will ensure position is properly set to the total number + * of bytes transferred so far (i.e. value returned by transfered()). + */ + @Override + public long transferTo(final WritableByteChannel target, final long position) throws IOException { + Preconditions.checkArgument(position == totalBytesTransferred, "Invalid position."); + // Bytes written for header in this call. + long writtenHeader = 0; + if (header.readableBytes() > 0) { + writtenHeader = copyByteBuf(header, target); + totalBytesTransferred += writtenHeader; + if (header.readableBytes() > 0) { + return writtenHeader; + } + } + + // Bytes written for body in this call. + long writtenBody = 0; + if (body instanceof FileRegion) { + writtenBody = ((FileRegion) body).transferTo(target, totalBytesTransferred - headerLength); + } else if (body instanceof ByteBuf) { + writtenBody = copyByteBuf((ByteBuf) body, target); + } + totalBytesTransferred += writtenBody; + + return writtenHeader + writtenBody; + } + + @Override + protected void deallocate() { + header.release(); + ReferenceCountUtil.release(body); + } + + private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException { + int written = target.write(buf.nioBuffer()); + buf.skipBytes(written); + return written; + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java index 625c3257d764e..ef209991804b4 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -100,8 +100,7 @@ protected void initChannel(SocketChannel ch) throws Exception { } }); - channelFuture = bootstrap.bind(new InetSocketAddress(portToBind)); - channelFuture.syncUninterruptibly(); + bindRightPort(portToBind); port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort(); logger.debug("Shuffle server started on port :" + port); @@ -123,4 +122,37 @@ public void close() { bootstrap = null; } + /** + * Attempt to bind to the specified port up to a fixed number of retries. + * If all attempts fail after the max number of retries, exit. + */ + private void bindRightPort(int portToBind) { + int maxPortRetries = conf.portMaxRetries(); + + for (int i = 0; i <= maxPortRetries; i++) { + int tryPort = -1; + if (0 == portToBind) { + // Do not increment port if tryPort is 0, which is treated as a special port + tryPort = 0; + } else { + // If the new port wraps around, do not try a privilege port + tryPort = ((portToBind + i - 1024) % (65536 - 1024)) + 1024; + } + try { + channelFuture = bootstrap.bind(new InetSocketAddress(tryPort)); + channelFuture.syncUninterruptibly(); + return; + } catch (Exception e) { + logger.warn("Netty service could not bind on port " + tryPort + + ". Attempting the next port."); + if (i >= maxPortRetries) { + logger.error(e.getMessage() + ": Netty server failed after " + + maxPortRetries + " retries."); + + // If it can't find a right port, it should exit directly. + System.exit(-1); + } + } + } + } } diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index 6c9178688693f..2eaf3b71d9a49 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -98,4 +98,11 @@ public int memoryMapBytes() { public boolean lazyFileDescriptor() { return conf.getBoolean("spark.shuffle.io.lazyFD", true); } + + /** + * Maximum number of retries when binding to a port before giving up. + */ + public int portMaxRetries() { + return conf.getInt("spark.port.maxRetries", 16); + } } diff --git a/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java b/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java new file mode 100644 index 0000000000000..b525ed69fc9fb --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; + +public class ByteArrayWritableChannel implements WritableByteChannel { + + private final byte[] data; + private int offset; + + public ByteArrayWritableChannel(int size) { + this.data = new byte[size]; + this.offset = 0; + } + + public byte[] getData() { + return data; + } + + @Override + public int write(ByteBuffer src) { + int available = src.remaining(); + src.get(data, offset, available); + offset += available; + return available; + } + + @Override + public void close() { + + } + + @Override + public boolean isOpen() { + return true; + } + +} diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java index 43dc0cf8c7194..860dd6d9b3915 100644 --- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -17,26 +17,34 @@ package org.apache.spark.network; +import java.util.List; + +import com.google.common.primitives.Ints; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.FileRegion; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.MessageToMessageEncoder; import org.junit.Test; import static org.junit.Assert.assertEquals; -import org.apache.spark.network.protocol.Message; -import org.apache.spark.network.protocol.StreamChunkId; -import org.apache.spark.network.protocol.ChunkFetchRequest; import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchRequest; import org.apache.spark.network.protocol.ChunkFetchSuccess; -import org.apache.spark.network.protocol.RpcRequest; -import org.apache.spark.network.protocol.RpcFailure; -import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.protocol.Message; import org.apache.spark.network.protocol.MessageDecoder; import org.apache.spark.network.protocol.MessageEncoder; +import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcRequest; +import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.protocol.StreamChunkId; import org.apache.spark.network.util.NettyUtils; public class ProtocolSuite { private void testServerToClient(Message msg) { - EmbeddedChannel serverChannel = new EmbeddedChannel(new MessageEncoder()); + EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(), + new MessageEncoder()); serverChannel.writeOutbound(msg); EmbeddedChannel clientChannel = new EmbeddedChannel( @@ -51,7 +59,8 @@ private void testServerToClient(Message msg) { } private void testClientToServer(Message msg) { - EmbeddedChannel clientChannel = new EmbeddedChannel(new MessageEncoder()); + EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(), + new MessageEncoder()); clientChannel.writeOutbound(msg); EmbeddedChannel serverChannel = new EmbeddedChannel( @@ -83,4 +92,25 @@ public void responses() { testServerToClient(new RpcFailure(0, "this is an error")); testServerToClient(new RpcFailure(0, "")); } + + /** + * Handler to transform a FileRegion into a byte buffer. EmbeddedChannel doesn't actually transfer + * bytes, but messages, so this is needed so that the frame decoder on the receiving side can + * understand what MessageWithHeader actually contains. + */ + private static class FileRegionEncoder extends MessageToMessageEncoder { + + @Override + public void encode(ChannelHandlerContext ctx, FileRegion in, List out) + throws Exception { + + ByteArrayWritableChannel channel = new ByteArrayWritableChannel(Ints.checkedCast(in.count())); + while (in.transfered() < in.count()) { + in.transferTo(channel, in.transfered()); + } + out.add(Unpooled.wrappedBuffer(channel.getData())); + } + + } + } diff --git a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java new file mode 100644 index 0000000000000..ff985096d72d5 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.FileRegion; +import io.netty.util.AbstractReferenceCounted; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.ByteArrayWritableChannel; + +public class MessageWithHeaderSuite { + + @Test + public void testSingleWrite() throws Exception { + testFileRegionBody(8, 8); + } + + @Test + public void testShortWrite() throws Exception { + testFileRegionBody(8, 1); + } + + @Test + public void testByteBufBody() throws Exception { + ByteBuf header = Unpooled.copyLong(42); + ByteBuf body = Unpooled.copyLong(84); + MessageWithHeader msg = new MessageWithHeader(header, body, body.readableBytes()); + + ByteBuf result = doWrite(msg, 1); + assertEquals(msg.count(), result.readableBytes()); + assertEquals(42, result.readLong()); + assertEquals(84, result.readLong()); + } + + private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception { + ByteBuf header = Unpooled.copyLong(42); + int headerLength = header.readableBytes(); + TestFileRegion region = new TestFileRegion(totalWrites, writesPerCall); + MessageWithHeader msg = new MessageWithHeader(header, region, region.count()); + + ByteBuf result = doWrite(msg, totalWrites / writesPerCall); + assertEquals(headerLength + region.count(), result.readableBytes()); + assertEquals(42, result.readLong()); + for (long i = 0; i < 8; i++) { + assertEquals(i, result.readLong()); + } + } + + private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception { + int writes = 0; + ByteArrayWritableChannel channel = new ByteArrayWritableChannel((int) msg.count()); + while (msg.transfered() < msg.count()) { + msg.transferTo(channel, msg.transfered()); + writes++; + } + assertTrue("Not enough writes!", minExpectedWrites <= writes); + return Unpooled.wrappedBuffer(channel.getData()); + } + + private static class TestFileRegion extends AbstractReferenceCounted implements FileRegion { + + private final int writeCount; + private final int writesPerCall; + private int written; + + TestFileRegion(int totalWrites, int writesPerCall) { + this.writeCount = totalWrites; + this.writesPerCall = writesPerCall; + } + + @Override + public long count() { + return 8 * writeCount; + } + + @Override + public long position() { + return 0; + } + + @Override + public long transfered() { + return 8 * written; + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + for (int i = 0; i < writesPerCall; i++) { + ByteBuf buf = Unpooled.copyLong((position / 8) + i); + ByteBuffer nio = buf.nioBuffer(); + while (nio.remaining() > 0) { + target.write(nio); + } + buf.release(); + written++; + } + return 8 * writesPerCall; + } + + @Override + protected void deallocate() { + } + + } + +} diff --git a/network/common/src/test/resources/log4j.properties b/network/common/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..e8da774f7ca9e --- /dev/null +++ b/network/common/src/test/resources/log4j.properties @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=DEBUG, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Silence verbose logs from 3rd-party libraries. +log4j.logger.io.netty=INFO diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index c2d0300ecd904..ded317d728d80 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -21,8 +21,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.3.2-SNAPSHOT ../../pom.xml diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml index acec8f18f2b5c..a8b7c8bfd30b0 100644 --- a/network/yarn/pom.xml +++ b/network/yarn/pom.xml @@ -21,8 +21,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.3.2-SNAPSHOT ../../pom.xml @@ -33,6 +33,8 @@ http://spark.apache.org/ network-yarn + + provided @@ -47,7 +49,6 @@ org.apache.hadoop hadoop-client - provided diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index a34aabe9e78a6..63b21222e7b77 100644 --- a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -76,6 +76,9 @@ public class YarnShuffleService extends AuxiliaryService { // The actual server that serves shuffle files private TransportServer shuffleServer = null; + // Handles registering executors and opening shuffle blocks + private ExternalShuffleBlockHandler blockHandler; + public YarnShuffleService() { super("spark_shuffle"); logger.info("Initializing YARN shuffle service for Spark"); @@ -99,7 +102,8 @@ protected void serviceInit(Configuration conf) { // If authentication is enabled, set up the shuffle server to use a // special RPC handler that filters out unauthenticated fetch requests boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); - RpcHandler rpcHandler = new ExternalShuffleBlockHandler(transportConf); + blockHandler = new ExternalShuffleBlockHandler(transportConf); + RpcHandler rpcHandler = blockHandler; if (authEnabled) { secretManager = new ShuffleSecretManager(); rpcHandler = new SaslRpcHandler(rpcHandler, secretManager); @@ -136,6 +140,7 @@ public void stopApplication(ApplicationTerminationContext context) { if (isAuthenticationEnabled()) { secretManager.unregisterApp(appId); } + blockHandler.applicationRemoved(appId, false /* clean up local dirs */); } catch (Exception e) { logger.error("Exception when stopping application {}", appId, e); } diff --git a/pom.xml b/pom.xml index e25eced877578..a599ad470decd 100644 --- a/pom.xml +++ b/pom.xml @@ -25,8 +25,8 @@ 14 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.3.2-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ @@ -135,9 +135,12 @@ 1.6.0rc3 1.2.3 8.1.14.v20131031 + 3.0.0.v201112011016 0.5.0 + 2.4.0 + 2.0.8 3.1.0 - 1.7.6 + 1.7.7 0.7.1 1.8.3 @@ -149,7 +152,9 @@ 2.10 ${scala.version} org.scala-lang + 3.6.3 1.8.8 + 2.4.4 1.1.1.6 @@ -390,7 +395,7 @@ provided - + org.apache.commons commons-lang3 @@ -399,7 +404,7 @@ commons-codec commons-codec - 1.5 + 1.10 org.apache.commons @@ -417,6 +422,13 @@ 2.42.2 test + + + xml-apis + xml-apis + 1.4.01 + test + org.slf4j slf4j-api @@ -573,6 +585,24 @@ metrics-graphite ${codahale.metrics.version} + + com.fasterxml.jackson.core + jackson-databind + ${fasterxml.jackson.version} + + + + com.fasterxml.jackson.module + jackson-module-scala_2.10 + ${fasterxml.jackson.version} + + + com.google.guava + guava + + + org.scala-lang scala-compiler @@ -604,19 +634,6 @@ 2.2.1 test - - org.easymock - easymockclassextension - 3.1 - test - - - - asm - asm - 3.3.1 - test - org.mockito mockito-all @@ -924,6 +941,16 @@ ${codehaus.jackson.version} ${hadoop.deps.scope} + + org.codehaus.jackson + jackson-xc + ${codehaus.jackson.version} + + + org.codehaus.jackson + jackson-jaxrs + ${codehaus.jackson.version} + ${hive.group} hive-beeline @@ -950,6 +977,10 @@ com.esotericsoftware.kryo kryo + + org.apache.avro + avro-mapred + @@ -1067,6 +1098,12 @@ scala-maven-plugin 3.2.0 + + eclipse-add-source + + add-source + + scala-compile-first process-resources @@ -1149,13 +1186,19 @@ ${project.build.directory}/surefire-reports -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + + + ${test_classpath} + true ${session.executionRootDirectory} 1 false false - ${test_classpath} true false @@ -1534,6 +1577,7 @@ 2.5.0 0.98.7-hadoop2 hadoop2 + 1.9.13 @@ -1542,10 +1586,11 @@ 2.3.0 2.5.0 - 0.9.0 + 0.9.3 0.98.7-hadoop2 3.1.1 hadoop2 + 1.9.13 @@ -1554,10 +1599,11 @@ 2.4.0 2.5.0 - 0.9.0 + 0.9.3 0.98.7-hadoop2 3.1.1 hadoop2 + 1.9.13 @@ -1574,7 +1620,7 @@ 1.0.3-mapr-3.0.3 2.4.1-mapr-1408 - 0.94.17-mapr-1405 + 0.98.4-mapr-1408 3.4.5-mapr-1406 @@ -1584,7 +1630,7 @@ 2.4.1-mapr-1408 2.4.1-mapr-1408 - 0.94.17-mapr-1405-4.0.0-FCS + 0.98.4-mapr-1408 3.4.5-mapr-1406 diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b17532c1d814c..ee6229aa6bbe1 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,7 @@ object MimaExcludes { case v if v.startsWith("1.3") => Seq( MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("ml"), // These are needed if checking against the sbt build, since they are part of // the maven-generated artifacts in the 1.2 build. MimaBuild.excludeSparkPackage("unused"), @@ -142,6 +143,16 @@ object MimaExcludes { "org.apache.spark.graphx.Graph.getCheckpointFiles"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.graphx.Graph.isCheckpointed") + ) ++ Seq( + // SPARK-4789 Standardize ML Prediction APIs + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.linalg.VectorUDT"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.serialize"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.sqlType") + ) ++ Seq( + // SPARK-4682 + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.RealClock"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Clock"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.TestClock") ) case v if v.startsWith("1.2") => diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index fbc8983b953b7..20826913190a2 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -177,6 +177,29 @@ object SparkBuild extends PomBuild { enable(Flume.settings)(streamingFlumeSink) + + /** + * Adds the ability to run the spark shell directly from SBT without building an assembly + * jar. + * + * Usage: `build/sbt sparkShell` + */ + val sparkShell = taskKey[Unit]("start a spark-shell.") + + enable(Seq( + connectInput in run := true, + fork := true, + outputStrategy in run := Some (StdoutOutput), + + javaOptions ++= Seq("-Xmx2G", "-XX:MaxPermSize=1g"), + + sparkShell := { + (runMain in Compile).toTask(" org.apache.spark.repl.Main -usejavacp").value + } + ))(assembly) + + enable(Seq(sparkShell := sparkShell in "assembly"))(spark) + // TODO: move this to its upstream project. override def projectDefinitions(baseDirectory: File): Seq[Project] = { super.projectDefinitions(baseDirectory).map { x => @@ -245,6 +268,7 @@ object SQL { |import org.apache.spark.sql.catalyst.plans.logical._ |import org.apache.spark.sql.catalyst.rules._ |import org.apache.spark.sql.catalyst.util._ + |import org.apache.spark.sql.Dsl._ |import org.apache.spark.sql.execution |import org.apache.spark.sql.test.TestSQLContext._ |import org.apache.spark.sql.types._ @@ -275,6 +299,7 @@ object Hive { |import org.apache.spark.sql.catalyst.plans.logical._ |import org.apache.spark.sql.catalyst.rules._ |import org.apache.spark.sql.catalyst.util._ + |import org.apache.spark.sql.Dsl._ |import org.apache.spark.sql.execution |import org.apache.spark.sql.hive._ |import org.apache.spark.sql.hive.test.TestHive._ @@ -332,25 +357,38 @@ object Unidoc { names.map(s => "org.apache.spark." + s).mkString(":") } + private def ignoreUndocumentedPackages(packages: Seq[Seq[File]]): Seq[Seq[File]] = { + packages + .map(_.filterNot(_.getName.contains("$"))) + .map(_.filterNot(_.getCanonicalPath.contains("akka"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/deploy"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/network"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/shuffle"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/executor"))) + .map(_.filterNot(_.getCanonicalPath.contains("python"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/collection"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/catalyst"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/execution"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/hive/test"))) + } + lazy val settings = scalaJavaUnidocSettings ++ Seq ( publish := {}, unidocProjectFilter in(ScalaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, catalyst, streamingFlumeSink, yarn), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, bagel, examples, tools, catalyst, streamingFlumeSink, yarn), + inAnyProject -- inProjects(OldDeps.project, repl, bagel, examples, tools, streamingFlumeSink, yarn), + + // Skip actual catalyst, but include the subproject. + // Catalyst is not public API and contains quasiquotes which break scaladoc. + unidocAllSources in (ScalaUnidoc, unidoc) := { + ignoreUndocumentedPackages((unidocAllSources in (ScalaUnidoc, unidoc)).value) + }, // Skip class names containing $ and some internal packages in Javadocs unidocAllSources in (JavaUnidoc, unidoc) := { - (unidocAllSources in (JavaUnidoc, unidoc)).value - .map(_.filterNot(_.getName.contains("$"))) - .map(_.filterNot(_.getCanonicalPath.contains("akka"))) - .map(_.filterNot(_.getCanonicalPath.contains("deploy"))) - .map(_.filterNot(_.getCanonicalPath.contains("network"))) - .map(_.filterNot(_.getCanonicalPath.contains("shuffle"))) - .map(_.filterNot(_.getCanonicalPath.contains("executor"))) - .map(_.filterNot(_.getCanonicalPath.contains("python"))) - .map(_.filterNot(_.getCanonicalPath.contains("collection"))) + ignoreUndocumentedPackages((unidocAllSources in (JavaUnidoc, unidoc)).value) }, // Javadoc options: create a window title, and group key packages on index page @@ -373,7 +411,10 @@ object Unidoc { ), "-group", "Spark SQL", packageList("sql.api.java", "sql.api.java.types", "sql.hive.api.java"), "-noqualifier", "java.lang" - ) + ), + + // Group similar methods together based on the @group annotation. + scalacOptions in (ScalaUnidoc, unidoc) ++= Seq("-groups") ) } @@ -383,6 +424,10 @@ object TestSettings { lazy val settings = Seq ( // Fork new JVMs for tests and set Java options for those fork := true, + // Setting SPARK_DIST_CLASSPATH is a simple way to make sure any child processes + // launched by the tests have access to the correct test-time classpath. + envVars in Test += ("SPARK_DIST_CLASSPATH" -> + (fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":")), javaOptions in Test += "-Dspark.test.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", javaOptions in Test += "-Dspark.port.maxRetries=100", @@ -395,10 +440,6 @@ object TestSettings { javaOptions in Test += "-ea", javaOptions in Test ++= "-Xmx3g -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g" .split(" ").toSeq, - // This places test scope jars on the classpath of executors during tests. - javaOptions in Test += - "-Dspark.executor.extraClassPath=" + (fullClasspath in Test).value.files. - map(_.getAbsolutePath).mkString(":").stripSuffix(":"), javaOptions += "-Xmx3g", // Show full stack trace and duration in test cases. testOptions in Test += Tests.Argument("-oDF"), diff --git a/python/docs/conf.py b/python/docs/conf.py index b00dce95d65b4..163987dd8e5fa 100644 --- a/python/docs/conf.py +++ b/python/docs/conf.py @@ -48,16 +48,16 @@ # General information about the project. project = u'PySpark' -copyright = u'2014, Author' +copyright = u'' # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. -version = '1.3-SNAPSHOT' +version = 'master' # The full version, including alpha/beta/rc tags. -release = '1.3-SNAPSHOT' +release = os.environ.get('RELEASE_VERSION', version) # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -97,6 +97,10 @@ # If true, keep warnings as "system message" paragraphs in the built documents. #keep_warnings = False +# -- Options for autodoc -------------------------------------------------- + +# Look at the first line of the docstring for function and method signatures. +autodoc_docstring_signature = True # -- Options for HTML output ---------------------------------------------- diff --git a/python/docs/index.rst b/python/docs/index.rst index d150de9d5c502..f7eede9c3c82a 100644 --- a/python/docs/index.rst +++ b/python/docs/index.rst @@ -29,6 +29,14 @@ Core classes: A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. + :class:`pyspark.sql.SQLContext` + + Main entry point for DataFrame and SQL functionality. + + :class:`pyspark.sql.DataFrame` + + A distributed collection of data grouped into named columns. + Indices and tables ================== diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst index f10d1339a9a8f..4da6d4a74a299 100644 --- a/python/docs/pyspark.ml.rst +++ b/python/docs/pyspark.ml.rst @@ -1,11 +1,8 @@ pyspark.ml package ===================== -Submodules ----------- - -pyspark.ml module ------------------ +Module Context +-------------- .. automodule:: pyspark.ml :members: diff --git a/python/docs/pyspark.mllib.rst b/python/docs/pyspark.mllib.rst index 4548b8739ed91..b706c5e376ef4 100644 --- a/python/docs/pyspark.mllib.rst +++ b/python/docs/pyspark.mllib.rst @@ -1,16 +1,13 @@ pyspark.mllib package ===================== -Submodules ----------- - pyspark.mllib.classification module ----------------------------------- .. automodule:: pyspark.mllib.classification :members: :undoc-members: - :show-inheritance: + :inherited-members: pyspark.mllib.clustering module ------------------------------- @@ -18,7 +15,6 @@ pyspark.mllib.clustering module .. automodule:: pyspark.mllib.clustering :members: :undoc-members: - :show-inheritance: pyspark.mllib.feature module ------------------------------- @@ -42,7 +38,6 @@ pyspark.mllib.random module .. automodule:: pyspark.mllib.random :members: :undoc-members: - :show-inheritance: pyspark.mllib.recommendation module ----------------------------------- @@ -50,7 +45,6 @@ pyspark.mllib.recommendation module .. automodule:: pyspark.mllib.recommendation :members: :undoc-members: - :show-inheritance: pyspark.mllib.regression module ------------------------------- @@ -58,7 +52,7 @@ pyspark.mllib.regression module .. automodule:: pyspark.mllib.regression :members: :undoc-members: - :show-inheritance: + :inherited-members: pyspark.mllib.stat module ------------------------- @@ -66,7 +60,6 @@ pyspark.mllib.stat module .. automodule:: pyspark.mllib.stat :members: :undoc-members: - :show-inheritance: pyspark.mllib.tree module ------------------------- @@ -74,7 +67,7 @@ pyspark.mllib.tree module .. automodule:: pyspark.mllib.tree :members: :undoc-members: - :show-inheritance: + :inherited-members: pyspark.mllib.util module ------------------------- @@ -82,4 +75,3 @@ pyspark.mllib.util module .. automodule:: pyspark.mllib.util :members: :undoc-members: - :show-inheritance: diff --git a/python/docs/pyspark.sql.rst b/python/docs/pyspark.sql.rst index 65b3650ae10ab..6259379ed05b7 100644 --- a/python/docs/pyspark.sql.rst +++ b/python/docs/pyspark.sql.rst @@ -1,10 +1,23 @@ pyspark.sql module ================== -Module contents ---------------- +Module Context +-------------- .. automodule:: pyspark.sql :members: :undoc-members: - :show-inheritance: + + +pyspark.sql.types module +------------------------ +.. automodule:: pyspark.sql.types + :members: + :undoc-members: + + +pyspark.sql.functions module +---------------------------- +.. automodule:: pyspark.sql.functions + :members: + :undoc-members: diff --git a/python/docs/pyspark.streaming.rst b/python/docs/pyspark.streaming.rst index f08185627d0bc..50822c93faba1 100644 --- a/python/docs/pyspark.streaming.rst +++ b/python/docs/pyspark.streaming.rst @@ -8,3 +8,10 @@ Module contents :members: :undoc-members: :show-inheritance: + +pyspark.streaming.kafka module +------------------------------ +.. automodule:: pyspark.streaming.kafka + :members: + :undoc-members: + :show-inheritance: diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index d3efcdf221d82..5f70ac6ed8fe6 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -22,17 +22,17 @@ - :class:`SparkContext`: Main entry point for Spark functionality. - - L{RDD} + - :class:`RDD`: A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. - - L{Broadcast} + - :class:`Broadcast`: A broadcast variable that gets reused across tasks. - - L{Accumulator} + - :class:`Accumulator`: An "add-only" shared variable that tasks can only add values to. - - L{SparkConf} + - :class:`SparkConf`: For configuring Spark. - - L{SparkFiles} + - :class:`SparkFiles`: Access files shipped with jobs. - - L{StorageLevel} + - :class:`StorageLevel`: Finer-grained cache persistence levels. """ @@ -45,6 +45,7 @@ from pyspark.accumulators import Accumulator, AccumulatorParam from pyspark.broadcast import Broadcast from pyspark.serializers import MarshalSerializer, PickleSerializer +from pyspark.status import * from pyspark.profiler import Profiler, BasicProfiler # for back compatibility @@ -53,5 +54,5 @@ __all__ = [ "SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast", "Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer", - "Profiler", "BasicProfiler", + "StatusTracker", "SparkJobInfo", "SparkStageInfo", "Profiler", "BasicProfiler", ] diff --git a/python/pyspark/context.py b/python/pyspark/context.py index bf1f61c8504ed..78dccc40470e3 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -21,6 +21,8 @@ from threading import Lock from tempfile import NamedTemporaryFile +from py4j.java_collections import ListConverter + from pyspark import accumulators from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast @@ -30,12 +32,11 @@ from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ PairDeserializer, AutoBatchedSerializer, NoOpSerializer from pyspark.storagelevel import StorageLevel -from pyspark.rdd import RDD +from pyspark.rdd import RDD, _load_from_socket from pyspark.traceback_utils import CallSite, first_spark_call +from pyspark.status import StatusTracker from pyspark.profiler import ProfilerCollector, BasicProfiler -from py4j.java_collections import ListConverter - __all__ = ['SparkContext'] @@ -58,12 +59,13 @@ class SparkContext(object): _gateway = None _jvm = None - _writeToFile = None _next_accum_id = 0 _active_spark_context = None _lock = Lock() _python_includes = None # zip and egg files that need to be added to PYTHONPATH + PACKAGE_EXTENSIONS = ('.zip', '.egg', '.jar') + def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, environment=None, batchSize=0, serializer=PickleSerializer(), conf=None, gateway=None, jsc=None, profiler_cls=BasicProfiler): @@ -185,7 +187,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, for path in self._conf.get("spark.submit.pyFiles", "").split(","): if path != "": (dirname, filename) = os.path.split(path) - if filename.lower().endswith("zip") or filename.lower().endswith("egg"): + if filename[-4:].lower() in self.PACKAGE_EXTENSIONS: self._python_includes.append(filename) sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename)) @@ -218,7 +220,6 @@ def _ensure_initialized(cls, instance=None, gateway=None): if not SparkContext._gateway: SparkContext._gateway = gateway or launch_gateway() SparkContext._jvm = SparkContext._gateway.jvm - SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile if instance: if (SparkContext._active_spark_context and @@ -705,7 +706,7 @@ def addPyFile(self, path): self.addFile(path) (dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix - if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'): + if filename[-4:].lower() in self.PACKAGE_EXTENSIONS: self._python_includes.append(filename) # for tests in local mode sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename)) @@ -808,6 +809,12 @@ def cancelAllJobs(self): """ self._jsc.sc().cancelAllJobs() + def statusTracker(self): + """ + Return :class:`StatusTracker` object + """ + return StatusTracker(self._jsc.statusTracker()) + def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): """ Executes the given partitionFunc on the specified set of partitions, @@ -831,8 +838,9 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): # by runJob() in order to avoid having to pass a Python lambda into # SparkContext#runJob. mappedRDD = rdd.mapPartitions(partitionFunc) - it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal) - return list(mappedRDD._collect_iterator_through_file(it)) + port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, + allowLocal) + return list(_load_from_socket(port, mappedRDD._jrdd_deserializer)) def show_profiles(self): """ Print the profile stats to stdout """ diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index f09587f211708..93885985fe377 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -61,7 +61,10 @@ def worker(sock): except SystemExit as exc: exit_code = compute_real_exit_code(exc.code) finally: - outfile.flush() + try: + outfile.flush() + except Exception: + pass return exit_code diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index a0a028446d5fd..19ee2e3add7c2 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -17,22 +17,22 @@ import atexit import os -import sys +import select import signal import shlex +import socket import platform from subprocess import Popen, PIPE -from threading import Thread from py4j.java_gateway import java_import, JavaGateway, GatewayClient +from pyspark.serializers import read_int -def launch_gateway(): - SPARK_HOME = os.environ["SPARK_HOME"] - gateway_port = -1 +def launch_gateway(): if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) else: + SPARK_HOME = os.environ["SPARK_HOME"] # Launch the Py4j gateway using Spark's run command so that we pick up the # proper classpath and settings from spark-env.sh on_windows = platform.system() == "Windows" @@ -41,36 +41,42 @@ def launch_gateway(): submit_args = submit_args if submit_args is not None else "" submit_args = shlex.split(submit_args) command = [os.path.join(SPARK_HOME, script)] + submit_args + ["pyspark-shell"] + + # Start a socket that will be used by PythonGatewayServer to communicate its port to us + callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + callback_socket.bind(('127.0.0.1', 0)) + callback_socket.listen(1) + callback_host, callback_port = callback_socket.getsockname() + env = dict(os.environ) + env['_PYSPARK_DRIVER_CALLBACK_HOST'] = callback_host + env['_PYSPARK_DRIVER_CALLBACK_PORT'] = str(callback_port) + + # Launch the Java gateway. + # We open a pipe to stdin so that the Java gateway can die when the pipe is broken if not on_windows: # Don't send ctrl-c / SIGINT to the Java gateway: def preexec_func(): signal.signal(signal.SIGINT, signal.SIG_IGN) - env = dict(os.environ) env["IS_SUBPROCESS"] = "1" # tell JVM to exit after python exits - proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_func, env=env) + proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env) else: # preexec_fn not supported on Windows - proc = Popen(command, stdout=PIPE, stdin=PIPE) + proc = Popen(command, stdin=PIPE, env=env) - try: - # Determine which ephemeral port the server started on: - gateway_port = proc.stdout.readline() - gateway_port = int(gateway_port) - except ValueError: - # Grab the remaining lines of stdout - (stdout, _) = proc.communicate() - exit_code = proc.poll() - error_msg = "Launching GatewayServer failed" - error_msg += " with exit code %d!\n" % exit_code if exit_code else "!\n" - error_msg += "Warning: Expected GatewayServer to output a port, but found " - if gateway_port == "" and stdout == "": - error_msg += "no output.\n" - else: - error_msg += "the following:\n\n" - error_msg += "--------------------------------------------------------------\n" - error_msg += gateway_port + stdout - error_msg += "--------------------------------------------------------------\n" - raise Exception(error_msg) + gateway_port = None + # We use select() here in order to avoid blocking indefinitely if the subprocess dies + # before connecting + while gateway_port is None and proc.poll() is None: + timeout = 1 # (seconds) + readable, _, _ = select.select([callback_socket], [], [], timeout) + if callback_socket in readable: + gateway_connection = callback_socket.accept()[0] + # Determine which ephemeral port the server started on: + gateway_port = read_int(gateway_connection.makefile()) + gateway_connection.close() + callback_socket.close() + if gateway_port is None: + raise Exception("Java gateway process exited before sending the driver its port number") # In Windows, ensure the Java child processes do not linger after Python has exited. # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when @@ -88,21 +94,6 @@ def killChild(): Popen(["cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid)]) atexit.register(killChild) - # Create a thread to echo output from the GatewayServer, which is required - # for Java log output to show up: - class EchoOutputThread(Thread): - - def __init__(self, stream): - Thread.__init__(self) - self.daemon = True - self.stream = stream - - def run(self): - while True: - line = self.stream.readline() - sys.stderr.write(line) - EchoOutputThread(proc.stdout).start() - # Connect to the gateway gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False) diff --git a/python/pyspark/join.py b/python/pyspark/join.py index b4a844713745a..efc1ef9396412 100644 --- a/python/pyspark/join.py +++ b/python/pyspark/join.py @@ -35,8 +35,8 @@ def _do_python_join(rdd, other, numPartitions, dispatch): - vs = rdd.map(lambda (k, v): (k, (1, v))) - ws = other.map(lambda (k, v): (k, (2, v))) + vs = rdd.mapValues(lambda v: (1, v)) + ws = other.mapValues(lambda v: (2, v)) return vs.union(ws).groupByKey(numPartitions).flatMapValues(lambda x: dispatch(x.__iter__())) @@ -98,8 +98,8 @@ def dispatch(seq): def python_cogroup(rdds, numPartitions): def make_mapper(i): - return lambda (k, v): (k, (i, v)) - vrdds = [rdd.map(make_mapper(i)) for i, rdd in enumerate(rdds)] + return lambda v: (i, v) + vrdds = [rdd.mapValues(make_mapper(i)) for i, rdd in enumerate(rdds)] union_vrdds = reduce(lambda acc, other: acc.union(other), vrdds) rdd_len = len(vrdds) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 6bd2aa8e47837..7f42de531f3b4 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -15,10 +15,11 @@ # limitations under the License. # -from pyspark.ml.util import inherit_doc +from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\ HasRegParam +from pyspark.mllib.common import inherit_doc __all__ = ['LogisticRegression', 'LogisticRegressionModel'] @@ -32,22 +33,46 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> from pyspark.sql import Row >>> from pyspark.mllib.linalg import Vectors - >>> dataset = sqlCtx.inferSchema(sc.parallelize([ \ - Row(label=1.0, features=Vectors.dense(1.0)), \ - Row(label=0.0, features=Vectors.sparse(1, [], []))])) - >>> lr = LogisticRegression() \ - .setMaxIter(5) \ - .setRegParam(0.01) - >>> model = lr.fit(dataset) - >>> test0 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.dense(-1.0))])) + >>> df = sc.parallelize([ + ... Row(label=1.0, features=Vectors.dense(1.0)), + ... Row(label=0.0, features=Vectors.sparse(1, [], []))]).toDF() + >>> lr = LogisticRegression(maxIter=5, regParam=0.01) + >>> model = lr.fit(df) + >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF() >>> print model.transform(test0).head().prediction 0.0 - >>> test1 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))])) + >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF() >>> print model.transform(test1).head().prediction 1.0 + >>> lr.setParams("vector") + Traceback (most recent call last): + ... + TypeError: Method setParams forces keyword arguments. """ _java_class = "org.apache.spark.ml.classification.LogisticRegression" + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxIter=100, regParam=0.1): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxIter=100, regParam=0.1) + """ + super(LogisticRegression, self).__init__() + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxIter=100, regParam=0.1): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxIter=100, regParam=0.1) + Sets params for logistic regression. + """ + kwargs = self.setParams._input_kwargs + return self._set_params(**kwargs) + def _create_model(self, java_model): return LogisticRegressionModel(java_model) @@ -66,9 +91,9 @@ class LogisticRegressionModel(JavaModel): # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.feature tests") - sqlCtx = SQLContext(sc) + sqlContext = SQLContext(sc) globs['sc'] = sc - globs['sqlCtx'] = sqlCtx + globs['sqlContext'] = sqlContext (failure_count, test_count) = doctest.testmod( globs=globs, optionflags=doctest.ELLIPSIS) sc.stop() diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index e088acd0ca82d..1cfcd019dfb18 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -16,8 +16,9 @@ # from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures -from pyspark.ml.util import inherit_doc +from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaTransformer +from pyspark.mllib.common import inherit_doc __all__ = ['Tokenizer', 'HashingTF'] @@ -29,18 +30,45 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): splits it by white spaces. >>> from pyspark.sql import Row - >>> dataset = sqlCtx.inferSchema(sc.parallelize([Row(text="a b c")])) - >>> tokenizer = Tokenizer() \ - .setInputCol("text") \ - .setOutputCol("words") - >>> print tokenizer.transform(dataset).head() + >>> df = sc.parallelize([Row(text="a b c")]).toDF() + >>> tokenizer = Tokenizer(inputCol="text", outputCol="words") + >>> print tokenizer.transform(df).head() Row(text=u'a b c', words=[u'a', u'b', u'c']) - >>> print tokenizer.transform(dataset, {tokenizer.outputCol: "tokens"}).head() + >>> # Change a parameter. + >>> print tokenizer.setParams(outputCol="tokens").transform(df).head() Row(text=u'a b c', tokens=[u'a', u'b', u'c']) + >>> # Temporarily modify a parameter. + >>> print tokenizer.transform(df, {tokenizer.outputCol: "words"}).head() + Row(text=u'a b c', words=[u'a', u'b', u'c']) + >>> print tokenizer.transform(df).head() + Row(text=u'a b c', tokens=[u'a', u'b', u'c']) + >>> # Must use keyword arguments to specify params. + >>> tokenizer.setParams("text") + Traceback (most recent call last): + ... + TypeError: Method setParams forces keyword arguments. """ _java_class = "org.apache.spark.ml.feature.Tokenizer" + @keyword_only + def __init__(self, inputCol="input", outputCol="output"): + """ + __init__(self, inputCol="input", outputCol="output") + """ + super(Tokenizer, self).__init__() + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, inputCol="input", outputCol="output"): + """ + setParams(self, inputCol="input", outputCol="output") + Sets params for this Tokenizer. + """ + kwargs = self.setParams._input_kwargs + return self._set_params(**kwargs) + @inherit_doc class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): @@ -49,20 +77,37 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): hashing trick. >>> from pyspark.sql import Row - >>> dataset = sqlCtx.inferSchema(sc.parallelize([Row(words=["a", "b", "c"])])) - >>> hashingTF = HashingTF() \ - .setNumFeatures(10) \ - .setInputCol("words") \ - .setOutputCol("features") - >>> print hashingTF.transform(dataset).head().features + >>> df = sc.parallelize([Row(words=["a", "b", "c"])]).toDF() + >>> hashingTF = HashingTF(numFeatures=10, inputCol="words", outputCol="features") + >>> print hashingTF.transform(df).head().features + (10,[7,8,9],[1.0,1.0,1.0]) + >>> print hashingTF.setParams(outputCol="freqs").transform(df).head().freqs (10,[7,8,9],[1.0,1.0,1.0]) >>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"} - >>> print hashingTF.transform(dataset, params).head().vector + >>> print hashingTF.transform(df, params).head().vector (5,[2,3,4],[1.0,1.0,1.0]) """ _java_class = "org.apache.spark.ml.feature.HashingTF" + @keyword_only + def __init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output"): + """ + __init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output") + """ + super(HashingTF, self).__init__() + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output"): + """ + setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output") + Sets params for this HashingTF. + """ + kwargs = self.setParams._input_kwargs + return self._set_params(**kwargs) + if __name__ == "__main__": import doctest @@ -72,9 +117,9 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.feature tests") - sqlCtx = SQLContext(sc) + sqlContext = SQLContext(sc) globs['sc'] = sc - globs['sqlCtx'] = sqlCtx + globs['sqlContext'] = sqlContext (failure_count, test_count) = doctest.testmod( globs=globs, optionflags=doctest.ELLIPSIS) sc.stop() diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 5566792cead48..e3a53dd780c4c 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -80,3 +80,11 @@ def _dummy(): dummy = Params() dummy.uid = "undefined" return dummy + + def _set_params(self, **kwargs): + """ + Sets params. + """ + for param, value in kwargs.iteritems(): + self.paramMap[getattr(self, param)] = value + return self diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 2d239f8c802a0..83880a5afcd1d 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -18,7 +18,8 @@ from abc import ABCMeta, abstractmethod from pyspark.ml.param import Param, Params -from pyspark.ml.util import inherit_doc +from pyspark.ml.util import keyword_only +from pyspark.mllib.common import inherit_doc __all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel'] @@ -38,7 +39,7 @@ def fit(self, dataset, params={}): Fits a model to the input dataset with optional parameters. :param dataset: input dataset, which is an instance of - :py:class:`pyspark.sql.SchemaRDD` + :py:class:`pyspark.sql.DataFrame` :param params: an optional param map that overwrites embedded params :returns: fitted model @@ -61,7 +62,7 @@ def transform(self, dataset, params={}): Transforms the input dataset with optional parameters. :param dataset: input dataset, which is an instance of - :py:class:`pyspark.sql.SchemaRDD` + :py:class:`pyspark.sql.DataFrame` :param params: an optional param map that overwrites embedded params :returns: transformed dataset @@ -89,10 +90,16 @@ class Pipeline(Estimator): identity transformer. """ - def __init__(self): + @keyword_only + def __init__(self, stages=[]): + """ + __init__(self, stages=[]) + """ super(Pipeline, self).__init__() #: Param for pipeline stages. self.stages = Param(self, "stages", "pipeline stages") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) def setStages(self, value): """ @@ -110,6 +117,15 @@ def getStages(self): if self.stages in self.paramMap: return self.paramMap[self.stages] + @keyword_only + def setParams(self, stages=[]): + """ + setParams(self, stages=[]) + Sets params for Pipeline. + """ + kwargs = self.setParams._input_kwargs + return self._set_params(**kwargs) + def fit(self, dataset, params={}): paramMap = self._merge_params(params) stages = paramMap[self.stages] diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index b1caa84b6306a..6f7f39c40eb5a 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -15,21 +15,22 @@ # limitations under the License. # +from functools import wraps import uuid -def inherit_doc(cls): - for name, func in vars(cls).items(): - # only inherit docstring for public functions - if name.startswith("_"): - continue - if not func.__doc__: - for parent in cls.__bases__: - parent_func = getattr(parent, name, None) - if parent_func and getattr(parent_func, "__doc__", None): - func.__doc__ = parent_func.__doc__ - break - return cls +def keyword_only(func): + """ + A decorator that forces keyword arguments in the wrapped method + and saves actual input keyword arguments in `_input_kwargs`. + """ + @wraps(func) + def wrapper(*args, **kwargs): + if len(args) > 1: + raise TypeError("Method %s forces keyword arguments." % func.__name__) + wrapper._input_kwargs = kwargs + return func(*args, **kwargs) + return wrapper class Identifiable(object): diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 9e12ddc3d9b8f..31a66b3d2f730 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -21,7 +21,7 @@ from pyspark.sql import DataFrame from pyspark.ml.param import Params from pyspark.ml.pipeline import Estimator, Transformer -from pyspark.ml.util import inherit_doc +from pyspark.mllib.common import inherit_doc def _jvm(): @@ -102,7 +102,7 @@ def _fit_java(self, dataset, params={}): """ Fits a Java model to the input dataset. :param dataset: input dataset, which is an instance of - :py:class:`pyspark.sql.SchemaRDD` + :py:class:`pyspark.sql.DataFrame` :param params: additional params (overwriting embedded values) :return: fitted Java model """ diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py index c3217620e3c4e..6449800d9c120 100644 --- a/python/pyspark/mllib/__init__.py +++ b/python/pyspark/mllib/__init__.py @@ -19,7 +19,7 @@ Python bindings for MLlib. """ -# MLlib currently needs and NumPy 1.4+, so complain if lower +# MLlib currently needs NumPy 1.4+, so complain if lower import numpy if numpy.version.version < '1.4': diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 00e2e76711e84..e4765173709e8 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -207,7 +207,7 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType """ def train(rdd, i): return callMLlibFunc("trainLogisticRegressionModelWithLBFGS", rdd, int(iterations), i, - float(regParam), str(regType), bool(intercept), int(corrections), + float(regParam), regType, bool(intercept), int(corrections), float(tolerance)) return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index f6b97abb1723c..949db5705abd7 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -152,7 +152,7 @@ def predictSoft(self, x): class GaussianMixture(object): """ - Estimate model parameters with the expectation-maximization algorithm. + Learning algorithm for Gaussian Mixtures using the expectation-maximization algorithm. :param data: RDD of data points :param k: Number of components diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index 3c5ee66cd8b64..621591c26b77f 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -134,3 +134,20 @@ def __del__(self): def call(self, name, *a): """Call method of java_model""" return callJavaFunc(self._sc, getattr(self._java_model, name), *a) + + +def inherit_doc(cls): + """ + A decorator that makes a class inherit documentation from its parents. + """ + for name, func in vars(cls).items(): + # only inherit docstring for public functions + if name.startswith("_"): + continue + if not func.__doc__: + for parent in cls.__bases__: + parent_func = getattr(parent, name, None) + if parent_func and getattr(parent_func, "__doc__", None): + func.__doc__ = parent_func.__doc__ + break + return cls diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 10df6288065b8..0ffe092a07365 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -58,7 +58,8 @@ class Normalizer(VectorTransformer): For any 1 <= `p` < float('inf'), normalizes samples using sum(abs(vector) :sup:`p`) :sup:`(1/p)` as norm. - For `p` = float('inf'), max(abs(vector)) will be used as norm for normalization. + For `p` = float('inf'), max(abs(vector)) will be used as norm for + normalization. >>> v = Vectors.dense(range(3)) >>> nor = Normalizer(1) @@ -120,9 +121,14 @@ def transform(self, vector): """ Applies standardization transformation on a vector. + Note: In Python, transform cannot currently be used within + an RDD transformation or action. + Call transform directly on the RDD instead. + :param vector: Vector or RDD of Vector to be standardized. - :return: Standardized vector. If the variance of a column is zero, - it will return default `0.0` for the column with zero variance. + :return: Standardized vector. If the variance of a column is + zero, it will return default `0.0` for the column with + zero variance. """ return JavaVectorTransformer.transform(self, vector) @@ -148,9 +154,10 @@ def __init__(self, withMean=False, withStd=True): """ :param withMean: False by default. Centers the data with mean before scaling. It will build a dense output, so this - does not work on sparse input and will raise an exception. - :param withStd: True by default. Scales the data to unit standard - deviation. + does not work on sparse input and will raise an + exception. + :param withStd: True by default. Scales the data to unit + standard deviation. """ if not (withMean or withStd): warnings.warn("Both withMean and withStd are false. The model does nothing.") @@ -159,10 +166,11 @@ def __init__(self, withMean=False, withStd=True): def fit(self, dataset): """ - Computes the mean and variance and stores as a model to be used for later scaling. + Computes the mean and variance and stores as a model to be used + for later scaling. - :param data: The data used to compute the mean and variance to build - the transformation model. + :param data: The data used to compute the mean and variance + to build the transformation model. :return: a StandardScalarModel """ dataset = dataset.map(_convert_to_vector) @@ -174,7 +182,8 @@ class HashingTF(object): """ .. note:: Experimental - Maps a sequence of terms to their term frequencies using the hashing trick. + Maps a sequence of terms to their term frequencies using the hashing + trick. Note: the terms must be hashable (can not be dict/set/list...). @@ -195,8 +204,9 @@ def indexOf(self, term): def transform(self, document): """ - Transforms the input document (list of terms) to term frequency vectors, - or transform the RDD of document to RDD of term frequency vectors. + Transforms the input document (list of terms) to term frequency + vectors, or transform the RDD of document to RDD of term + frequency vectors. """ if isinstance(document, RDD): return document.map(self.transform) @@ -220,7 +230,12 @@ def transform(self, x): the terms which occur in fewer than `minDocFreq` documents will have an entry of 0. - :param x: an RDD of term frequency vectors or a term frequency vector + Note: In Python, transform cannot currently be used within + an RDD transformation or action. + Call transform directly on the RDD instead. + + :param x: an RDD of term frequency vectors or a term frequency + vector :return: an RDD of TF-IDF vectors or a TF-IDF vector """ if isinstance(x, RDD): @@ -241,9 +256,9 @@ class IDF(object): of documents that contain term `t`. This implementation supports filtering out terms which do not appear - in a minimum number of documents (controlled by the variable `minDocFreq`). - For terms that are not in at least `minDocFreq` documents, the IDF is - found as 0, resulting in TF-IDFs of 0. + in a minimum number of documents (controlled by the variable + `minDocFreq`). For terms that are not in at least `minDocFreq` + documents, the IDF is found as 0, resulting in TF-IDFs of 0. >>> n = 4 >>> freqs = [Vectors.sparse(n, (1, 3), (1.0, 2.0)), @@ -325,15 +340,16 @@ class Word2Vec(object): The vector representation can be used as features in natural language processing and machine learning algorithms. - We used skip-gram model in our implementation and hierarchical softmax - method to train the model. The variable names in the implementation - matches the original C implementation. + We used skip-gram model in our implementation and hierarchical + softmax method to train the model. The variable names in the + implementation matches the original C implementation. - For original C implementation, see https://code.google.com/p/word2vec/ + For original C implementation, + see https://code.google.com/p/word2vec/ For research papers, see Efficient Estimation of Word Representations in Vector Space - and - Distributed Representations of Words and Phrases and their Compositionality. + and Distributed Representations of Words and Phrases and their + Compositionality. >>> sentence = "a b " * 100 + "a c " * 10 >>> localDoc = [sentence, sentence] @@ -374,15 +390,16 @@ def setLearningRate(self, learningRate): def setNumPartitions(self, numPartitions): """ - Sets number of partitions (default: 1). Use a small number for accuracy. + Sets number of partitions (default: 1). Use a small number for + accuracy. """ self.numPartitions = numPartitions return self def setNumIterations(self, numIterations): """ - Sets number of iterations (default: 1), which should be smaller than or equal to number of - partitions. + Sets number of iterations (default: 1), which should be smaller + than or equal to number of partitions. """ self.numIterations = numIterations return self diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 7f21190ed8c25..8b791ff6a7877 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -29,7 +29,7 @@ import numpy as np -from pyspark.sql import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \ +from pyspark.sql.types import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \ IntegerType, ByteType @@ -152,6 +152,9 @@ def deserialize(self, datum): else: raise ValueError("do not recognize type %r" % tpe) + def simpleString(self): + return "vector" + class Vector(object): @@ -170,7 +173,24 @@ def toArray(self): class DenseVector(Vector): """ - A dense vector represented by a value array. + A dense vector represented by a value array. We use numpy array for + storage and arithmetics will be delegated to the underlying numpy + array. + + >>> v = Vectors.dense([1.0, 2.0]) + >>> u = Vectors.dense([3.0, 4.0]) + >>> v + u + DenseVector([4.0, 6.0]) + >>> 2 - v + DenseVector([1.0, 0.0]) + >>> v / 2 + DenseVector([0.5, 1.0]) + >>> v * u + DenseVector([3.0, 8.0]) + >>> u / v + DenseVector([3.0, 2.0]) + >>> u % 2 + DenseVector([1.0, 0.0]) """ def __init__(self, ar): if isinstance(ar, basestring): @@ -289,6 +309,25 @@ def __ne__(self, other): def __getattr__(self, item): return getattr(self.array, item) + def _delegate(op): + def func(self, other): + if isinstance(other, DenseVector): + other = other.array + return DenseVector(getattr(self.array, op)(other)) + return func + + __neg__ = _delegate("__neg__") + __add__ = _delegate("__add__") + __sub__ = _delegate("__sub__") + __mul__ = _delegate("__mul__") + __div__ = _delegate("__div__") + __mod__ = _delegate("__mod__") + __radd__ = _delegate("__radd__") + __rsub__ = _delegate("__rsub__") + __rmul__ = _delegate("__rmul__") + __rdiv__ = _delegate("__rdiv__") + __rmod__ = _delegate("__rmod__") + class SparseVector(Vector): """ diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 0d99e6dedfad9..c5c4c13dae105 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -19,7 +19,8 @@ from pyspark import SparkContext from pyspark.rdd import RDD -from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc +from pyspark.mllib.util import JavaLoader, JavaSaveable __all__ = ['MatrixFactorizationModel', 'ALS', 'Rating'] @@ -39,7 +40,8 @@ def __reduce__(self): return Rating, (int(self.user), int(self.product), float(self.rating)) -class MatrixFactorizationModel(JavaModelWrapper): +@inherit_doc +class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): """A matrix factorisation model trained by regularized alternating least-squares. @@ -50,7 +52,7 @@ class MatrixFactorizationModel(JavaModelWrapper): >>> ratings = sc.parallelize([r1, r2, r3]) >>> model = ALS.trainImplicit(ratings, 1, seed=10) >>> model.predict(2, 2) - 0.43... + 0.4... >>> testset = sc.parallelize([(1, 2), (1, 1)]) >>> model = ALS.train(ratings, 2, seed=0) @@ -80,7 +82,20 @@ class MatrixFactorizationModel(JavaModelWrapper): >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10) >>> model.predict(2,2) - 0.43... + 0.4... + + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = MatrixFactorizationModel.load(sc, path) + >>> sameModel.predict(2,2) + 0.4... + >>> sameModel.predictAll(testset).collect() + [Rating(... + >>> try: + ... os.removedirs(path) + ... except OSError: + ... pass """ def predict(self, user, product): return self._java_model.predict(int(user), int(product)) @@ -98,6 +113,12 @@ def userFeatures(self): def productFeatures(self): return self.call("getProductFeatures") + @classmethod + def load(cls, sc, path): + model = cls._load_java(sc, path) + wrapper = sc._jvm.MatrixFactorizationModelWrapper(model) + return MatrixFactorizationModel(wrapper) + class ALS(object): diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 210060140fd91..ad2b0505e765b 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -18,11 +18,13 @@ import numpy as np from numpy import array -from pyspark.mllib.common import callMLlibFunc +from pyspark.mllib.common import callMLlibFunc, inherit_doc from pyspark.mllib.linalg import SparseVector, _convert_to_vector -__all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel', - 'LinearRegressionWithSGD', 'LassoWithSGD', 'RidgeRegressionWithSGD'] +__all__ = ['LabeledPoint', 'LinearModel', + 'LinearRegressionModel', 'LinearRegressionWithSGD', + 'RidgeRegressionModel', 'RidgeRegressionWithSGD', + 'LassoModel', 'LassoWithSGD'] class LabeledPoint(object): @@ -31,8 +33,11 @@ class LabeledPoint(object): The features and labels of a data point. :param label: Label for this data point. - :param features: Vector of features for this point (NumPy array, list, - pyspark.mllib.linalg.SparseVector, or scipy.sparse column matrix) + :param features: Vector of features for this point (NumPy array, + list, pyspark.mllib.linalg.SparseVector, or scipy.sparse + column matrix) + + Note: 'label' and 'features' are accessible as class attributes. """ def __init__(self, label, features): @@ -69,6 +74,7 @@ def __repr__(self): return "(weights=%s, intercept=%r)" % (self._coeff, self._intercept) +@inherit_doc class LinearRegressionModelBase(LinearModel): """A linear regression model. @@ -89,6 +95,7 @@ def predict(self, x): return self.weights.dot(x) + self.intercept +@inherit_doc class LinearRegressionModel(LinearRegressionModelBase): """A linear regression model derived from a least-squares fit. @@ -128,7 +135,8 @@ def _regression_train_wrapper(train_func, modelClass, data, initial_weights): first = data.first() if not isinstance(first, LabeledPoint): raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first) - initial_weights = initial_weights or [0.0] * len(data.first().features) + if initial_weights is None: + initial_weights = [0.0] * len(data.first().features) weights, intercept = train_func(data, _convert_to_vector(initial_weights)) return modelClass(weights, intercept) @@ -162,7 +170,7 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, @param intercept: Boolean parameter which indicates the use or not of the augmented representation for training data (i.e. whether bias features - are activated or not). + are activated or not). (default: False) """ def train(rdd, i): return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations), @@ -172,6 +180,7 @@ def train(rdd, i): return _regression_train_wrapper(train, LinearRegressionModel, data, initialWeights) +@inherit_doc class LassoModel(LinearRegressionModelBase): """A linear regression model derived from a least-squares fit with an @@ -218,6 +227,7 @@ def train(rdd, i): return _regression_train_wrapper(train, LassoModel, data, initialWeights) +@inherit_doc class RidgeRegressionModel(LinearRegressionModelBase): """A linear regression model derived from a least-squares fit with an diff --git a/python/pyspark/mllib/stat/__init__.py b/python/pyspark/mllib/stat/__init__.py index b686d955a0080..e3e128513e0d7 100644 --- a/python/pyspark/mllib/stat/__init__.py +++ b/python/pyspark/mllib/stat/__init__.py @@ -21,5 +21,7 @@ from pyspark.mllib.stat._statistics import * from pyspark.mllib.stat.distribution import MultivariateGaussian +from pyspark.mllib.stat.test import ChiSqTestResult -__all__ = ["Statistics", "MultivariateStatisticalSummary", "MultivariateGaussian"] +__all__ = ["Statistics", "MultivariateStatisticalSummary", "ChiSqTestResult", + "MultivariateGaussian"] diff --git a/python/pyspark/mllib/stat/distribution.py b/python/pyspark/mllib/stat/distribution.py index 07792e1532046..46f7a1d2f277a 100644 --- a/python/pyspark/mllib/stat/distribution.py +++ b/python/pyspark/mllib/stat/distribution.py @@ -22,7 +22,8 @@ class MultivariateGaussian(namedtuple('MultivariateGaussian', ['mu', 'sigma'])): - """ Represents a (mu, sigma) tuple + """Represents a (mu, sigma) tuple + >>> m = MultivariateGaussian(Vectors.dense([11,12]),DenseMatrix(2, 2, (1.0, 3.0, 5.0, 2.0))) >>> (m.mu, m.sigma.toArray()) (DenseVector([11.0, 12.0]), array([[ 1., 5.],[ 3., 2.]])) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 42aa22873772d..113f2163fb6e4 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -19,7 +19,9 @@ Fuller unit tests for Python MLlib. """ +import os import sys +import tempfile import array as pyarray from numpy import array, array_equal @@ -34,6 +36,7 @@ else: import unittest +from pyspark.mllib.common import _to_java_object_rdd from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ DenseMatrix, Vectors, Matrices from pyspark.mllib.regression import LabeledPoint @@ -195,7 +198,8 @@ def test_gmm_deterministic(self): def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes - from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees + from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\ + RandomForestModel, GradientBoostedTrees, GradientBoostedTreesModel data = [ LabeledPoint(0.0, [1, 0, 0]), LabeledPoint(1.0, [0, 1, 1]), @@ -205,6 +209,8 @@ def test_classification(self): rdd = self.sc.parallelize(data) features = [p.features.tolist() for p in data] + temp_dir = tempfile.mkdtemp() + lr_model = LogisticRegressionWithSGD.train(rdd) self.assertTrue(lr_model.predict(features[0]) <= 0) self.assertTrue(lr_model.predict(features[1]) > 0) @@ -231,6 +237,11 @@ def test_classification(self): self.assertTrue(dt_model.predict(features[2]) <= 0) self.assertTrue(dt_model.predict(features[3]) > 0) + dt_model_dir = os.path.join(temp_dir, "dt") + dt_model.save(self.sc, dt_model_dir) + same_dt_model = DecisionTreeModel.load(self.sc, dt_model_dir) + self.assertEqual(same_dt_model.toDebugString(), dt_model.toDebugString()) + rf_model = RandomForest.trainClassifier( rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100) self.assertTrue(rf_model.predict(features[0]) <= 0) @@ -238,6 +249,11 @@ def test_classification(self): self.assertTrue(rf_model.predict(features[2]) <= 0) self.assertTrue(rf_model.predict(features[3]) > 0) + rf_model_dir = os.path.join(temp_dir, "rf") + rf_model.save(self.sc, rf_model_dir) + same_rf_model = RandomForestModel.load(self.sc, rf_model_dir) + self.assertEqual(same_rf_model.toDebugString(), rf_model.toDebugString()) + gbt_model = GradientBoostedTrees.trainClassifier( rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) self.assertTrue(gbt_model.predict(features[0]) <= 0) @@ -245,6 +261,16 @@ def test_classification(self): self.assertTrue(gbt_model.predict(features[2]) <= 0) self.assertTrue(gbt_model.predict(features[3]) > 0) + gbt_model_dir = os.path.join(temp_dir, "gbt") + gbt_model.save(self.sc, gbt_model_dir) + same_gbt_model = GradientBoostedTreesModel.load(self.sc, gbt_model_dir) + self.assertEqual(same_gbt_model.toDebugString(), gbt_model.toDebugString()) + + try: + os.removedirs(temp_dir) + except OSError: + pass + def test_regression(self): from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ RidgeRegressionWithSGD @@ -285,7 +311,7 @@ def test_regression(self): self.assertTrue(dt_model.predict(features[3]) > 0) rf_model = RandomForest.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100) + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100, seed=1) self.assertTrue(rf_model.predict(features[0]) <= 0) self.assertTrue(rf_model.predict(features[1]) > 0) self.assertTrue(rf_model.predict(features[2]) <= 0) @@ -298,6 +324,13 @@ def test_regression(self): self.assertTrue(gbt_model.predict(features[2]) <= 0) self.assertTrue(gbt_model.predict(features[3]) > 0) + try: + LinearRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0])) + LassoWithSGD.train(rdd, initialWeights=array([1.0, 1.0])) + RidgeRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0])) + except ValueError: + self.fail() + class StatTests(PySparkTestCase): # SPARK-4023 @@ -335,7 +368,7 @@ def test_infer_schema(self): sqlCtx = SQLContext(self.sc) rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)]) srdd = sqlCtx.inferSchema(rdd) - schema = srdd.schema() + schema = srdd.schema field = [f for f in schema.fields if f.name == "features"][0] self.assertEqual(field.dataType, self.udt) vectors = srdd.map(lambda p: p.features).collect() @@ -588,6 +621,13 @@ def test_right_number_of_results(self): self.assertEqual(len(chi), num_cols) self.assertIsNotNone(chi[1000]) + +class SerDeTest(PySparkTestCase): + def test_to_java_object_rdd(self): # SPARK-6660 + data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0L) + self.assertEqual(_to_java_object_rdd(data).count(), 10) + + if __name__ == "__main__": if not _have_scipy: print "NOTE: Skipping SciPy tests as it does not seem to be installed" diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index aae48f213246b..a7a4d2aaf855b 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -20,19 +20,24 @@ import random from pyspark import SparkContext, RDD -from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper +from pyspark.mllib.common import callMLlibFunc, inherit_doc, JavaModelWrapper from pyspark.mllib.linalg import _convert_to_vector from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.util import JavaLoader, JavaSaveable __all__ = ['DecisionTreeModel', 'DecisionTree', 'RandomForestModel', - 'RandomForest', 'GradientBoostedTrees'] + 'RandomForest', 'GradientBoostedTreesModel', 'GradientBoostedTrees'] -class TreeEnsembleModel(JavaModelWrapper): +class TreeEnsembleModel(JavaModelWrapper, JavaSaveable): def predict(self, x): """ Predict values for a single data point or an RDD of points using the model trained. + + Note: In Python, predict cannot currently be used within an RDD + transformation or action. + Call predict directly on the RDD instead. """ if isinstance(x, RDD): return self.call("predict", x.map(_convert_to_vector)) @@ -48,7 +53,8 @@ def numTrees(self): def totalNumNodes(self): """ - Get total number of nodes, summed over all trees in the ensemble. + Get total number of nodes, summed over all trees in the + ensemble. """ return self.call("totalNumNodes") @@ -61,7 +67,7 @@ def toDebugString(self): return self._java_model.toDebugString() -class DecisionTreeModel(JavaModelWrapper): +class DecisionTreeModel(JavaModelWrapper, JavaSaveable, JavaLoader): """ .. note:: Experimental @@ -71,6 +77,10 @@ def predict(self, x): """ Predict the label of one or more examples. + Note: In Python, predict cannot currently be used within an RDD + transformation or action. + Call predict directly on the RDD instead. + :param x: Data point (feature vector), or an RDD of data points (feature vectors). """ @@ -94,12 +104,17 @@ def toDebugString(self): """ full model. """ return self._java_model.toDebugString() + @classmethod + def _java_loader_class(cls): + return "org.apache.spark.mllib.tree.model.DecisionTreeModel" + class DecisionTree(object): """ .. note:: Experimental - Learning algorithm for a decision tree model for classification or regression. + Learning algorithm for a decision tree model for classification or + regression. """ @classmethod @@ -176,17 +191,17 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, :param data: Training data: RDD of LabeledPoint. Labels are real numbers. - :param categoricalFeaturesInfo: Map from categorical feature index - to number of categories. - Any feature not in this map - is treated as continuous. + :param categoricalFeaturesInfo: Map from categorical feature + index to number of categories. + Any feature not in this map is treated as continuous. :param impurity: Supported values: "variance" :param maxDepth: Max depth of tree. - E.g., depth 0 means 1 leaf node. - Depth 1 means 1 internal node + 2 leaf nodes. - :param maxBins: Number of bins used for finding splits at each node. - :param minInstancesPerNode: Min number of instances required at child - nodes to create the parent split + E.g., depth 0 means 1 leaf node. + Depth 1 means 1 internal node + 2 leaf nodes. + :param maxBins: Number of bins used for finding splits at each + node. + :param minInstancesPerNode: Min number of instances required at + child nodes to create the parent split :param minInfoGain: Min info gain required to create a split :return: DecisionTreeModel @@ -216,19 +231,25 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) -class RandomForestModel(TreeEnsembleModel): +@inherit_doc +class RandomForestModel(TreeEnsembleModel, JavaLoader): """ .. note:: Experimental Represents a random forest model. """ + @classmethod + def _java_loader_class(cls): + return "org.apache.spark.mllib.tree.model.RandomForestModel" + class RandomForest(object): """ .. note:: Experimental - Learning algorithm for a random forest model for classification or regression. + Learning algorithm for a random forest model for classification or + regression. """ supportedFeatureSubsetStrategies = ("auto", "all", "sqrt", "log2", "onethird") @@ -255,26 +276,30 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, Method to train a decision tree model for binary or multiclass classification. - :param data: Training dataset: RDD of LabeledPoint. Labels should take - values {0, 1, ..., numClasses-1}. + :param data: Training dataset: RDD of LabeledPoint. Labels + should take values {0, 1, ..., numClasses-1}. :param numClasses: number of classes for classification. - :param categoricalFeaturesInfo: Map storing arity of categorical features. - E.g., an entry (n -> k) indicates that feature n is categorical - with k categories indexed from 0: {0, 1, ..., k-1}. + :param categoricalFeaturesInfo: Map storing arity of categorical + features. E.g., an entry (n -> k) indicates that + feature n is categorical with k categories indexed + from 0: {0, 1, ..., k-1}. :param numTrees: Number of trees in the random forest. - :param featureSubsetStrategy: Number of features to consider for splits at - each node. - Supported: "auto" (default), "all", "sqrt", "log2", "onethird". - If "auto" is set, this parameter is set based on numTrees: - if numTrees == 1, set to "all"; - if numTrees > 1 (forest) set to "sqrt". + :param featureSubsetStrategy: Number of features to consider for + splits at each node. + Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + If "auto" is set, this parameter is set based on numTrees: + if numTrees == 1, set to "all"; + if numTrees > 1 (forest) set to "sqrt". :param impurity: Criterion used for information gain calculation. Supported values: "gini" (recommended) or "entropy". - :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 leaf node; - depth 1 means 1 internal node + 2 leaf nodes. (default: 4) - :param maxBins: maximum number of bins used for splitting features - (default: 100) - :param seed: Random seed for bootstrapping and choosing feature subsets. + :param maxDepth: Maximum depth of the tree. + E.g., depth 0 means 1 leaf node; depth 1 means + 1 internal node + 2 leaf nodes. (default: 4) + :param maxBins: maximum number of bins used for splitting + features + (default: 100) + :param seed: Random seed for bootstrapping and choosing feature + subsets. :return: RandomForestModel that can be used for prediction Example usage: @@ -336,19 +361,21 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetSt {0, 1, ..., k-1}. :param numTrees: Number of trees in the random forest. :param featureSubsetStrategy: Number of features to consider for - splits at each node. - Supported: "auto" (default), "all", "sqrt", "log2", "onethird". - If "auto" is set, this parameter is set based on numTrees: - if numTrees == 1, set to "all"; - if numTrees > 1 (forest) set to "onethird" for regression. - :param impurity: Criterion used for information gain calculation. - Supported values: "variance". - :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 - leaf node; depth 1 means 1 internal node + 2 leaf nodes. - (default: 4) - :param maxBins: maximum number of bins used for splitting features - (default: 100) - :param seed: Random seed for bootstrapping and choosing feature subsets. + splits at each node. + Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + If "auto" is set, this parameter is set based on numTrees: + if numTrees == 1, set to "all"; + if numTrees > 1 (forest) set to "onethird" for regression. + :param impurity: Criterion used for information gain + calculation. + Supported values: "variance". + :param maxDepth: Maximum depth of the tree. E.g., depth 0 means + 1 leaf node; depth 1 means 1 internal node + 2 leaf + nodes. (default: 4) + :param maxBins: maximum number of bins used for splitting + features (default: 100) + :param seed: Random seed for bootstrapping and choosing feature + subsets. :return: RandomForestModel that can be used for prediction Example usage: @@ -381,19 +408,25 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetSt featureSubsetStrategy, impurity, maxDepth, maxBins, seed) -class GradientBoostedTreesModel(TreeEnsembleModel): +@inherit_doc +class GradientBoostedTreesModel(TreeEnsembleModel, JavaLoader): """ .. note:: Experimental Represents a gradient-boosted tree model. """ + @classmethod + def _java_loader_class(cls): + return "org.apache.spark.mllib.tree.model.GradientBoostedTreesModel" + class GradientBoostedTrees(object): """ .. note:: Experimental - Learning algorithm for a gradient boosted trees model for classification or regression. + Learning algorithm for a gradient boosted trees model for + classification or regression. """ @classmethod @@ -409,24 +442,29 @@ def _train(cls, data, algo, categoricalFeaturesInfo, def trainClassifier(cls, data, categoricalFeaturesInfo, loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3): """ - Method to train a gradient-boosted trees model for classification. + Method to train a gradient-boosted trees model for + classification. - :param data: Training dataset: RDD of LabeledPoint. Labels should take values {0, 1}. + :param data: Training dataset: RDD of LabeledPoint. + Labels should take values {0, 1}. :param categoricalFeaturesInfo: Map storing arity of categorical features. E.g., an entry (n -> k) indicates that feature n is categorical with k categories indexed from 0: {0, 1, ..., k-1}. - :param loss: Loss function used for minimization during gradient boosting. - Supported: {"logLoss" (default), "leastSquaresError", "leastAbsoluteError"}. + :param loss: Loss function used for minimization during gradient + boosting. Supported: {"logLoss" (default), + "leastSquaresError", "leastAbsoluteError"}. :param numIterations: Number of iterations of boosting. (default: 100) - :param learningRate: Learning rate for shrinking the contribution of each estimator. - The learning rate should be between in the interval (0, 1] - (default: 0.1) - :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 - leaf node; depth 1 means 1 internal node + 2 leaf nodes. - (default: 3) - :return: GradientBoostedTreesModel that can be used for prediction + :param learningRate: Learning rate for shrinking the + contribution of each estimator. The learning rate + should be between in the interval (0, 1]. + (default: 0.1) + :param maxDepth: Maximum depth of the tree. E.g., depth 0 means + 1 leaf node; depth 1 means 1 internal node + 2 leaf + nodes. (default: 3) + :return: GradientBoostedTreesModel that can be used for + prediction Example usage: @@ -470,17 +508,20 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, features. E.g., an entry (n -> k) indicates that feature n is categorical with k categories indexed from 0: {0, 1, ..., k-1}. - :param loss: Loss function used for minimization during gradient boosting. - Supported: {"logLoss" (default), "leastSquaresError", "leastAbsoluteError"}. + :param loss: Loss function used for minimization during gradient + boosting. Supported: {"logLoss" (default), + "leastSquaresError", "leastAbsoluteError"}. :param numIterations: Number of iterations of boosting. (default: 100) - :param learningRate: Learning rate for shrinking the contribution of each estimator. - The learning rate should be between in the interval (0, 1] - (default: 0.1) - :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 - leaf node; depth 1 means 1 internal node + 2 leaf nodes. - (default: 3) - :return: GradientBoostedTreesModel that can be used for prediction + :param learningRate: Learning rate for shrinking the + contribution of each estimator. The learning rate + should be between in the interval (0, 1]. + (default: 0.1) + :param maxDepth: Maximum depth of the tree. E.g., depth 0 means + 1 leaf node; depth 1 means 1 internal node + 2 leaf + nodes. (default: 3) + :return: GradientBoostedTreesModel that can be used for + prediction Example usage: diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 4ed978b45409c..e877c720ac77a 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -18,7 +18,7 @@ import numpy as np import warnings -from pyspark.mllib.common import callMLlibFunc +from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper, inherit_doc from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint @@ -168,6 +168,93 @@ def loadLabeledPoints(sc, path, minPartitions=None): return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions) +class Saveable(object): + """ + Mixin for models and transformers which may be saved as files. + """ + + def save(self, sc, path): + """ + Save this model to the given path. + + This saves: + * human-readable (JSON) model metadata to path/metadata/ + * Parquet formatted data to path/data/ + + The model may be loaded using py:meth:`Loader.load`. + + :param sc: Spark context used to save model data. + :param path: Path specifying the directory in which to save + this model. If the directory already exists, + this method throws an exception. + """ + raise NotImplementedError + + +@inherit_doc +class JavaSaveable(Saveable): + """ + Mixin for models that provide save() through their Scala + implementation. + """ + + def save(self, sc, path): + self._java_model.save(sc._jsc.sc(), path) + + +class Loader(object): + """ + Mixin for classes which can load saved models from files. + """ + + @classmethod + def load(cls, sc, path): + """ + Load a model from the given path. The model should have been + saved using py:meth:`Saveable.save`. + + :param sc: Spark context used for loading model files. + :param path: Path specifying the directory to which the model + was saved. + :return: model instance + """ + raise NotImplemented + + +@inherit_doc +class JavaLoader(Loader): + """ + Mixin for classes which can load saved models using its Scala + implementation. + """ + + @classmethod + def _java_loader_class(cls): + """ + Returns the full class name of the Java loader. The default + implementation replaces "pyspark" by "org.apache.spark" in + the Python full class name. + """ + java_package = cls.__module__.replace("pyspark", "org.apache.spark") + return ".".join([java_package, cls.__name__]) + + @classmethod + def _load_java(cls, sc, path): + """ + Load a Java model from the given path. + """ + java_class = cls._java_loader_class() + java_obj = sc._jvm + for name in java_class.split("."): + java_obj = getattr(java_obj, name) + return java_obj.load(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = cls._load_java(sc, path) + return cls(java_model) + + def _test(): import doctest from pyspark.context import SparkContext diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 2f8a0edfe9644..eb8c6b4253873 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -19,7 +19,6 @@ from collections import defaultdict from itertools import chain, ifilter, imap import operator -import os import sys import shlex from subprocess import Popen, PIPE @@ -29,6 +28,7 @@ import heapq import bisect import random +import socket from math import sqrt, log, isinf, isnan, pow, ceil from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ @@ -111,6 +111,31 @@ def _parse_memory(s): return int(float(s[:-1]) * units[s[-1].lower()]) +def _load_from_socket(port, serializer): + sock = socket.socket() + sock.settimeout(3) + try: + sock.connect(("localhost", port)) + rf = sock.makefile("rb", 65536) + for item in serializer.load_stream(rf): + yield item + finally: + sock.close() + + +class Partitioner(object): + def __init__(self, numPartitions, partitionFunc): + self.numPartitions = numPartitions + self.partitionFunc = partitionFunc + + def __eq__(self, other): + return (isinstance(other, Partitioner) and self.numPartitions == other.numPartitions + and self.partitionFunc == other.partitionFunc) + + def __call__(self, k): + return self.partitionFunc(k) % self.numPartitions + + class RDD(object): """ @@ -126,7 +151,7 @@ def __init__(self, jrdd, ctx, jrdd_deserializer=AutoBatchedSerializer(PickleSeri self.ctx = ctx self._jrdd_deserializer = jrdd_deserializer self._id = jrdd.id() - self._partitionFunc = None + self.partitioner = None def _pickled(self): return self._reserialize(AutoBatchedSerializer(PickleSerializer())) @@ -450,14 +475,17 @@ def union(self, other): if self._jrdd_deserializer == other._jrdd_deserializer: rdd = RDD(self._jrdd.union(other._jrdd), self.ctx, self._jrdd_deserializer) - return rdd else: # These RDDs contain data in different serialized formats, so we # must normalize them to the default serializer. self_copy = self._reserialize() other_copy = other._reserialize() - return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx, - self.ctx.serializer) + rdd = RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx, + self.ctx.serializer) + if (self.partitioner == other.partitioner and + self.getNumPartitions() == rdd.getNumPartitions()): + rdd.partitioner = self.partitioner + return rdd def intersection(self, other): """ @@ -561,7 +589,7 @@ def sortPartition(iterator): maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner fraction = min(maxSampleSize / max(rddSize, 1), 1.0) samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect() - samples = sorted(samples, reverse=(not ascending), key=keyfunc) + samples = sorted(samples, key=keyfunc) # we have numPartitions many parts but one of the them has # an implicit boundary @@ -682,21 +710,8 @@ def collect(self): Return a list that contains all of the elements in this RDD. """ with SCCallSiteSync(self.context) as css: - bytesInJava = self._jrdd.collect().iterator() - return list(self._collect_iterator_through_file(bytesInJava)) - - def _collect_iterator_through_file(self, iterator): - # Transferring lots of data through Py4J can be slow because - # socket.readline() is inefficient. Instead, we'll dump the data to a - # file and read it back. - tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir) - tempFile.close() - self.ctx._writeToFile(iterator, tempFile.name) - # Read the data into Python and deserialize it: - with open(tempFile.name, 'rb') as tempFile: - for item in self._jrdd_deserializer.load_stream(tempFile): - yield item - os.unlink(tempFile.name) + port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) + return list(_load_from_socket(port, self._jrdd_deserializer)) def reduce(self, f): """ @@ -1366,10 +1381,14 @@ def saveAsPickleFile(self, path, batchSize=10): ser = BatchedSerializer(PickleSerializer(), batchSize) self._reserialize(ser)._jrdd.saveAsObjectFile(path) - def saveAsTextFile(self, path): + def saveAsTextFile(self, path, compressionCodecClass=None): """ Save this RDD as a text file, using string representations of elements. + @param path: path to text file + @param compressionCodecClass: (None by default) string i.e. + "org.apache.hadoop.io.compress.GzipCodec" + >>> tempFile = NamedTemporaryFile(delete=True) >>> tempFile.close() >>> sc.parallelize(range(10)).saveAsTextFile(tempFile.name) @@ -1385,6 +1404,16 @@ def saveAsTextFile(self, path): >>> sc.parallelize(['', 'foo', '', 'bar', '']).saveAsTextFile(tempFile2.name) >>> ''.join(sorted(input(glob(tempFile2.name + "/part-0000*")))) '\\n\\n\\nbar\\nfoo\\n' + + Using compressionCodecClass + + >>> tempFile3 = NamedTemporaryFile(delete=True) + >>> tempFile3.close() + >>> codec = "org.apache.hadoop.io.compress.GzipCodec" + >>> sc.parallelize(['foo', 'bar']).saveAsTextFile(tempFile3.name, codec) + >>> from fileinput import input, hook_compressed + >>> ''.join(sorted(input(glob(tempFile3.name + "/part*.gz"), openhook=hook_compressed))) + 'bar\\nfoo\\n' """ def func(split, iterator): for x in iterator: @@ -1395,7 +1424,11 @@ def func(split, iterator): yield x keyed = self.mapPartitionsWithIndex(func) keyed._bypass_serializer = True - keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path) + if compressionCodecClass: + compressionCodec = self.ctx._jvm.java.lang.Class.forName(compressionCodecClass) + keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path, compressionCodec) + else: + keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path) # Pair functions @@ -1570,6 +1603,9 @@ def partitionBy(self, numPartitions, partitionFunc=portable_hash): """ if numPartitions is None: numPartitions = self._defaultReducePartitions() + partitioner = Partitioner(numPartitions, partitionFunc) + if self.partitioner == partitioner: + return self # Transferring O(n) objects to Java is too expensive. # Instead, we'll form the hash buckets in Python, @@ -1614,18 +1650,16 @@ def add_shuffle_key(split, iterator): yield pack_long(split) yield outputSerializer.dumps(items) - keyed = self.mapPartitionsWithIndex(add_shuffle_key) + keyed = self.mapPartitionsWithIndex(add_shuffle_key, preservesPartitioning=True) keyed._bypass_serializer = True with SCCallSiteSync(self.context) as css: pairRDD = self.ctx._jvm.PairwiseRDD( keyed._jrdd.rdd()).asJavaPairRDD() - partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, - id(partitionFunc)) - jrdd = pairRDD.partitionBy(partitioner).values() + jpartitioner = self.ctx._jvm.PythonPartitioner(numPartitions, + id(partitionFunc)) + jrdd = self.ctx._jvm.PythonRDD.valueOfPair(pairRDD.partitionBy(jpartitioner)) rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer)) - # This is required so that id(partitionFunc) remains unique, - # even if partitionFunc is a lambda: - rdd._partitionFunc = partitionFunc + rdd.partitioner = partitioner return rdd # TODO: add control over map-side aggregation @@ -1671,7 +1705,7 @@ def combineLocally(iterator): merger.mergeValues(iterator) return merger.iteritems() - locally_combined = self.mapPartitions(combineLocally) + locally_combined = self.mapPartitions(combineLocally, preservesPartitioning=True) shuffled = locally_combined.partitionBy(numPartitions) def _mergeCombiners(iterator): @@ -1680,7 +1714,7 @@ def _mergeCombiners(iterator): merger.mergeCombiners(iterator) return merger.iteritems() - return shuffled.mapPartitions(_mergeCombiners, True) + return shuffled.mapPartitions(_mergeCombiners, preservesPartitioning=True) def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): """ @@ -1915,7 +1949,7 @@ def batch_as(rdd, batchSize): my_batch = get_batch_size(self._jrdd_deserializer) other_batch = get_batch_size(other._jrdd_deserializer) - if my_batch != other_batch: + if my_batch != other_batch or not my_batch: # use the smallest batchSize for both of them batchSize = min(my_batch, other_batch) if batchSize <= 0: @@ -2059,8 +2093,8 @@ def lookup(self, key): """ values = self.filter(lambda (k, v): k == key).values() - if self._partitionFunc is not None: - return self.ctx.runJob(values, lambda x: x, [self._partitionFunc(key)], False) + if self.partitioner is not None: + return self.ctx.runJob(values, lambda x: x, [self.partitioner(key)], False) return values.collect() @@ -2076,6 +2110,7 @@ def _to_java_object_rdd(self): def countApprox(self, timeout, confidence=0.95): """ .. note:: Experimental + Approximate version of count() that returns a potentially incomplete result within a timeout, even if not all tasks have finished. @@ -2089,6 +2124,7 @@ def countApprox(self, timeout, confidence=0.95): def sumApprox(self, timeout, confidence=0.95): """ .. note:: Experimental + Approximate operation to return the sum within a timeout or meet the confidence. @@ -2105,6 +2141,7 @@ def sumApprox(self, timeout, confidence=0.95): def meanApprox(self, timeout, confidence=0.95): """ .. note:: Experimental + Approximate operation to return the mean within a timeout or meet the confidence. @@ -2121,6 +2158,7 @@ def meanApprox(self, timeout, confidence=0.95): def countApproxDistinct(self, relativeSD=0.05): """ .. note:: Experimental + Return approximate number of distinct elements in the RDD. The algorithm used is based on streamlib's implementation of @@ -2162,6 +2200,25 @@ def toLocalIterator(self): yield row +def _prepare_for_python_RDD(sc, command, obj=None): + # the serialized command will be compressed by broadcast + ser = CloudPickleSerializer() + pickled_command = ser.dumps(command) + if len(pickled_command) > (1 << 20): # 1M + broadcast = sc.broadcast(pickled_command) + pickled_command = ser.dumps(broadcast) + # tracking the life cycle by obj + if obj is not None: + obj._broadcast = broadcast + broadcast_vars = ListConverter().convert( + [x._jbroadcast for x in sc._pickled_broadcast_vars], + sc._gateway._gateway_client) + sc._pickled_broadcast_vars.clear() + env = MapConverter().convert(sc.environment, sc._gateway._gateway_client) + includes = ListConverter().convert(sc._python_includes, sc._gateway._gateway_client) + return pickled_command, broadcast_vars, env, includes + + class PipelinedRDD(RDD): """ @@ -2206,7 +2263,7 @@ def pipeline_func(split, iterator): self._id = None self._jrdd_deserializer = self.ctx.serializer self._bypass_serializer = False - self._partitionFunc = prev._partitionFunc if self.preservesPartitioning else None + self.partitioner = prev.partitioner if self.preservesPartitioning else None self._broadcast = None def __del__(self): @@ -2228,25 +2285,12 @@ def _jrdd(self): command = (self.func, profiler, self._prev_jrdd_deserializer, self._jrdd_deserializer) - # the serialized command will be compressed by broadcast - ser = CloudPickleSerializer() - pickled_command = ser.dumps(command) - if len(pickled_command) > (1 << 20): # 1M - self._broadcast = self.ctx.broadcast(pickled_command) - pickled_command = ser.dumps(self._broadcast) - broadcast_vars = ListConverter().convert( - [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], - self.ctx._gateway._gateway_client) - self.ctx._pickled_broadcast_vars.clear() - env = MapConverter().convert(self.ctx.environment, - self.ctx._gateway._gateway_client) - includes = ListConverter().convert(self.ctx._python_includes, - self.ctx._gateway._gateway_client) + pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self.ctx, command, self) python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), - bytearray(pickled_command), + bytearray(pickled_cmd), env, includes, self.preservesPartitioning, self.ctx.pythonExec, - broadcast_vars, self.ctx._javaAccumulator) + bvars, self.ctx._javaAccumulator) self._jrdd_val = python_rdd.asJavaRDD() if profiler: diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 89cf76920e353..81aa970a32f76 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -31,13 +31,18 @@ import atexit import os import platform + +import py4j + import pyspark from pyspark.context import SparkContext +from pyspark.sql import SQLContext, HiveContext from pyspark.storagelevel import StorageLevel -# this is the equivalent of ADD_JARS -add_files = (os.environ.get("ADD_FILES").split(',') - if os.environ.get("ADD_FILES") is not None else None) +# this is the deprecated equivalent of ADD_JARS +add_files = None +if os.environ.get("ADD_FILES") is not None: + add_files = os.environ.get("ADD_FILES").split(',') if os.environ.get("SPARK_EXECUTOR_URI"): SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"]) @@ -45,6 +50,13 @@ sc = SparkContext(appName="PySparkShell", pyFiles=add_files) atexit.register(lambda: sc.stop()) +try: + # Try to access HiveConf, it will raise exception if Hive is not added + sc._jvm.org.apache.hadoop.hive.conf.HiveConf() + sqlCtx = sqlContext = HiveContext(sc) +except py4j.protocol.Py4JError: + sqlCtx = sqlContext = SQLContext(sc) + print("""Welcome to ____ __ / __/__ ___ _____/ /__ @@ -56,9 +68,10 @@ platform.python_version(), platform.python_build()[0], platform.python_build()[1])) -print("SparkContext available as sc.") +print("SparkContext available as sc, %s available as sqlContext." % sqlContext.__class__.__name__) if add_files is not None: + print("Warning: ADD_FILES environment variable is deprecated, use --py-files argument instead") print("Adding files: [%s]" % ", ".join(add_files)) # The ./bin/pyspark script stores the old PYTHONSTARTUP value in OLD_PYTHONSTARTUP, diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py deleted file mode 100644 index 32bff0c7e8c55..0000000000000 --- a/python/pyspark/sql.py +++ /dev/null @@ -1,2600 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -public classes of Spark SQL: - - - L{SQLContext} - Main entry point for SQL functionality. - - L{DataFrame} - A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In - addition to normal RDD operations, DataFrames also support SQL. - - L{GroupedDataFrame} - - L{Column} - Column is a DataFrame with a single column. - - L{Row} - A Row of data returned by a Spark SQL query. - - L{HiveContext} - Main entry point for accessing data stored in Apache Hive.. -""" - -import sys -import itertools -import decimal -import datetime -import keyword -import warnings -import json -import re -import random -import os -from tempfile import NamedTemporaryFile -from array import array -from operator import itemgetter -from itertools import imap - -from py4j.protocol import Py4JError -from py4j.java_collections import ListConverter, MapConverter - -from pyspark.context import SparkContext -from pyspark.rdd import RDD -from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \ - CloudPickleSerializer, UTF8Deserializer -from pyspark.storagelevel import StorageLevel -from pyspark.traceback_utils import SCCallSiteSync - - -__all__ = [ - "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", - "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", - "ShortType", "ArrayType", "MapType", "StructField", "StructType", - "SQLContext", "HiveContext", "DataFrame", "GroupedDataFrame", "Column", "Row", - "SchemaRDD"] - - -class DataType(object): - - """Spark SQL DataType""" - - def __repr__(self): - return self.__class__.__name__ - - def __hash__(self): - return hash(str(self)) - - def __eq__(self, other): - return (isinstance(other, self.__class__) and - self.__dict__ == other.__dict__) - - def __ne__(self, other): - return not self.__eq__(other) - - @classmethod - def typeName(cls): - return cls.__name__[:-4].lower() - - def jsonValue(self): - return self.typeName() - - def json(self): - return json.dumps(self.jsonValue(), - separators=(',', ':'), - sort_keys=True) - - -class PrimitiveTypeSingleton(type): - - """Metaclass for PrimitiveType""" - - _instances = {} - - def __call__(cls): - if cls not in cls._instances: - cls._instances[cls] = super(PrimitiveTypeSingleton, cls).__call__() - return cls._instances[cls] - - -class PrimitiveType(DataType): - - """Spark SQL PrimitiveType""" - - __metaclass__ = PrimitiveTypeSingleton - - def __eq__(self, other): - # because they should be the same object - return self is other - - -class NullType(PrimitiveType): - - """Spark SQL NullType - - The data type representing None, used for the types which has not - been inferred. - """ - - -class StringType(PrimitiveType): - - """Spark SQL StringType - - The data type representing string values. - """ - - -class BinaryType(PrimitiveType): - - """Spark SQL BinaryType - - The data type representing bytearray values. - """ - - -class BooleanType(PrimitiveType): - - """Spark SQL BooleanType - - The data type representing bool values. - """ - - -class DateType(PrimitiveType): - - """Spark SQL DateType - - The data type representing datetime.date values. - """ - - -class TimestampType(PrimitiveType): - - """Spark SQL TimestampType - - The data type representing datetime.datetime values. - """ - - -class DecimalType(DataType): - - """Spark SQL DecimalType - - The data type representing decimal.Decimal values. - """ - - def __init__(self, precision=None, scale=None): - self.precision = precision - self.scale = scale - self.hasPrecisionInfo = precision is not None - - def jsonValue(self): - if self.hasPrecisionInfo: - return "decimal(%d,%d)" % (self.precision, self.scale) - else: - return "decimal" - - def __repr__(self): - if self.hasPrecisionInfo: - return "DecimalType(%d,%d)" % (self.precision, self.scale) - else: - return "DecimalType()" - - -class DoubleType(PrimitiveType): - - """Spark SQL DoubleType - - The data type representing float values. - """ - - -class FloatType(PrimitiveType): - - """Spark SQL FloatType - - The data type representing single precision floating-point values. - """ - - -class ByteType(PrimitiveType): - - """Spark SQL ByteType - - The data type representing int values with 1 singed byte. - """ - - -class IntegerType(PrimitiveType): - - """Spark SQL IntegerType - - The data type representing int values. - """ - - -class LongType(PrimitiveType): - - """Spark SQL LongType - - The data type representing long values. If the any value is - beyond the range of [-9223372036854775808, 9223372036854775807], - please use DecimalType. - """ - - -class ShortType(PrimitiveType): - - """Spark SQL ShortType - - The data type representing int values with 2 signed bytes. - """ - - -class ArrayType(DataType): - - """Spark SQL ArrayType - - The data type representing list values. An ArrayType object - comprises two fields, elementType (a DataType) and containsNull (a bool). - The field of elementType is used to specify the type of array elements. - The field of containsNull is used to specify if the array has None values. - - """ - - def __init__(self, elementType, containsNull=True): - """Creates an ArrayType - - :param elementType: the data type of elements. - :param containsNull: indicates whether the list contains None values. - - >>> ArrayType(StringType) == ArrayType(StringType, True) - True - >>> ArrayType(StringType, False) == ArrayType(StringType) - False - """ - self.elementType = elementType - self.containsNull = containsNull - - def __repr__(self): - return "ArrayType(%s,%s)" % (self.elementType, - str(self.containsNull).lower()) - - def jsonValue(self): - return {"type": self.typeName(), - "elementType": self.elementType.jsonValue(), - "containsNull": self.containsNull} - - @classmethod - def fromJson(cls, json): - return ArrayType(_parse_datatype_json_value(json["elementType"]), - json["containsNull"]) - - -class MapType(DataType): - - """Spark SQL MapType - - The data type representing dict values. A MapType object comprises - three fields, keyType (a DataType), valueType (a DataType) and - valueContainsNull (a bool). - - The field of keyType is used to specify the type of keys in the map. - The field of valueType is used to specify the type of values in the map. - The field of valueContainsNull is used to specify if values of this - map has None values. - - For values of a MapType column, keys are not allowed to have None values. - - """ - - def __init__(self, keyType, valueType, valueContainsNull=True): - """Creates a MapType - :param keyType: the data type of keys. - :param valueType: the data type of values. - :param valueContainsNull: indicates whether values contains - null values. - - >>> (MapType(StringType, IntegerType) - ... == MapType(StringType, IntegerType, True)) - True - >>> (MapType(StringType, IntegerType, False) - ... == MapType(StringType, FloatType)) - False - """ - self.keyType = keyType - self.valueType = valueType - self.valueContainsNull = valueContainsNull - - def __repr__(self): - return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, - str(self.valueContainsNull).lower()) - - def jsonValue(self): - return {"type": self.typeName(), - "keyType": self.keyType.jsonValue(), - "valueType": self.valueType.jsonValue(), - "valueContainsNull": self.valueContainsNull} - - @classmethod - def fromJson(cls, json): - return MapType(_parse_datatype_json_value(json["keyType"]), - _parse_datatype_json_value(json["valueType"]), - json["valueContainsNull"]) - - -class StructField(DataType): - - """Spark SQL StructField - - Represents a field in a StructType. - A StructField object comprises three fields, name (a string), - dataType (a DataType) and nullable (a bool). The field of name - is the name of a StructField. The field of dataType specifies - the data type of a StructField. - - The field of nullable specifies if values of a StructField can - contain None values. - - """ - - def __init__(self, name, dataType, nullable=True, metadata=None): - """Creates a StructField - :param name: the name of this field. - :param dataType: the data type of this field. - :param nullable: indicates whether values of this field - can be null. - :param metadata: metadata of this field, which is a map from string - to simple type that can be serialized to JSON - automatically - - >>> (StructField("f1", StringType, True) - ... == StructField("f1", StringType, True)) - True - >>> (StructField("f1", StringType, True) - ... == StructField("f2", StringType, True)) - False - """ - self.name = name - self.dataType = dataType - self.nullable = nullable - self.metadata = metadata or {} - - def __repr__(self): - return "StructField(%s,%s,%s)" % (self.name, self.dataType, - str(self.nullable).lower()) - - def jsonValue(self): - return {"name": self.name, - "type": self.dataType.jsonValue(), - "nullable": self.nullable, - "metadata": self.metadata} - - @classmethod - def fromJson(cls, json): - return StructField(json["name"], - _parse_datatype_json_value(json["type"]), - json["nullable"], - json["metadata"]) - - -class StructType(DataType): - - """Spark SQL StructType - - The data type representing rows. - A StructType object comprises a list of L{StructField}. - - """ - - def __init__(self, fields): - """Creates a StructType - - >>> struct1 = StructType([StructField("f1", StringType, True)]) - >>> struct2 = StructType([StructField("f1", StringType, True)]) - >>> struct1 == struct2 - True - >>> struct1 = StructType([StructField("f1", StringType, True)]) - >>> struct2 = StructType([StructField("f1", StringType, True), - ... [StructField("f2", IntegerType, False)]]) - >>> struct1 == struct2 - False - """ - self.fields = fields - - def __repr__(self): - return ("StructType(List(%s))" % - ",".join(str(field) for field in self.fields)) - - def jsonValue(self): - return {"type": self.typeName(), - "fields": [f.jsonValue() for f in self.fields]} - - @classmethod - def fromJson(cls, json): - return StructType([StructField.fromJson(f) for f in json["fields"]]) - - -class UserDefinedType(DataType): - """ - .. note:: WARN: Spark Internal Use Only - SQL User-Defined Type (UDT). - """ - - @classmethod - def typeName(cls): - return cls.__name__.lower() - - @classmethod - def sqlType(cls): - """ - Underlying SQL storage type for this UDT. - """ - raise NotImplementedError("UDT must implement sqlType().") - - @classmethod - def module(cls): - """ - The Python module of the UDT. - """ - raise NotImplementedError("UDT must implement module().") - - @classmethod - def scalaUDT(cls): - """ - The class name of the paired Scala UDT. - """ - raise NotImplementedError("UDT must have a paired Scala UDT.") - - def serialize(self, obj): - """ - Converts the a user-type object into a SQL datum. - """ - raise NotImplementedError("UDT must implement serialize().") - - def deserialize(self, datum): - """ - Converts a SQL datum into a user-type object. - """ - raise NotImplementedError("UDT must implement deserialize().") - - def json(self): - return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) - - def jsonValue(self): - schema = { - "type": "udt", - "class": self.scalaUDT(), - "pyClass": "%s.%s" % (self.module(), type(self).__name__), - "sqlType": self.sqlType().jsonValue() - } - return schema - - @classmethod - def fromJson(cls, json): - pyUDT = json["pyClass"] - split = pyUDT.rfind(".") - pyModule = pyUDT[:split] - pyClass = pyUDT[split+1:] - m = __import__(pyModule, globals(), locals(), [pyClass], -1) - UDT = getattr(m, pyClass) - return UDT() - - def __eq__(self, other): - return type(self) == type(other) - - -_all_primitive_types = dict((v.typeName(), v) - for v in globals().itervalues() - if type(v) is PrimitiveTypeSingleton and - v.__base__ == PrimitiveType) - - -_all_complex_types = dict((v.typeName(), v) - for v in [ArrayType, MapType, StructType]) - - -def _parse_datatype_json_string(json_string): - """Parses the given data type JSON string. - >>> def check_datatype(datatype): - ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json()) - ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) - ... return datatype == python_datatype - >>> all(check_datatype(cls()) for cls in _all_primitive_types.values()) - True - >>> # Simple ArrayType. - >>> simple_arraytype = ArrayType(StringType(), True) - >>> check_datatype(simple_arraytype) - True - >>> # Simple MapType. - >>> simple_maptype = MapType(StringType(), LongType()) - >>> check_datatype(simple_maptype) - True - >>> # Simple StructType. - >>> simple_structtype = StructType([ - ... StructField("a", DecimalType(), False), - ... StructField("b", BooleanType(), True), - ... StructField("c", LongType(), True), - ... StructField("d", BinaryType(), False)]) - >>> check_datatype(simple_structtype) - True - >>> # Complex StructType. - >>> complex_structtype = StructType([ - ... StructField("simpleArray", simple_arraytype, True), - ... StructField("simpleMap", simple_maptype, True), - ... StructField("simpleStruct", simple_structtype, True), - ... StructField("boolean", BooleanType(), False), - ... StructField("withMeta", DoubleType(), False, {"name": "age"})]) - >>> check_datatype(complex_structtype) - True - >>> # Complex ArrayType. - >>> complex_arraytype = ArrayType(complex_structtype, True) - >>> check_datatype(complex_arraytype) - True - >>> # Complex MapType. - >>> complex_maptype = MapType(complex_structtype, - ... complex_arraytype, False) - >>> check_datatype(complex_maptype) - True - >>> check_datatype(ExamplePointUDT()) - True - >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> check_datatype(structtype_with_udt) - True - """ - return _parse_datatype_json_value(json.loads(json_string)) - - -_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)") - - -def _parse_datatype_json_value(json_value): - if type(json_value) is unicode: - if json_value in _all_primitive_types.keys(): - return _all_primitive_types[json_value]() - elif json_value == u'decimal': - return DecimalType() - elif _FIXED_DECIMAL.match(json_value): - m = _FIXED_DECIMAL.match(json_value) - return DecimalType(int(m.group(1)), int(m.group(2))) - else: - raise ValueError("Could not parse datatype: %s" % json_value) - else: - tpe = json_value["type"] - if tpe in _all_complex_types: - return _all_complex_types[tpe].fromJson(json_value) - elif tpe == 'udt': - return UserDefinedType.fromJson(json_value) - else: - raise ValueError("not supported type: %s" % tpe) - - -# Mapping Python types to Spark SQL DataType -_type_mappings = { - type(None): NullType, - bool: BooleanType, - int: IntegerType, - long: LongType, - float: DoubleType, - str: StringType, - unicode: StringType, - bytearray: BinaryType, - decimal.Decimal: DecimalType, - datetime.date: DateType, - datetime.datetime: TimestampType, - datetime.time: TimestampType, -} - - -def _infer_type(obj): - """Infer the DataType from obj - - >>> p = ExamplePoint(1.0, 2.0) - >>> _infer_type(p) - ExamplePointUDT - """ - if obj is None: - raise ValueError("Can not infer type for None") - - if hasattr(obj, '__UDT__'): - return obj.__UDT__ - - dataType = _type_mappings.get(type(obj)) - if dataType is not None: - return dataType() - - if isinstance(obj, dict): - for key, value in obj.iteritems(): - if key is not None and value is not None: - return MapType(_infer_type(key), _infer_type(value), True) - else: - return MapType(NullType(), NullType(), True) - elif isinstance(obj, (list, array)): - for v in obj: - if v is not None: - return ArrayType(_infer_type(obj[0]), True) - else: - return ArrayType(NullType(), True) - else: - try: - return _infer_schema(obj) - except ValueError: - raise ValueError("not supported type: %s" % type(obj)) - - -def _infer_schema(row): - """Infer the schema from dict/namedtuple/object""" - if isinstance(row, dict): - items = sorted(row.items()) - - elif isinstance(row, tuple): - if hasattr(row, "_fields"): # namedtuple - items = zip(row._fields, tuple(row)) - elif hasattr(row, "__FIELDS__"): # Row - items = zip(row.__FIELDS__, tuple(row)) - elif all(isinstance(x, tuple) and len(x) == 2 for x in row): - items = row - else: - raise ValueError("Can't infer schema from tuple") - - elif hasattr(row, "__dict__"): # object - items = sorted(row.__dict__.items()) - - else: - raise ValueError("Can not infer schema for type: %s" % type(row)) - - fields = [StructField(k, _infer_type(v), True) for k, v in items] - return StructType(fields) - - -def _need_python_to_sql_conversion(dataType): - """ - Checks whether we need python to sql conversion for the given type. - For now, only UDTs need this conversion. - - >>> _need_python_to_sql_conversion(DoubleType()) - False - >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), - ... StructField("values", ArrayType(DoubleType(), False), False)]) - >>> _need_python_to_sql_conversion(schema0) - False - >>> _need_python_to_sql_conversion(ExamplePointUDT()) - True - >>> schema1 = ArrayType(ExamplePointUDT(), False) - >>> _need_python_to_sql_conversion(schema1) - True - >>> schema2 = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> _need_python_to_sql_conversion(schema2) - True - """ - if isinstance(dataType, StructType): - return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields]) - elif isinstance(dataType, ArrayType): - return _need_python_to_sql_conversion(dataType.elementType) - elif isinstance(dataType, MapType): - return _need_python_to_sql_conversion(dataType.keyType) or \ - _need_python_to_sql_conversion(dataType.valueType) - elif isinstance(dataType, UserDefinedType): - return True - else: - return False - - -def _python_to_sql_converter(dataType): - """ - Returns a converter that converts a Python object into a SQL datum for the given type. - - >>> conv = _python_to_sql_converter(DoubleType()) - >>> conv(1.0) - 1.0 - >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False)) - >>> conv([1.0, 2.0]) - [1.0, 2.0] - >>> conv = _python_to_sql_converter(ExamplePointUDT()) - >>> conv(ExamplePoint(1.0, 2.0)) - [1.0, 2.0] - >>> schema = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> conv = _python_to_sql_converter(schema) - >>> conv((1.0, ExamplePoint(1.0, 2.0))) - (1.0, [1.0, 2.0]) - """ - if not _need_python_to_sql_conversion(dataType): - return lambda x: x - - if isinstance(dataType, StructType): - names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) - converters = map(_python_to_sql_converter, types) - - def converter(obj): - if isinstance(obj, dict): - return tuple(c(obj.get(n)) for n, c in zip(names, converters)) - elif isinstance(obj, tuple): - if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"): - return tuple(c(v) for c, v in zip(converters, obj)) - elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs - d = dict(obj) - return tuple(c(d.get(n)) for n, c in zip(names, converters)) - else: - return tuple(c(v) for c, v in zip(converters, obj)) - else: - raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) - return converter - elif isinstance(dataType, ArrayType): - element_converter = _python_to_sql_converter(dataType.elementType) - return lambda a: [element_converter(v) for v in a] - elif isinstance(dataType, MapType): - key_converter = _python_to_sql_converter(dataType.keyType) - value_converter = _python_to_sql_converter(dataType.valueType) - return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) - elif isinstance(dataType, UserDefinedType): - return lambda obj: dataType.serialize(obj) - else: - raise ValueError("Unexpected type %r" % dataType) - - -def _has_nulltype(dt): - """ Return whether there is NullType in `dt` or not """ - if isinstance(dt, StructType): - return any(_has_nulltype(f.dataType) for f in dt.fields) - elif isinstance(dt, ArrayType): - return _has_nulltype((dt.elementType)) - elif isinstance(dt, MapType): - return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType) - else: - return isinstance(dt, NullType) - - -def _merge_type(a, b): - if isinstance(a, NullType): - return b - elif isinstance(b, NullType): - return a - elif type(a) is not type(b): - # TODO: type cast (such as int -> long) - raise TypeError("Can not merge type %s and %s" % (a, b)) - - # same type - if isinstance(a, StructType): - nfs = dict((f.name, f.dataType) for f in b.fields) - fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()))) - for f in a.fields] - names = set([f.name for f in fields]) - for n in nfs: - if n not in names: - fields.append(StructField(n, nfs[n])) - return StructType(fields) - - elif isinstance(a, ArrayType): - return ArrayType(_merge_type(a.elementType, b.elementType), True) - - elif isinstance(a, MapType): - return MapType(_merge_type(a.keyType, b.keyType), - _merge_type(a.valueType, b.valueType), - True) - else: - return a - - -def _create_converter(dataType): - """Create an converter to drop the names of fields in obj """ - if isinstance(dataType, ArrayType): - conv = _create_converter(dataType.elementType) - return lambda row: map(conv, row) - - elif isinstance(dataType, MapType): - kconv = _create_converter(dataType.keyType) - vconv = _create_converter(dataType.valueType) - return lambda row: dict((kconv(k), vconv(v)) for k, v in row.iteritems()) - - elif isinstance(dataType, NullType): - return lambda x: None - - elif not isinstance(dataType, StructType): - return lambda x: x - - # dataType must be StructType - names = [f.name for f in dataType.fields] - converters = [_create_converter(f.dataType) for f in dataType.fields] - - def convert_struct(obj): - if obj is None: - return - - if isinstance(obj, tuple): - if hasattr(obj, "_fields"): - d = dict(zip(obj._fields, obj)) - elif hasattr(obj, "__FIELDS__"): - d = dict(zip(obj.__FIELDS__, obj)) - elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): - d = dict(obj) - else: - raise ValueError("unexpected tuple: %s" % str(obj)) - - elif isinstance(obj, dict): - d = obj - elif hasattr(obj, "__dict__"): # object - d = obj.__dict__ - else: - raise ValueError("Unexpected obj: %s" % obj) - - return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) - - return convert_struct - - -_BRACKETS = {'(': ')', '[': ']', '{': '}'} - - -def _split_schema_abstract(s): - """ - split the schema abstract into fields - - >>> _split_schema_abstract("a b c") - ['a', 'b', 'c'] - >>> _split_schema_abstract("a(a b)") - ['a(a b)'] - >>> _split_schema_abstract("a b[] c{a b}") - ['a', 'b[]', 'c{a b}'] - >>> _split_schema_abstract(" ") - [] - """ - - r = [] - w = '' - brackets = [] - for c in s: - if c == ' ' and not brackets: - if w: - r.append(w) - w = '' - else: - w += c - if c in _BRACKETS: - brackets.append(c) - elif c in _BRACKETS.values(): - if not brackets or c != _BRACKETS[brackets.pop()]: - raise ValueError("unexpected " + c) - - if brackets: - raise ValueError("brackets not closed: %s" % brackets) - if w: - r.append(w) - return r - - -def _parse_field_abstract(s): - """ - Parse a field in schema abstract - - >>> _parse_field_abstract("a") - StructField(a,None,true) - >>> _parse_field_abstract("b(c d)") - StructField(b,StructType(...c,None,true),StructField(d... - >>> _parse_field_abstract("a[]") - StructField(a,ArrayType(None,true),true) - >>> _parse_field_abstract("a{[]}") - StructField(a,MapType(None,ArrayType(None,true),true),true) - """ - if set(_BRACKETS.keys()) & set(s): - idx = min((s.index(c) for c in _BRACKETS if c in s)) - name = s[:idx] - return StructField(name, _parse_schema_abstract(s[idx:]), True) - else: - return StructField(s, None, True) - - -def _parse_schema_abstract(s): - """ - parse abstract into schema - - >>> _parse_schema_abstract("a b c") - StructType...a...b...c... - >>> _parse_schema_abstract("a[b c] b{}") - StructType...a,ArrayType...b...c...b,MapType... - >>> _parse_schema_abstract("c{} d{a b}") - StructType...c,MapType...d,MapType...a...b... - >>> _parse_schema_abstract("a b(t)").fields[1] - StructField(b,StructType(List(StructField(t,None,true))),true) - """ - s = s.strip() - if not s: - return - - elif s.startswith('('): - return _parse_schema_abstract(s[1:-1]) - - elif s.startswith('['): - return ArrayType(_parse_schema_abstract(s[1:-1]), True) - - elif s.startswith('{'): - return MapType(None, _parse_schema_abstract(s[1:-1])) - - parts = _split_schema_abstract(s) - fields = [_parse_field_abstract(p) for p in parts] - return StructType(fields) - - -def _infer_schema_type(obj, dataType): - """ - Fill the dataType with types inferred from obj - - >>> schema = _parse_schema_abstract("a b c d") - >>> row = (1, 1.0, "str", datetime.date(2014, 10, 10)) - >>> _infer_schema_type(row, schema) - StructType...IntegerType...DoubleType...StringType...DateType... - >>> row = [[1], {"key": (1, 2.0)}] - >>> schema = _parse_schema_abstract("a[] b{c d}") - >>> _infer_schema_type(row, schema) - StructType...a,ArrayType...b,MapType(StringType,...c,IntegerType... - """ - if dataType is None: - return _infer_type(obj) - - if not obj: - return NullType() - - if isinstance(dataType, ArrayType): - eType = _infer_schema_type(obj[0], dataType.elementType) - return ArrayType(eType, True) - - elif isinstance(dataType, MapType): - k, v = obj.iteritems().next() - return MapType(_infer_schema_type(k, dataType.keyType), - _infer_schema_type(v, dataType.valueType)) - - elif isinstance(dataType, StructType): - fs = dataType.fields - assert len(fs) == len(obj), \ - "Obj(%s) have different length with fields(%s)" % (obj, fs) - fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True) - for o, f in zip(obj, fs)] - return StructType(fields) - - else: - raise ValueError("Unexpected dataType: %s" % dataType) - - -_acceptable_types = { - BooleanType: (bool,), - ByteType: (int, long), - ShortType: (int, long), - IntegerType: (int, long), - LongType: (int, long), - FloatType: (float,), - DoubleType: (float,), - DecimalType: (decimal.Decimal,), - StringType: (str, unicode), - BinaryType: (bytearray,), - DateType: (datetime.date,), - TimestampType: (datetime.datetime,), - ArrayType: (list, tuple, array), - MapType: (dict,), - StructType: (tuple, list), -} - - -def _verify_type(obj, dataType): - """ - Verify the type of obj against dataType, raise an exception if - they do not match. - - >>> _verify_type(None, StructType([])) - >>> _verify_type("", StringType()) - >>> _verify_type(0, IntegerType()) - >>> _verify_type(range(3), ArrayType(ShortType())) - >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - TypeError:... - >>> _verify_type({}, MapType(StringType(), IntegerType())) - >>> _verify_type((), StructType([])) - >>> _verify_type([], StructType([])) - >>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - ValueError:... - >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) - >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - ValueError:... - """ - # all objects are nullable - if obj is None: - return - - if isinstance(dataType, UserDefinedType): - if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): - raise ValueError("%r is not an instance of type %r" % (obj, dataType)) - _verify_type(dataType.serialize(obj), dataType.sqlType()) - return - - _type = type(dataType) - assert _type in _acceptable_types, "unkown datatype: %s" % dataType - - # subclass of them can not be deserialized in JVM - if type(obj) not in _acceptable_types[_type]: - raise TypeError("%s can not accept object in type %s" - % (dataType, type(obj))) - - if isinstance(dataType, ArrayType): - for i in obj: - _verify_type(i, dataType.elementType) - - elif isinstance(dataType, MapType): - for k, v in obj.iteritems(): - _verify_type(k, dataType.keyType) - _verify_type(v, dataType.valueType) - - elif isinstance(dataType, StructType): - if len(obj) != len(dataType.fields): - raise ValueError("Length of object (%d) does not match with" - "length of fields (%d)" % (len(obj), len(dataType.fields))) - for v, f in zip(obj, dataType.fields): - _verify_type(v, f.dataType) - - -_cached_cls = {} - - -def _restore_object(dataType, obj): - """ Restore object during unpickling. """ - # use id(dataType) as key to speed up lookup in dict - # Because of batched pickling, dataType will be the - # same object in most cases. - k = id(dataType) - cls = _cached_cls.get(k) - if cls is None: - # use dataType as key to avoid create multiple class - cls = _cached_cls.get(dataType) - if cls is None: - cls = _create_cls(dataType) - _cached_cls[dataType] = cls - _cached_cls[k] = cls - return cls(obj) - - -def _create_object(cls, v): - """ Create an customized object with class `cls`. """ - # datetime.date would be deserialized as datetime.datetime - # from java type, so we need to set it back. - if cls is datetime.date and isinstance(v, datetime.datetime): - return v.date() - return cls(v) if v is not None else v - - -def _create_getter(dt, i): - """ Create a getter for item `i` with schema """ - cls = _create_cls(dt) - - def getter(self): - return _create_object(cls, self[i]) - - return getter - - -def _has_struct_or_date(dt): - """Return whether `dt` is or has StructType/DateType in it""" - if isinstance(dt, StructType): - return True - elif isinstance(dt, ArrayType): - return _has_struct_or_date(dt.elementType) - elif isinstance(dt, MapType): - return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType) - elif isinstance(dt, DateType): - return True - elif isinstance(dt, UserDefinedType): - return True - return False - - -def _create_properties(fields): - """Create properties according to fields""" - ps = {} - for i, f in enumerate(fields): - name = f.name - if (name.startswith("__") and name.endswith("__") - or keyword.iskeyword(name)): - warnings.warn("field name %s can not be accessed in Python," - "use position to access it instead" % name) - if _has_struct_or_date(f.dataType): - # delay creating object until accessing it - getter = _create_getter(f.dataType, i) - else: - getter = itemgetter(i) - ps[name] = property(getter) - return ps - - -def _create_cls(dataType): - """ - Create an class by dataType - - The created class is similar to namedtuple, but can have nested schema. - - >>> schema = _parse_schema_abstract("a b c") - >>> row = (1, 1.0, "str") - >>> schema = _infer_schema_type(row, schema) - >>> obj = _create_cls(schema)(row) - >>> import pickle - >>> pickle.loads(pickle.dumps(obj)) - Row(a=1, b=1.0, c='str') - - >>> row = [[1], {"key": (1, 2.0)}] - >>> schema = _parse_schema_abstract("a[] b{c d}") - >>> schema = _infer_schema_type(row, schema) - >>> obj = _create_cls(schema)(row) - >>> pickle.loads(pickle.dumps(obj)) - Row(a=[1], b={'key': Row(c=1, d=2.0)}) - >>> pickle.loads(pickle.dumps(obj.a)) - [1] - >>> pickle.loads(pickle.dumps(obj.b)) - {'key': Row(c=1, d=2.0)} - """ - - if isinstance(dataType, ArrayType): - cls = _create_cls(dataType.elementType) - - def List(l): - if l is None: - return - return [_create_object(cls, v) for v in l] - - return List - - elif isinstance(dataType, MapType): - kcls = _create_cls(dataType.keyType) - vcls = _create_cls(dataType.valueType) - - def Dict(d): - if d is None: - return - return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items()) - - return Dict - - elif isinstance(dataType, DateType): - return datetime.date - - elif isinstance(dataType, UserDefinedType): - return lambda datum: dataType.deserialize(datum) - - elif not isinstance(dataType, StructType): - # no wrapper for primitive types - return lambda x: x - - class Row(tuple): - - """ Row in DataFrame """ - __DATATYPE__ = dataType - __FIELDS__ = tuple(f.name for f in dataType.fields) - __slots__ = () - - # create property for fast access - locals().update(_create_properties(dataType.fields)) - - def asDict(self): - """ Return as a dict """ - return dict((n, getattr(self, n)) for n in self.__FIELDS__) - - def __repr__(self): - # call collect __repr__ for nested objects - return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) - for n in self.__FIELDS__)) - - def __reduce__(self): - return (_restore_object, (self.__DATATYPE__, tuple(self))) - - return Row - - -class SQLContext(object): - - """Main entry point for Spark SQL functionality. - - A SQLContext can be used create L{DataFrame}, register L{DataFrame} as - tables, execute SQL over tables, cache tables, and read parquet files. - """ - - def __init__(self, sparkContext, sqlContext=None): - """Create a new SQLContext. - - :param sparkContext: The SparkContext to wrap. - :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new - SQLContext in the JVM, instead we make all calls to this object. - - >>> df = sqlCtx.inferSchema(rdd) - >>> sqlCtx.inferSchema(df) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - TypeError:... - - >>> bad_rdd = sc.parallelize([1,2,3]) - >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - ValueError:... - - >>> from datetime import datetime - >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, - ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), - ... time=datetime(2014, 8, 1, 14, 1, 5))]) - >>> df = sqlCtx.inferSchema(allTypes) - >>> df.registerTempTable("allTypes") - >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' - ... 'from allTypes where b and i > 0').collect() - [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] - >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, - ... x.row.a, x.list)).collect() - [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] - """ - self._sc = sparkContext - self._jsc = self._sc._jsc - self._jvm = self._sc._jvm - self._scala_SQLContext = sqlContext - - @property - def _ssql_ctx(self): - """Accessor for the JVM Spark SQL context. - - Subclasses can override this property to provide their own - JVM Contexts. - """ - if self._scala_SQLContext is None: - self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) - return self._scala_SQLContext - - def registerFunction(self, name, f, returnType=StringType()): - """Registers a lambda function as a UDF so it can be used in SQL statements. - - In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not given it default to a string and conversion will automatically - be done. For any other return type, the produced object must match the specified type. - - >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x)) - >>> sqlCtx.sql("SELECT stringLengthString('test')").collect() - [Row(c0=u'4')] - >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) - >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() - [Row(c0=4)] - """ - func = lambda _, it: imap(lambda x: f(*x), it) - command = (func, None, - AutoBatchedSerializer(PickleSerializer()), - AutoBatchedSerializer(PickleSerializer())) - ser = CloudPickleSerializer() - pickled_command = ser.dumps(command) - if len(pickled_command) > (1 << 20): # 1M - broadcast = self._sc.broadcast(pickled_command) - pickled_command = ser.dumps(broadcast) - broadcast_vars = ListConverter().convert( - [x._jbroadcast for x in self._sc._pickled_broadcast_vars], - self._sc._gateway._gateway_client) - self._sc._pickled_broadcast_vars.clear() - env = MapConverter().convert(self._sc.environment, - self._sc._gateway._gateway_client) - includes = ListConverter().convert(self._sc._python_includes, - self._sc._gateway._gateway_client) - self._ssql_ctx.udf().registerPython(name, - bytearray(pickled_command), - env, - includes, - self._sc.pythonExec, - broadcast_vars, - self._sc._javaAccumulator, - returnType.json()) - - def inferSchema(self, rdd, samplingRatio=None): - """Infer and apply a schema to an RDD of L{Row}. - - When samplingRatio is specified, the schema is inferred by looking - at the types of each row in the sampled dataset. Otherwise, the - first 100 rows of the RDD are inspected. Nested collections are - supported, which can include array, dict, list, Row, tuple, - namedtuple, or object. - - Each row could be L{pyspark.sql.Row} object or namedtuple or objects. - Using top level dicts is deprecated, as dict is used to represent Maps. - - If a single column has multiple distinct inferred types, it may cause - runtime exceptions. - - >>> rdd = sc.parallelize( - ... [Row(field1=1, field2="row1"), - ... Row(field1=2, field2="row2"), - ... Row(field1=3, field2="row3")]) - >>> df = sqlCtx.inferSchema(rdd) - >>> df.collect()[0] - Row(field1=1, field2=u'row1') - - >>> NestedRow = Row("f1", "f2") - >>> nestedRdd1 = sc.parallelize([ - ... NestedRow(array('i', [1, 2]), {"row1": 1.0}), - ... NestedRow(array('i', [2, 3]), {"row2": 2.0})]) - >>> df = sqlCtx.inferSchema(nestedRdd1) - >>> df.collect() - [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})] - - >>> nestedRdd2 = sc.parallelize([ - ... NestedRow([[1, 2], [2, 3]], [1, 2]), - ... NestedRow([[2, 3], [3, 4]], [2, 3])]) - >>> df = sqlCtx.inferSchema(nestedRdd2) - >>> df.collect() - [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])] - - >>> from collections import namedtuple - >>> CustomRow = namedtuple('CustomRow', 'field1 field2') - >>> rdd = sc.parallelize( - ... [CustomRow(field1=1, field2="row1"), - ... CustomRow(field1=2, field2="row2"), - ... CustomRow(field1=3, field2="row3")]) - >>> df = sqlCtx.inferSchema(rdd) - >>> df.collect()[0] - Row(field1=1, field2=u'row1') - """ - - if isinstance(rdd, DataFrame): - raise TypeError("Cannot apply schema to DataFrame") - - first = rdd.first() - if not first: - raise ValueError("The first row in RDD is empty, " - "can not infer schema") - if type(first) is dict: - warnings.warn("Using RDD of dict to inferSchema is deprecated," - "please use pyspark.sql.Row instead") - - if samplingRatio is None: - schema = _infer_schema(first) - if _has_nulltype(schema): - for row in rdd.take(100)[1:]: - schema = _merge_type(schema, _infer_schema(row)) - if not _has_nulltype(schema): - break - else: - warnings.warn("Some of types cannot be determined by the " - "first 100 rows, please try again with sampling") - else: - if samplingRatio > 0.99: - rdd = rdd.sample(False, float(samplingRatio)) - schema = rdd.map(_infer_schema).reduce(_merge_type) - - converter = _create_converter(schema) - rdd = rdd.map(converter) - return self.applySchema(rdd, schema) - - def applySchema(self, rdd, schema): - """ - Applies the given schema to the given RDD of L{tuple} or L{list}. - - These tuples or lists can contain complex nested structures like - lists, maps or nested rows. - - The schema should be a StructType. - - It is important that the schema matches the types of the objects - in each row or exceptions could be thrown at runtime. - - >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")]) - >>> schema = StructType([StructField("field1", IntegerType(), False), - ... StructField("field2", StringType(), False)]) - >>> df = sqlCtx.applySchema(rdd2, schema) - >>> sqlCtx.registerRDDAsTable(df, "table1") - >>> df2 = sqlCtx.sql("SELECT * from table1") - >>> df2.collect() - [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] - - >>> from datetime import date, datetime - >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0, - ... date(2010, 1, 1), - ... datetime(2010, 1, 1, 1, 1, 1), - ... {"a": 1}, (2,), [1, 2, 3], None)]) - >>> schema = StructType([ - ... StructField("byte1", ByteType(), False), - ... StructField("byte2", ByteType(), False), - ... StructField("short1", ShortType(), False), - ... StructField("short2", ShortType(), False), - ... StructField("int", IntegerType(), False), - ... StructField("float", FloatType(), False), - ... StructField("date", DateType(), False), - ... StructField("time", TimestampType(), False), - ... StructField("map", - ... MapType(StringType(), IntegerType(), False), False), - ... StructField("struct", - ... StructType([StructField("b", ShortType(), False)]), False), - ... StructField("list", ArrayType(ByteType(), False), False), - ... StructField("null", DoubleType(), True)]) - >>> df = sqlCtx.applySchema(rdd, schema) - >>> results = df.map( - ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date, - ... x.time, x.map["a"], x.struct.b, x.list, x.null)) - >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE - (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1), - datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) - - >>> df.registerTempTable("table2") - >>> sqlCtx.sql( - ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + - ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + - ... "float + 1.5 as float FROM table2").collect() - [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.5)] - - >>> rdd = sc.parallelize([(127, -32768, 1.0, - ... datetime(2010, 1, 1, 1, 1, 1), - ... {"a": 1}, (2,), [1, 2, 3])]) - >>> abstract = "byte short float time map{} struct(b) list[]" - >>> schema = _parse_schema_abstract(abstract) - >>> typedSchema = _infer_schema_type(rdd.first(), schema) - >>> df = sqlCtx.applySchema(rdd, typedSchema) - >>> df.collect() - [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])] - """ - - if isinstance(rdd, DataFrame): - raise TypeError("Cannot apply schema to DataFrame") - - if not isinstance(schema, StructType): - raise TypeError("schema should be StructType") - - # take the first few rows to verify schema - rows = rdd.take(10) - # Row() cannot been deserialized by Pyrolite - if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': - rdd = rdd.map(tuple) - rows = rdd.take(10) - - for row in rows: - _verify_type(row, schema) - - # convert python objects to sql data - converter = _python_to_sql_converter(schema) - rdd = rdd.map(converter) - - jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) - df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) - return DataFrame(df, self) - - def registerRDDAsTable(self, rdd, tableName): - """Registers the given RDD as a temporary table in the catalog. - - Temporary tables exist only during the lifetime of this instance of - SQLContext. - - >>> df = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(df, "table1") - """ - if (rdd.__class__ is DataFrame): - df = rdd._jdf - self._ssql_ctx.registerRDDAsTable(df, tableName) - else: - raise ValueError("Can only register DataFrame as table") - - def parquetFile(self, path): - """Loads a Parquet file, returning the result as a L{DataFrame}. - - >>> import tempfile, shutil - >>> parquetFile = tempfile.mkdtemp() - >>> shutil.rmtree(parquetFile) - >>> df = sqlCtx.inferSchema(rdd) - >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlCtx.parquetFile(parquetFile) - >>> sorted(df.collect()) == sorted(df2.collect()) - True - """ - jdf = self._ssql_ctx.parquetFile(path) - return DataFrame(jdf, self) - - def jsonFile(self, path, schema=None, samplingRatio=1.0): - """ - Loads a text file storing one JSON object per line as a - L{DataFrame}. - - If the schema is provided, applies the given schema to this - JSON dataset. - - Otherwise, it samples the dataset with ratio `samplingRatio` to - determine the schema. - - >>> import tempfile, shutil - >>> jsonFile = tempfile.mkdtemp() - >>> shutil.rmtree(jsonFile) - >>> ofn = open(jsonFile, 'w') - >>> for json in jsonStrings: - ... print>>ofn, json - >>> ofn.close() - >>> df1 = sqlCtx.jsonFile(jsonFile) - >>> sqlCtx.registerRDDAsTable(df1, "table1") - >>> df2 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table1") - >>> for r in df2.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - - >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema()) - >>> sqlCtx.registerRDDAsTable(df3, "table2") - >>> df4 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table2") - >>> for r in df4.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - - >>> schema = StructType([ - ... StructField("field2", StringType(), True), - ... StructField("field3", - ... StructType([ - ... StructField("field5", - ... ArrayType(IntegerType(), False), True)]), False)]) - >>> df5 = sqlCtx.jsonFile(jsonFile, schema) - >>> sqlCtx.registerRDDAsTable(df5, "table3") - >>> df6 = sqlCtx.sql( - ... "SELECT field2 AS f1, field3.field5 as f2, " - ... "field3.field5[0] as f3 from table3") - >>> df6.collect() - [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] - """ - if schema is None: - df = self._ssql_ctx.jsonFile(path, samplingRatio) - else: - scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - df = self._ssql_ctx.jsonFile(path, scala_datatype) - return DataFrame(df, self) - - def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): - """Loads an RDD storing one JSON object per string as a L{DataFrame}. - - If the schema is provided, applies the given schema to this - JSON dataset. - - Otherwise, it samples the dataset with ratio `samplingRatio` to - determine the schema. - - >>> df1 = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerRDDAsTable(df1, "table1") - >>> df2 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table1") - >>> for r in df2.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - - >>> df3 = sqlCtx.jsonRDD(json, df1.schema()) - >>> sqlCtx.registerRDDAsTable(df3, "table2") - >>> df4 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table2") - >>> for r in df4.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - - >>> schema = StructType([ - ... StructField("field2", StringType(), True), - ... StructField("field3", - ... StructType([ - ... StructField("field5", - ... ArrayType(IntegerType(), False), True)]), False)]) - >>> df5 = sqlCtx.jsonRDD(json, schema) - >>> sqlCtx.registerRDDAsTable(df5, "table3") - >>> df6 = sqlCtx.sql( - ... "SELECT field2 AS f1, field3.field5 as f2, " - ... "field3.field5[0] as f3 from table3") - >>> df6.collect() - [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)] - - >>> sqlCtx.jsonRDD(sc.parallelize(['{}', - ... '{"key0": {"key1": "value1"}}'])).collect() - [Row(key0=None), Row(key0=Row(key1=u'value1'))] - >>> sqlCtx.jsonRDD(sc.parallelize(['{"key0": null}', - ... '{"key0": {"key1": "value1"}}'])).collect() - [Row(key0=None), Row(key0=Row(key1=u'value1'))] - """ - - def func(iterator): - for x in iterator: - if not isinstance(x, basestring): - x = unicode(x) - if isinstance(x, unicode): - x = x.encode("utf-8") - yield x - keyed = rdd.mapPartitions(func) - keyed._bypass_serializer = True - jrdd = keyed._jrdd.map(self._jvm.BytesToString()) - if schema is None: - df = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) - else: - scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) - return DataFrame(df, self) - - def sql(self, sqlQuery): - """Return a L{DataFrame} representing the result of the given query. - - >>> df = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(df, "table1") - >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") - >>> df2.collect() - [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] - """ - return DataFrame(self._ssql_ctx.sql(sqlQuery), self) - - def table(self, tableName): - """Returns the specified table as a L{DataFrame}. - - >>> df = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(df, "table1") - >>> df2 = sqlCtx.table("table1") - >>> sorted(df.collect()) == sorted(df2.collect()) - True - """ - return DataFrame(self._ssql_ctx.table(tableName), self) - - def cacheTable(self, tableName): - """Caches the specified table in-memory.""" - self._ssql_ctx.cacheTable(tableName) - - def uncacheTable(self, tableName): - """Removes the specified table from the in-memory cache.""" - self._ssql_ctx.uncacheTable(tableName) - - -class HiveContext(SQLContext): - - """A variant of Spark SQL that integrates with data stored in Hive. - - Configuration for Hive is read from hive-site.xml on the classpath. - It supports running both SQL and HiveQL commands. - """ - - def __init__(self, sparkContext, hiveContext=None): - """Create a new HiveContext. - - :param sparkContext: The SparkContext to wrap. - :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new - HiveContext in the JVM, instead we make all calls to this object. - """ - SQLContext.__init__(self, sparkContext) - - if hiveContext: - self._scala_HiveContext = hiveContext - - @property - def _ssql_ctx(self): - try: - if not hasattr(self, '_scala_HiveContext'): - self._scala_HiveContext = self._get_hive_ctx() - return self._scala_HiveContext - except Py4JError as e: - raise Exception("You must build Spark with Hive. " - "Export 'SPARK_HIVE=true' and run " - "build/sbt assembly", e) - - def _get_hive_ctx(self): - return self._jvm.HiveContext(self._jsc.sc()) - - -class LocalHiveContext(HiveContext): - - def __init__(self, sparkContext, sqlContext=None): - HiveContext.__init__(self, sparkContext, sqlContext) - warnings.warn("LocalHiveContext is deprecated. " - "Use HiveContext instead.", DeprecationWarning) - - def _get_hive_ctx(self): - return self._jvm.LocalHiveContext(self._jsc.sc()) - - -def _create_row(fields, values): - row = Row(*values) - row.__FIELDS__ = fields - return row - - -class Row(tuple): - - """ - A row in L{DataFrame}. The fields in it can be accessed like attributes. - - Row can be used to create a row object by using named arguments, - the fields will be sorted by names. - - >>> row = Row(name="Alice", age=11) - >>> row - Row(age=11, name='Alice') - >>> row.name, row.age - ('Alice', 11) - - Row also can be used to create another Row like class, then it - could be used to create Row objects, such as - - >>> Person = Row("name", "age") - >>> Person - - >>> Person("Alice", 11) - Row(name='Alice', age=11) - """ - - def __new__(self, *args, **kwargs): - if args and kwargs: - raise ValueError("Can not use both args " - "and kwargs to create Row") - if args: - # create row class or objects - return tuple.__new__(self, args) - - elif kwargs: - # create row objects - names = sorted(kwargs.keys()) - values = tuple(kwargs[n] for n in names) - row = tuple.__new__(self, values) - row.__FIELDS__ = names - return row - - else: - raise ValueError("No args or kwargs") - - def asDict(self): - """ - Return as an dict - """ - if not hasattr(self, "__FIELDS__"): - raise TypeError("Cannot convert a Row class into dict") - return dict(zip(self.__FIELDS__, self)) - - # let obect acs like class - def __call__(self, *args): - """create new Row object""" - return _create_row(self, args) - - def __getattr__(self, item): - if item.startswith("__"): - raise AttributeError(item) - try: - # it will be slow when it has many fields, - # but this will not be used in normal cases - idx = self.__FIELDS__.index(item) - return self[idx] - except IndexError: - raise AttributeError(item) - - def __reduce__(self): - if hasattr(self, "__FIELDS__"): - return (_create_row, (self.__FIELDS__, tuple(self))) - else: - return tuple.__reduce__(self) - - def __repr__(self): - if hasattr(self, "__FIELDS__"): - return "Row(%s)" % ", ".join("%s=%r" % (k, v) - for k, v in zip(self.__FIELDS__, self)) - else: - return "" % ", ".join(self) - - -class DataFrame(object): - - """A collection of rows that have the same columns. - - A :class:`DataFrame` is equivalent to a relational table in Spark SQL, - and can be created using various functions in :class:`SQLContext`:: - - people = sqlContext.parquetFile("...") - - Once created, it can be manipulated using the various domain-specific-language - (DSL) functions defined in: [[DataFrame]], [[Column]]. - - To select a column from the data frame, use the apply method:: - - ageCol = people.age - - Note that the :class:`Column` type can also be manipulated - through its various functions:: - - # The following creates a new column that increases everybody's age by 10. - people.age + 10 - - - A more concrete example:: - - # To create DataFrame using SQLContext - people = sqlContext.parquetFile("...") - department = sqlContext.parquetFile("...") - - people.filter(people.age > 30).join(department, people.deptId == department.id)) \ - .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) - """ - - def __init__(self, jdf, sql_ctx): - self._jdf = jdf - self.sql_ctx = sql_ctx - self._sc = sql_ctx and sql_ctx._sc - self.is_cached = False - - @property - def rdd(self): - """Return the content of the :class:`DataFrame` as an :class:`RDD` - of :class:`Row`s. """ - if not hasattr(self, '_lazy_rdd'): - jrdd = self._jdf.javaToPython() - rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) - schema = self.schema() - - def applySchema(it): - cls = _create_cls(schema) - return itertools.imap(cls, it) - - self._lazy_rdd = rdd.mapPartitions(applySchema) - - return self._lazy_rdd - - def limit(self, num): - """Limit the result count to the number specified. - - >>> df = sqlCtx.inferSchema(rdd) - >>> df.limit(2).collect() - [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] - >>> df.limit(0).collect() - [] - """ - jdf = self._jdf.limit(num) - return DataFrame(jdf, self.sql_ctx) - - def toJSON(self, use_unicode=False): - """Convert a DataFrame into a MappedRDD of JSON documents; one document per row. - - >>> df1 = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerRDDAsTable(df1, "table1") - >>> df2 = sqlCtx.sql( "SELECT * from table1") - >>> df2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}' - True - >>> df3 = sqlCtx.sql( "SELECT field3.field4 from table1") - >>> df3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}'] - True - """ - rdd = self._jdf.toJSON() - return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) - - def saveAsParquetFile(self, path): - """Save the contents as a Parquet file, preserving the schema. - - Files that are written out using this method can be read back in as - a DataFrame using the L{SQLContext.parquetFile} method. - - >>> import tempfile, shutil - >>> parquetFile = tempfile.mkdtemp() - >>> shutil.rmtree(parquetFile) - >>> df = sqlCtx.inferSchema(rdd) - >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlCtx.parquetFile(parquetFile) - >>> sorted(df2.collect()) == sorted(df.collect()) - True - """ - self._jdf.saveAsParquetFile(path) - - def registerTempTable(self, name): - """Registers this RDD as a temporary table using the given name. - - The lifetime of this temporary table is tied to the L{SQLContext} - that was used to create this DataFrame. - - >>> df = sqlCtx.inferSchema(rdd) - >>> df.registerTempTable("test") - >>> df2 = sqlCtx.sql("select * from test") - >>> sorted(df.collect()) == sorted(df2.collect()) - True - """ - self._jdf.registerTempTable(name) - - def registerAsTable(self, name): - """DEPRECATED: use registerTempTable() instead""" - warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) - self.registerTempTable(name) - - def insertInto(self, tableName, overwrite=False): - """Inserts the contents of this DataFrame into the specified table. - - Optionally overwriting any existing data. - """ - self._jdf.insertInto(tableName, overwrite) - - def saveAsTable(self, tableName): - """Creates a new table with the contents of this DataFrame.""" - self._jdf.saveAsTable(tableName) - - def schema(self): - """Returns the schema of this DataFrame (represented by - a L{StructType}).""" - return _parse_datatype_json_string(self._jdf.schema().json()) - - def printSchema(self): - """Prints out the schema in the tree format.""" - print (self._jdf.schema().treeString()) - - def count(self): - """Return the number of elements in this RDD. - - Unlike the base RDD implementation of count, this implementation - leverages the query optimizer to compute the count on the DataFrame, - which supports features such as filter pushdown. - - >>> df = sqlCtx.inferSchema(rdd) - >>> df.count() - 3L - >>> df.count() == df.map(lambda x: x).count() - True - """ - return self._jdf.count() - - def collect(self): - """Return a list that contains all of the rows. - - Each object in the list is a Row, the fields can be accessed as - attributes. - - >>> df = sqlCtx.inferSchema(rdd) - >>> df.collect() - [Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')] - """ - with SCCallSiteSync(self._sc) as css: - bytesInJava = self._jdf.javaToPython().collect().iterator() - cls = _create_cls(self.schema()) - tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir) - tempFile.close() - self._sc._writeToFile(bytesInJava, tempFile.name) - # Read the data into Python and deserialize it: - with open(tempFile.name, 'rb') as tempFile: - rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile)) - os.unlink(tempFile.name) - return [cls(r) for r in rs] - - def take(self, num): - """Take the first num rows of the RDD. - - Each object in the list is a Row, the fields can be accessed as - attributes. - - >>> df = sqlCtx.inferSchema(rdd) - >>> df.take(2) - [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] - """ - return self.limit(num).collect() - - def map(self, f): - """ Return a new RDD by applying a function to each Row, it's a - shorthand for df.rdd.map() - """ - return self.rdd.map(f) - - def mapPartitions(self, f, preservesPartitioning=False): - """ - Return a new RDD by applying a function to each partition. - - >>> rdd = sc.parallelize([1, 2, 3, 4], 4) - >>> def f(iterator): yield 1 - >>> rdd.mapPartitions(f).sum() - 4 - """ - return self.rdd.mapPartitions(f, preservesPartitioning) - - def cache(self): - """ Persist with the default storage level (C{MEMORY_ONLY_SER}). - """ - self.is_cached = True - self._jdf.cache() - return self - - def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): - """ Set the storage level to persist its values across operations - after the first time it is computed. This can only be used to assign - a new storage level if the RDD does not have a storage level set yet. - If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). - """ - self.is_cached = True - javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) - self._jdf.persist(javaStorageLevel) - return self - - def unpersist(self, blocking=True): - """ Mark it as non-persistent, and remove all blocks for it from - memory and disk. - """ - self.is_cached = False - self._jdf.unpersist(blocking) - return self - - # def coalesce(self, numPartitions, shuffle=False): - # rdd = self._jdf.coalesce(numPartitions, shuffle, None) - # return DataFrame(rdd, self.sql_ctx) - - def repartition(self, numPartitions): - """ Return a new :class:`DataFrame` that has exactly `numPartitions` - partitions. - """ - rdd = self._jdf.repartition(numPartitions, None) - return DataFrame(rdd, self.sql_ctx) - - def sample(self, withReplacement, fraction, seed=None): - """ - Return a sampled subset of this DataFrame. - - >>> df = sqlCtx.inferSchema(rdd) - >>> df.sample(False, 0.5, 97).count() - 2L - """ - assert fraction >= 0.0, "Negative fraction value: %s" % fraction - seed = seed if seed is not None else random.randint(0, sys.maxint) - rdd = self._jdf.sample(withReplacement, fraction, long(seed)) - return DataFrame(rdd, self.sql_ctx) - - # def takeSample(self, withReplacement, num, seed=None): - # """Return a fixed-size sampled subset of this DataFrame. - # - # >>> df = sqlCtx.inferSchema(rdd) - # >>> df.takeSample(False, 2, 97) - # [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')] - # """ - # seed = seed if seed is not None else random.randint(0, sys.maxint) - # with SCCallSiteSync(self.context) as css: - # bytesInJava = self._jdf \ - # .takeSampleToPython(withReplacement, num, long(seed)) \ - # .iterator() - # cls = _create_cls(self.schema()) - # return map(cls, self._collect_iterator_through_file(bytesInJava)) - - @property - def dtypes(self): - """Return all column names and their data types as a list. - """ - return [(f.name, str(f.dataType)) for f in self.schema().fields] - - @property - def columns(self): - """ Return all column names as a list. - """ - return [f.name for f in self.schema().fields] - - def show(self): - raise NotImplemented - - def join(self, other, joinExprs=None, joinType=None): - """ - Join with another DataFrame, using the given join expression. - The following performs a full outer join between `df1` and `df2`:: - - df1.join(df2, df1.key == df2.key, "outer") - - :param other: Right side of the join - :param joinExprs: Join expression - :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`, - `semijoin`. - """ - if joinType is None: - if joinExprs is None: - jdf = self._jdf.join(other._jdf) - else: - jdf = self._jdf.join(other._jdf, joinExprs) - else: - jdf = self._jdf.join(other._jdf, joinExprs, joinType) - return DataFrame(jdf, self.sql_ctx) - - def sort(self, *cols): - """ Return a new [[DataFrame]] sorted by the specified column, - in ascending column. - - :param cols: The columns or expressions used for sorting - """ - if not cols: - raise ValueError("should sort by at least one column") - for i, c in enumerate(cols): - if isinstance(c, basestring): - cols[i] = Column(c) - jcols = [c._jc for c in cols] - jdf = self._jdf.join(*jcols) - return DataFrame(jdf, self.sql_ctx) - - sortBy = sort - - def head(self, n=None): - """ Return the first `n` rows or the first row if n is None. """ - if n is None: - rs = self.head(1) - return rs[0] if rs else None - return self.take(n) - - def first(self): - """ Return the first row. """ - return self.head() - - def tail(self): - raise NotImplemented - - def __getitem__(self, item): - if isinstance(item, basestring): - return Column(self._jdf.apply(item)) - - # TODO projection - raise IndexError - - def __getattr__(self, name): - """ Return the column by given name """ - if name.startswith("__"): - raise AttributeError(name) - return Column(self._jdf.apply(name)) - - def alias(self, name): - """ Alias the current DataFrame """ - return DataFrame(getattr(self._jdf, "as")(name), self.sql_ctx) - - def select(self, *cols): - """ Selecting a set of expressions.:: - - df.select() - df.select('colA', 'colB') - df.select(df.colA, df.colB + 1) - - """ - if not cols: - cols = ["*"] - if isinstance(cols[0], basestring): - cols = [_create_column_from_name(n) for n in cols] - else: - cols = [c._jc for c in cols] - jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client) - jdf = self._jdf.select(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols)) - return DataFrame(jdf, self.sql_ctx) - - def filter(self, condition): - """ Filtering rows using the given condition:: - - df.filter(df.age > 15) - df.where(df.age > 15) - - """ - return DataFrame(self._jdf.filter(condition._jc), self.sql_ctx) - - where = filter - - def groupBy(self, *cols): - """ Group the [[DataFrame]] using the specified columns, - so we can run aggregation on them. See :class:`GroupedDataFrame` - for all the available aggregate functions:: - - df.groupBy(df.department).avg() - df.groupBy("department", "gender").agg({ - "salary": "avg", - "age": "max", - }) - """ - if cols and isinstance(cols[0], basestring): - cols = [_create_column_from_name(n) for n in cols] - else: - cols = [c._jc for c in cols] - jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client) - jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols)) - return GroupedDataFrame(jdf, self.sql_ctx) - - def agg(self, *exprs): - """ Aggregate on the entire [[DataFrame]] without groups - (shorthand for df.groupBy.agg()):: - - df.agg({"age": "max", "salary": "avg"}) - """ - return self.groupBy().agg(*exprs) - - def unionAll(self, other): - """ Return a new DataFrame containing union of rows in this - frame and another frame. - - This is equivalent to `UNION ALL` in SQL. - """ - return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx) - - def intersect(self, other): - """ Return a new [[DataFrame]] containing rows only in - both this frame and another frame. - - This is equivalent to `INTERSECT` in SQL. - """ - return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) - - def subtract(self, other): - """ Return a new [[DataFrame]] containing rows in this frame - but not in another frame. - - This is equivalent to `EXCEPT` in SQL. - """ - return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) - - def sample(self, withReplacement, fraction, seed=None): - """ Return a new DataFrame by sampling a fraction of rows. """ - if seed is None: - jdf = self._jdf.sample(withReplacement, fraction) - else: - jdf = self._jdf.sample(withReplacement, fraction, seed) - return DataFrame(jdf, self.sql_ctx) - - def addColumn(self, colName, col): - """ Return a new [[DataFrame]] by adding a column. """ - return self.select('*', col.alias(colName)) - - def removeColumn(self, colName): - raise NotImplemented - - -# Having SchemaRDD for backward compatibility (for docs) -class SchemaRDD(DataFrame): - """ - SchemaRDD is deprecated, please use DataFrame - """ - - -def dfapi(f): - def _api(self): - name = f.__name__ - jdf = getattr(self._jdf, name)() - return DataFrame(jdf, self.sql_ctx) - _api.__name__ = f.__name__ - _api.__doc__ = f.__doc__ - return _api - - -class GroupedDataFrame(object): - - """ - A set of methods for aggregations on a :class:`DataFrame`, - created by DataFrame.groupBy(). - """ - - def __init__(self, jdf, sql_ctx): - self._jdf = jdf - self.sql_ctx = sql_ctx - - def agg(self, *exprs): - """ Compute aggregates by specifying a map from column name - to aggregate methods. - - The available aggregate methods are `avg`, `max`, `min`, - `sum`, `count`. - - :param exprs: list or aggregate columns or a map from column - name to agregate methods. - """ - assert exprs, "exprs should not be empty" - if len(exprs) == 1 and isinstance(exprs[0], dict): - jmap = MapConverter().convert(exprs[0], - self.sql_ctx._sc._gateway._gateway_client) - jdf = self._jdf.agg(jmap) - else: - # Columns - assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" - jcols = ListConverter().convert([c._jc for c in exprs[1:]], - self.sql_ctx._sc._gateway._gateway_client) - jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.Dsl.toColumns(jcols)) - return DataFrame(jdf, self.sql_ctx) - - @dfapi - def count(self): - """ Count the number of rows for each group. """ - - @dfapi - def mean(self): - """Compute the average value for each numeric columns - for each group. This is an alias for `avg`.""" - - @dfapi - def avg(self): - """Compute the average value for each numeric columns - for each group.""" - - @dfapi - def max(self): - """Compute the max value for each numeric columns for - each group. """ - - @dfapi - def min(self): - """Compute the min value for each numeric column for - each group.""" - - @dfapi - def sum(self): - """Compute the sum for each numeric columns for each - group.""" - - -SCALA_METHOD_MAPPINGS = { - '=': '$eq', - '>': '$greater', - '<': '$less', - '+': '$plus', - '-': '$minus', - '*': '$times', - '/': '$div', - '!': '$bang', - '@': '$at', - '#': '$hash', - '%': '$percent', - '^': '$up', - '&': '$amp', - '~': '$tilde', - '?': '$qmark', - '|': '$bar', - '\\': '$bslash', - ':': '$colon', -} - - -def _create_column_from_literal(literal): - sc = SparkContext._active_spark_context - return sc._jvm.org.apache.spark.sql.Dsl.lit(literal) - - -def _create_column_from_name(name): - sc = SparkContext._active_spark_context - return sc._jvm.IncomputableColumn(name) - - -def _scalaMethod(name): - """ Translate operators into methodName in Scala - - For example: - >>> _scalaMethod('+') - '$plus' - >>> _scalaMethod('>=') - '$greater$eq' - >>> _scalaMethod('cast') - 'cast' - """ - return ''.join(SCALA_METHOD_MAPPINGS.get(c, c) for c in name) - - -def _unary_op(name): - """ Create a method for given unary operator """ - def _(self): - return Column(getattr(self._jc, _scalaMethod(name))(), self._jdf, self.sql_ctx) - return _ - - -def _bin_op(name, pass_literal_through=True): - """ Create a method for given binary operator - - Keyword arguments: - pass_literal_through -- whether to pass literal value directly through to the JVM. - """ - def _(self, other): - if isinstance(other, Column): - jc = other._jc - else: - if pass_literal_through: - jc = other - else: - jc = _create_column_from_literal(other) - return Column(getattr(self._jc, _scalaMethod(name))(jc), self._jdf, self.sql_ctx) - return _ - - -def _reverse_op(name): - """ Create a method for binary operator (this object is on right side) - """ - def _(self, other): - return Column(getattr(_create_column_from_literal(other), _scalaMethod(name))(self._jc), - self._jdf, self.sql_ctx) - return _ - - -class Column(DataFrame): - - """ - A column in a DataFrame. - - `Column` instances can be created by: - {{{ - // 1. Select a column out of a DataFrame - df.colName - df["colName"] - - // 2. Create from an expression - df["colName"] + 1 - }}} - """ - - def __init__(self, jc, jdf=None, sql_ctx=None): - self._jc = jc - super(Column, self).__init__(jdf, sql_ctx) - - # arithmetic operators - __neg__ = _unary_op("unary_-") - __add__ = _bin_op("+") - __sub__ = _bin_op("-") - __mul__ = _bin_op("*") - __div__ = _bin_op("/") - __mod__ = _bin_op("%") - __radd__ = _bin_op("+") - __rsub__ = _reverse_op("-") - __rmul__ = _bin_op("*") - __rdiv__ = _reverse_op("/") - __rmod__ = _reverse_op("%") - __abs__ = _unary_op("abs") - abs = _unary_op("abs") - sqrt = _unary_op("sqrt") - - # logistic operators - __eq__ = _bin_op("===") - __ne__ = _bin_op("!==") - __lt__ = _bin_op("<") - __le__ = _bin_op("<=") - __ge__ = _bin_op(">=") - __gt__ = _bin_op(">") - # `and`, `or`, `not` cannot be overloaded in Python - And = _bin_op('&&') - Or = _bin_op('||') - Not = _unary_op('unary_!') - - # bitwise operators - __and__ = _bin_op("&") - __or__ = _bin_op("|") - __invert__ = _unary_op("unary_~") - __xor__ = _bin_op("^") - # __lshift__ = _bin_op("<<") - # __rshift__ = _bin_op(">>") - __rand__ = _bin_op("&") - __ror__ = _bin_op("|") - __rxor__ = _bin_op("^") - # __rlshift__ = _reverse_op("<<") - # __rrshift__ = _reverse_op(">>") - - # container operators - __contains__ = _bin_op("contains") - __getitem__ = _bin_op("getItem") - # __getattr__ = _bin_op("getField") - - # string methods - rlike = _bin_op("rlike") - like = _bin_op("like") - startswith = _bin_op("startsWith") - endswith = _bin_op("endsWith") - upper = _unary_op("upper") - lower = _unary_op("lower") - - def substr(self, startPos, pos): - if type(startPos) != type(pos): - raise TypeError("Can not mix the type") - if isinstance(startPos, (int, long)): - jc = self._jc.substr(startPos, pos) - elif isinstance(startPos, Column): - jc = self._jc.substr(startPos._jc, pos._jc) - else: - raise TypeError("Unexpected type: %s" % type(startPos)) - return Column(jc, self._jdf, self.sql_ctx) - - __getslice__ = substr - - # order - asc = _unary_op("asc") - desc = _unary_op("desc") - - isNull = _unary_op("isNull") - isNotNull = _unary_op("isNotNull") - - # `as` is keyword - def alias(self, alias): - return Column(getattr(self._jsc, "as")(alias), self._jdf, self.sql_ctx) - - def cast(self, dataType): - if self.sql_ctx is None: - sc = SparkContext._active_spark_context - ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) - else: - ssql_ctx = self.sql_ctx._ssql_ctx - jdt = ssql_ctx.parseDataType(dataType.json()) - return Column(self._jc.cast(jdt), self._jdf, self.sql_ctx) - - -def _to_java_column(col): - if isinstance(col, Column): - jcol = col._jc - else: - jcol = _create_column_from_name(col) - return jcol - - -def _aggregate_func(name): - """ Create a function for aggregator by name""" - def _(col): - sc = SparkContext._active_spark_context - jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col)) - return Column(jc) - - return staticmethod(_) - - -class Aggregator(object): - """ - A collections of builtin aggregators - """ - AGGS = [ - 'lit', 'col', 'column', 'upper', 'lower', 'sqrt', 'abs', - 'min', 'max', 'first', 'last', 'count', 'avg', 'mean', 'sum', 'sumDistinct', - ] - for _name in AGGS: - locals()[_name] = _aggregate_func(_name) - del _name - - @staticmethod - def countDistinct(col, *cols): - sc = SparkContext._active_spark_context - jcols = ListConverter().convert([_to_java_column(c) for c in cols], - sc._gateway._gateway_client) - jc = sc._jvm.Dsl.countDistinct(_to_java_column(col), - sc._jvm.Dsl.toColumns(jcols)) - return Column(jc) - - @staticmethod - def approxCountDistinct(col, rsd=None): - sc = SparkContext._active_spark_context - if rsd is None: - jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col)) - else: - jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd) - return Column(jc) - - -def _test(): - import doctest - from pyspark.context import SparkContext - # let doctest run in pyspark.sql, so DataTypes can be picklable - import pyspark.sql - from pyspark.sql import Row, SQLContext - from pyspark.tests import ExamplePoint, ExamplePointUDT - globs = pyspark.sql.__dict__.copy() - sc = SparkContext('local[4]', 'PythonTest') - globs['sc'] = sc - globs['sqlCtx'] = SQLContext(sc) - globs['rdd'] = sc.parallelize( - [Row(field1=1, field2="row1"), - Row(field1=2, field2="row2"), - Row(field1=3, field2="row3")] - ) - globs['ExamplePoint'] = ExamplePoint - globs['ExamplePointUDT'] = ExamplePointUDT - jsonStrings = [ - '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', - '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' - '"field6":[{"field7": "row2"}]}', - '{"field1" : null, "field2": "row3", ' - '"field3":{"field4":33, "field5": []}}' - ] - globs['jsonStrings'] = jsonStrings - globs['json'] = sc.parallelize(jsonStrings) - (failure_count, test_count) = doctest.testmod( - pyspark.sql, globs=globs, optionflags=doctest.ELLIPSIS) - globs['sc'].stop() - if failure_count: - exit(-1) - - -if __name__ == "__main__": - _test() diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py new file mode 100644 index 0000000000000..65abb24eed823 --- /dev/null +++ b/python/pyspark/sql/__init__.py @@ -0,0 +1,47 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Important classes of Spark SQL and DataFrames: + + - L{SQLContext} + Main entry point for :class:`DataFrame` and SQL functionality. + - L{DataFrame} + A distributed collection of data grouped into named columns. + - L{Column} + A column expression in a :class:`DataFrame`. + - L{Row} + A row of data in a :class:`DataFrame`. + - L{HiveContext} + Main entry point for accessing data stored in Apache Hive. + - L{GroupedData} + Aggregation methods, returned by :func:`DataFrame.groupBy`. + - L{DataFrameNaFunctions} + Methods for handling missing data (null values). + - L{functions} + List of built-in functions available for :class:`DataFrame`. + - L{types} + List of data types available. +""" + +from pyspark.sql.context import SQLContext, HiveContext +from pyspark.sql.types import Row +from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD, DataFrameNaFunctions + +__all__ = [ + 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', 'DataFrameNaFunctions' +] diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py new file mode 100644 index 0000000000000..93e2d176a5b6f --- /dev/null +++ b/python/pyspark/sql/context.py @@ -0,0 +1,623 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +import json +from itertools import imap + +from py4j.protocol import Py4JError +from py4j.java_collections import MapConverter + +from pyspark.rdd import RDD, _prepare_for_python_RDD +from pyspark.serializers import AutoBatchedSerializer, PickleSerializer +from pyspark.sql.types import Row, StringType, StructType, _verify_type, \ + _infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter +from pyspark.sql.dataframe import DataFrame + +try: + import pandas + has_pandas = True +except ImportError: + has_pandas = False + +__all__ = ["SQLContext", "HiveContext", "UDFRegistration"] + + +def _monkey_patch_RDD(sqlContext): + def toDF(self, schema=None, sampleRatio=None): + """ + Converts current :class:`RDD` into a :class:`DataFrame` + + This is a shorthand for ``sqlContext.createDataFrame(rdd, schema, sampleRatio)`` + + :param schema: a StructType or list of names of columns + :param samplingRatio: the sample ratio of rows used for inferring + :return: a DataFrame + + >>> rdd.toDF().collect() + [Row(name=u'Alice', age=1)] + """ + return sqlContext.createDataFrame(self, schema, sampleRatio) + + RDD.toDF = toDF + + +class SQLContext(object): + """Main entry point for Spark SQL functionality. + + A SQLContext can be used create :class:`DataFrame`, register :class:`DataFrame` as + tables, execute SQL over tables, cache tables, and read parquet files. + + When created, :class:`SQLContext` adds a method called ``toDF`` to :class:`RDD`, + which could be used to convert an RDD into a DataFrame, it's a shorthand for + :func:`SQLContext.createDataFrame`. + + :param sparkContext: The :class:`SparkContext` backing this SQLContext. + :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instantiate a new + SQLContext in the JVM, instead we make all calls to this object. + """ + + def __init__(self, sparkContext, sqlContext=None): + """Creates a new SQLContext. + + >>> from datetime import datetime + >>> sqlContext = SQLContext(sc) + >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, + ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), + ... time=datetime(2014, 8, 1, 14, 1, 5))]) + >>> df = allTypes.toDF() + >>> df.registerTempTable("allTypes") + >>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' + ... 'from allTypes where b and i > 0').collect() + [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] + >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, + ... x.row.a, x.list)).collect() + [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] + """ + self._sc = sparkContext + self._jsc = self._sc._jsc + self._jvm = self._sc._jvm + self._scala_SQLContext = sqlContext + _monkey_patch_RDD(self) + + @property + def _ssql_ctx(self): + """Accessor for the JVM Spark SQL context. + + Subclasses can override this property to provide their own + JVM Contexts. + """ + if self._scala_SQLContext is None: + self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) + return self._scala_SQLContext + + def setConf(self, key, value): + """Sets the given Spark SQL configuration property. + """ + self._ssql_ctx.setConf(key, value) + + def getConf(self, key, defaultValue): + """Returns the value of Spark SQL configuration property for the given key. + + If the key is not set, returns defaultValue. + """ + return self._ssql_ctx.getConf(key, defaultValue) + + @property + def udf(self): + """Returns a :class:`UDFRegistration` for UDF registration.""" + return UDFRegistration(self) + + def registerFunction(self, name, f, returnType=StringType()): + """Registers a lambda function as a UDF so it can be used in SQL statements. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not given it default to a string and conversion will automatically + be done. For any other return type, the produced object must match the specified type. + + :param name: name of the UDF + :param samplingRatio: lambda function + :param returnType: a :class:`DataType` object + + >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x)) + >>> sqlContext.sql("SELECT stringLengthString('test')").collect() + [Row(c0=u'4')] + + >>> from pyspark.sql.types import IntegerType + >>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() + [Row(c0=4)] + + >>> from pyspark.sql.types import IntegerType + >>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() + [Row(c0=4)] + """ + func = lambda _, it: imap(lambda x: f(*x), it) + ser = AutoBatchedSerializer(PickleSerializer()) + command = (func, None, ser, ser) + pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self) + self._ssql_ctx.udf().registerPython(name, + bytearray(pickled_cmd), + env, + includes, + self._sc.pythonExec, + bvars, + self._sc._javaAccumulator, + returnType.json()) + + def _inferSchema(self, rdd, samplingRatio=None): + first = rdd.first() + if not first: + raise ValueError("The first row in RDD is empty, " + "can not infer schema") + if type(first) is dict: + warnings.warn("Using RDD of dict to inferSchema is deprecated," + "please use pyspark.sql.Row instead") + + if samplingRatio is None: + schema = _infer_schema(first) + if _has_nulltype(schema): + for row in rdd.take(100)[1:]: + schema = _merge_type(schema, _infer_schema(row)) + if not _has_nulltype(schema): + break + else: + raise ValueError("Some of types cannot be determined by the " + "first 100 rows, please try again with sampling") + else: + if samplingRatio < 0.99: + rdd = rdd.sample(False, float(samplingRatio)) + schema = rdd.map(_infer_schema).reduce(_merge_type) + return schema + + def inferSchema(self, rdd, samplingRatio=None): + """::note: Deprecated in 1.3, use :func:`createDataFrame` instead. + """ + warnings.warn("inferSchema is deprecated, please use createDataFrame instead") + + if isinstance(rdd, DataFrame): + raise TypeError("Cannot apply schema to DataFrame") + + return self.createDataFrame(rdd, None, samplingRatio) + + def applySchema(self, rdd, schema): + """::note: Deprecated in 1.3, use :func:`createDataFrame` instead. + """ + warnings.warn("applySchema is deprecated, please use createDataFrame instead") + + if isinstance(rdd, DataFrame): + raise TypeError("Cannot apply schema to DataFrame") + + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType, but got %s" % schema) + + return self.createDataFrame(rdd, schema) + + def createDataFrame(self, data, schema=None, samplingRatio=None): + """ + Creates a :class:`DataFrame` from an :class:`RDD` of :class:`tuple`/:class:`list`, + list or :class:`pandas.DataFrame`. + + When ``schema`` is a list of column names, the type of each column + will be inferred from ``data``. + + When ``schema`` is ``None``, it will try to infer the schema (column names and types) + from ``data``, which should be an RDD of :class:`Row`, + or :class:`namedtuple`, or :class:`dict`. + + If schema inference is needed, ``samplingRatio`` is used to determined the ratio of + rows used for schema inference. The first row will be used if ``samplingRatio`` is ``None``. + + :param data: an RDD of :class:`Row`/:class:`tuple`/:class:`list`/:class:`dict`, + :class:`list`, or :class:`pandas.DataFrame`. + :param schema: a :class:`StructType` or list of column names. default None. + :param samplingRatio: the sample ratio of rows used for inferring + + >>> l = [('Alice', 1)] + >>> sqlContext.createDataFrame(l).collect() + [Row(_1=u'Alice', _2=1)] + >>> sqlContext.createDataFrame(l, ['name', 'age']).collect() + [Row(name=u'Alice', age=1)] + + >>> d = [{'name': 'Alice', 'age': 1}] + >>> sqlContext.createDataFrame(d).collect() + [Row(age=1, name=u'Alice')] + + >>> rdd = sc.parallelize(l) + >>> sqlContext.createDataFrame(rdd).collect() + [Row(_1=u'Alice', _2=1)] + >>> df = sqlContext.createDataFrame(rdd, ['name', 'age']) + >>> df.collect() + [Row(name=u'Alice', age=1)] + + >>> from pyspark.sql import Row + >>> Person = Row('name', 'age') + >>> person = rdd.map(lambda r: Person(*r)) + >>> df2 = sqlContext.createDataFrame(person) + >>> df2.collect() + [Row(name=u'Alice', age=1)] + + >>> from pyspark.sql.types import * + >>> schema = StructType([ + ... StructField("name", StringType(), True), + ... StructField("age", IntegerType(), True)]) + >>> df3 = sqlContext.createDataFrame(rdd, schema) + >>> df3.collect() + [Row(name=u'Alice', age=1)] + + >>> sqlContext.createDataFrame(df.toPandas()).collect() # doctest: +SKIP + [Row(name=u'Alice', age=1)] + """ + if isinstance(data, DataFrame): + raise TypeError("data is already a DataFrame") + + if has_pandas and isinstance(data, pandas.DataFrame): + if schema is None: + schema = list(data.columns) + data = [r.tolist() for r in data.to_records(index=False)] + + if not isinstance(data, RDD): + try: + # data could be list, tuple, generator ... + rdd = self._sc.parallelize(data) + except Exception: + raise ValueError("cannot create an RDD from type: %s" % type(data)) + else: + rdd = data + + if schema is None: + schema = self._inferSchema(rdd, samplingRatio) + converter = _create_converter(schema) + rdd = rdd.map(converter) + + if isinstance(schema, (list, tuple)): + first = rdd.first() + if not isinstance(first, (list, tuple)): + raise ValueError("each row in `rdd` should be list or tuple, " + "but got %r" % type(first)) + row_cls = Row(*schema) + schema = self._inferSchema(rdd.map(lambda r: row_cls(*r)), samplingRatio) + + # take the first few rows to verify schema + rows = rdd.take(10) + # Row() cannot been deserialized by Pyrolite + if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': + rdd = rdd.map(tuple) + rows = rdd.take(10) + + for row in rows: + _verify_type(row, schema) + + # convert python objects to sql data + converter = _python_to_sql_converter(schema) + rdd = rdd.map(converter) + + jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) + df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) + return DataFrame(df, self) + + def registerDataFrameAsTable(self, df, tableName): + """Registers the given :class:`DataFrame` as a temporary table in the catalog. + + Temporary tables exist only during the lifetime of this instance of :class:`SQLContext`. + + >>> sqlContext.registerDataFrameAsTable(df, "table1") + """ + if (df.__class__ is DataFrame): + self._ssql_ctx.registerDataFrameAsTable(df._jdf, tableName) + else: + raise ValueError("Can only register DataFrame as table") + + def parquetFile(self, *paths): + """Loads a Parquet file, returning the result as a :class:`DataFrame`. + + >>> import tempfile, shutil + >>> parquetFile = tempfile.mkdtemp() + >>> shutil.rmtree(parquetFile) + >>> df.saveAsParquetFile(parquetFile) + >>> df2 = sqlContext.parquetFile(parquetFile) + >>> sorted(df.collect()) == sorted(df2.collect()) + True + """ + gateway = self._sc._gateway + jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths)) + for i in range(0, len(paths)): + jpaths[i] = paths[i] + jdf = self._ssql_ctx.parquetFile(jpaths) + return DataFrame(jdf, self) + + def jsonFile(self, path, schema=None, samplingRatio=1.0): + """Loads a text file storing one JSON object per line as a :class:`DataFrame`. + + If the schema is provided, applies the given schema to this JSON dataset. + Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema. + + >>> import tempfile, shutil + >>> jsonFile = tempfile.mkdtemp() + >>> shutil.rmtree(jsonFile) + >>> with open(jsonFile, 'w') as f: + ... f.writelines(jsonStrings) + >>> df1 = sqlContext.jsonFile(jsonFile) + >>> df1.printSchema() + root + |-- field1: long (nullable = true) + |-- field2: string (nullable = true) + |-- field3: struct (nullable = true) + | |-- field4: long (nullable = true) + + >>> from pyspark.sql.types import * + >>> schema = StructType([ + ... StructField("field2", StringType()), + ... StructField("field3", + ... StructType([StructField("field5", ArrayType(IntegerType()))]))]) + >>> df2 = sqlContext.jsonFile(jsonFile, schema) + >>> df2.printSchema() + root + |-- field2: string (nullable = true) + |-- field3: struct (nullable = true) + | |-- field5: array (nullable = true) + | | |-- element: integer (containsNull = true) + """ + if schema is None: + df = self._ssql_ctx.jsonFile(path, samplingRatio) + else: + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) + df = self._ssql_ctx.jsonFile(path, scala_datatype) + return DataFrame(df, self) + + def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): + """Loads an RDD storing one JSON object per string as a :class:`DataFrame`. + + If the schema is provided, applies the given schema to this JSON dataset. + Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema. + + >>> df1 = sqlContext.jsonRDD(json) + >>> df1.first() + Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None) + + >>> df2 = sqlContext.jsonRDD(json, df1.schema) + >>> df2.first() + Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None) + + >>> from pyspark.sql.types import * + >>> schema = StructType([ + ... StructField("field2", StringType()), + ... StructField("field3", + ... StructType([StructField("field5", ArrayType(IntegerType()))])) + ... ]) + >>> df3 = sqlContext.jsonRDD(json, schema) + >>> df3.first() + Row(field2=u'row1', field3=Row(field5=None)) + """ + + def func(iterator): + for x in iterator: + if not isinstance(x, basestring): + x = unicode(x) + if isinstance(x, unicode): + x = x.encode("utf-8") + yield x + keyed = rdd.mapPartitions(func) + keyed._bypass_serializer = True + jrdd = keyed._jrdd.map(self._jvm.BytesToString()) + if schema is None: + df = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) + else: + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) + df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) + return DataFrame(df, self) + + def load(self, path=None, source=None, schema=None, **options): + """Returns the dataset in a data source as a :class:`DataFrame`. + + The data source is specified by the ``source`` and a set of ``options``. + If ``source`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. + + Optionally, a schema can be provided as the schema of the returned DataFrame. + """ + if path is not None: + options["path"] = path + if source is None: + source = self.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + joptions = MapConverter().convert(options, + self._sc._gateway._gateway_client) + if schema is None: + df = self._ssql_ctx.load(source, joptions) + else: + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType") + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) + df = self._ssql_ctx.load(source, scala_datatype, joptions) + return DataFrame(df, self) + + def createExternalTable(self, tableName, path=None, source=None, + schema=None, **options): + """Creates an external table based on the dataset in a data source. + + It returns the DataFrame associated with the external table. + + The data source is specified by the ``source`` and a set of ``options``. + If ``source`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. + + Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and + created external table. + """ + if path is not None: + options["path"] = path + if source is None: + source = self.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + joptions = MapConverter().convert(options, + self._sc._gateway._gateway_client) + if schema is None: + df = self._ssql_ctx.createExternalTable(tableName, source, joptions) + else: + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType") + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) + df = self._ssql_ctx.createExternalTable(tableName, source, scala_datatype, + joptions) + return DataFrame(df, self) + + def sql(self, sqlQuery): + """Returns a :class:`DataFrame` representing the result of the given query. + + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> df2 = sqlContext.sql("SELECT field1 AS f1, field2 as f2 from table1") + >>> df2.collect() + [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] + """ + return DataFrame(self._ssql_ctx.sql(sqlQuery), self) + + def table(self, tableName): + """Returns the specified table as a :class:`DataFrame`. + + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> df2 = sqlContext.table("table1") + >>> sorted(df.collect()) == sorted(df2.collect()) + True + """ + return DataFrame(self._ssql_ctx.table(tableName), self) + + def tables(self, dbName=None): + """Returns a :class:`DataFrame` containing names of tables in the given database. + + If ``dbName`` is not specified, the current database will be used. + + The returned DataFrame has two columns: ``tableName`` and ``isTemporary`` + (a column with :class:`BooleanType` indicating if a table is a temporary one or not). + + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> df2 = sqlContext.tables() + >>> df2.filter("tableName = 'table1'").first() + Row(tableName=u'table1', isTemporary=True) + """ + if dbName is None: + return DataFrame(self._ssql_ctx.tables(), self) + else: + return DataFrame(self._ssql_ctx.tables(dbName), self) + + def tableNames(self, dbName=None): + """Returns a list of names of tables in the database ``dbName``. + + If ``dbName`` is not specified, the current database will be used. + + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> "table1" in sqlContext.tableNames() + True + >>> "table1" in sqlContext.tableNames("db") + True + """ + if dbName is None: + return [name for name in self._ssql_ctx.tableNames()] + else: + return [name for name in self._ssql_ctx.tableNames(dbName)] + + def cacheTable(self, tableName): + """Caches the specified table in-memory.""" + self._ssql_ctx.cacheTable(tableName) + + def uncacheTable(self, tableName): + """Removes the specified table from the in-memory cache.""" + self._ssql_ctx.uncacheTable(tableName) + + def clearCache(self): + """Removes all cached tables from the in-memory cache. """ + self._ssql_ctx.clearCache() + + +class HiveContext(SQLContext): + """A variant of Spark SQL that integrates with data stored in Hive. + + Configuration for Hive is read from ``hive-site.xml`` on the classpath. + It supports running both SQL and HiveQL commands. + + :param sparkContext: The SparkContext to wrap. + :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instantiate a new + :class:`HiveContext` in the JVM, instead we make all calls to this object. + """ + + def __init__(self, sparkContext, hiveContext=None): + SQLContext.__init__(self, sparkContext) + if hiveContext: + self._scala_HiveContext = hiveContext + + @property + def _ssql_ctx(self): + try: + if not hasattr(self, '_scala_HiveContext'): + self._scala_HiveContext = self._get_hive_ctx() + return self._scala_HiveContext + except Py4JError as e: + raise Exception("You must build Spark with Hive. " + "Export 'SPARK_HIVE=true' and run " + "build/sbt assembly", e) + + def _get_hive_ctx(self): + return self._jvm.HiveContext(self._jsc.sc()) + + +class UDFRegistration(object): + """Wrapper for user-defined function registration.""" + + def __init__(self, sqlContext): + self.sqlContext = sqlContext + + def register(self, name, f, returnType=StringType()): + return self.sqlContext.registerFunction(name, f, returnType) + + register.__doc__ = SQLContext.registerFunction.__doc__ + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import Row, SQLContext + import pyspark.sql.context + globs = pyspark.sql.context.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) + globs['rdd'] = rdd = sc.parallelize( + [Row(field1=1, field2="row1"), + Row(field1=2, field2="row2"), + Row(field1=3, field2="row3")] + ) + globs['df'] = rdd.toDF() + jsonStrings = [ + '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', + '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' + '"field6":[{"field7": "row2"}]}', + '{"field1" : null, "field2": "row3", ' + '"field3":{"field4":33, "field5": []}}' + ] + globs['jsonStrings'] = jsonStrings + globs['json'] = sc.parallelize(jsonStrings) + (failure_count, test_count) = doctest.testmod( + pyspark.sql.context, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py new file mode 100644 index 0000000000000..ef91a9c4f522d --- /dev/null +++ b/python/pyspark/sql/dataframe.py @@ -0,0 +1,1205 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import itertools +import warnings +import random + +from py4j.java_collections import ListConverter, MapConverter + +from pyspark.context import SparkContext +from pyspark.rdd import RDD, _load_from_socket +from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer +from pyspark.storagelevel import StorageLevel +from pyspark.traceback_utils import SCCallSiteSync +from pyspark.sql.types import * +from pyspark.sql.types import _create_cls, _parse_datatype_json_string + + +__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD", "DataFrameNaFunctions"] + + +class DataFrame(object): + """A distributed collection of data grouped into named columns. + + A :class:`DataFrame` is equivalent to a relational table in Spark SQL, + and can be created using various functions in :class:`SQLContext`:: + + people = sqlContext.parquetFile("...") + + Once created, it can be manipulated using the various domain-specific-language + (DSL) functions defined in: :class:`DataFrame`, :class:`Column`. + + To select a column from the data frame, use the apply method:: + + ageCol = people.age + + A more concrete example:: + + # To create DataFrame using SQLContext + people = sqlContext.parquetFile("...") + department = sqlContext.parquetFile("...") + + people.filter(people.age > 30).join(department, people.deptId == department.id)) \ + .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) + """ + + def __init__(self, jdf, sql_ctx): + self._jdf = jdf + self.sql_ctx = sql_ctx + self._sc = sql_ctx and sql_ctx._sc + self.is_cached = False + self._schema = None # initialized lazily + + @property + def rdd(self): + """Returns the content as an :class:`pyspark.RDD` of :class:`Row`. + """ + if not hasattr(self, '_lazy_rdd'): + jrdd = self._jdf.javaToPython() + rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) + schema = self.schema + + def applySchema(it): + cls = _create_cls(schema) + return itertools.imap(cls, it) + + self._lazy_rdd = rdd.mapPartitions(applySchema) + + return self._lazy_rdd + + @property + def na(self): + """Returns a :class:`DataFrameNaFunctions` for handling missing values. + """ + return DataFrameNaFunctions(self) + + def toJSON(self, use_unicode=False): + """Converts a :class:`DataFrame` into a :class:`RDD` of string. + + Each row is turned into a JSON document as one element in the returned RDD. + + >>> df.toJSON().first() + '{"age":2,"name":"Alice"}' + """ + rdd = self._jdf.toJSON() + return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) + + def saveAsParquetFile(self, path): + """Saves the contents as a Parquet file, preserving the schema. + + Files that are written out using this method can be read back in as + a :class:`DataFrame` using :func:`SQLContext.parquetFile`. + + >>> import tempfile, shutil + >>> parquetFile = tempfile.mkdtemp() + >>> shutil.rmtree(parquetFile) + >>> df.saveAsParquetFile(parquetFile) + >>> df2 = sqlContext.parquetFile(parquetFile) + >>> sorted(df2.collect()) == sorted(df.collect()) + True + """ + self._jdf.saveAsParquetFile(path) + + def registerTempTable(self, name): + """Registers this RDD as a temporary table using the given name. + + The lifetime of this temporary table is tied to the :class:`SQLContext` + that was used to create this :class:`DataFrame`. + + >>> df.registerTempTable("people") + >>> df2 = sqlContext.sql("select * from people") + >>> sorted(df.collect()) == sorted(df2.collect()) + True + """ + self._jdf.registerTempTable(name) + + def registerAsTable(self, name): + """DEPRECATED: use :func:`registerTempTable` instead""" + warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) + self.registerTempTable(name) + + def insertInto(self, tableName, overwrite=False): + """Inserts the contents of this :class:`DataFrame` into the specified table. + + Optionally overwriting any existing data. + """ + self._jdf.insertInto(tableName, overwrite) + + def _java_save_mode(self, mode): + """Returns the Java save mode based on the Python save mode represented by a string. + """ + jSaveMode = self._sc._jvm.org.apache.spark.sql.SaveMode + jmode = jSaveMode.ErrorIfExists + mode = mode.lower() + if mode == "append": + jmode = jSaveMode.Append + elif mode == "overwrite": + jmode = jSaveMode.Overwrite + elif mode == "ignore": + jmode = jSaveMode.Ignore + elif mode == "error": + pass + else: + raise ValueError( + "Only 'append', 'overwrite', 'ignore', and 'error' are acceptable save mode.") + return jmode + + def saveAsTable(self, tableName, source=None, mode="error", **options): + """Saves the contents of this :class:`DataFrame` to a data source as a table. + + The data source is specified by the ``source`` and a set of ``options``. + If ``source`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. + + Additionally, mode is used to specify the behavior of the saveAsTable operation when + table already exists in the data source. There are four modes: + + * `append`: Append contents of this :class:`DataFrame` to existing data. + * `overwrite`: Overwrite existing data. + * `error`: Throw an exception if data already exists. + * `ignore`: Silently ignore this operation if data already exists. + """ + if source is None: + source = self.sql_ctx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + jmode = self._java_save_mode(mode) + joptions = MapConverter().convert(options, + self.sql_ctx._sc._gateway._gateway_client) + self._jdf.saveAsTable(tableName, source, jmode, joptions) + + def save(self, path=None, source=None, mode="error", **options): + """Saves the contents of the :class:`DataFrame` to a data source. + + The data source is specified by the ``source`` and a set of ``options``. + If ``source`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. + + Additionally, mode is used to specify the behavior of the save operation when + data already exists in the data source. There are four modes: + + * `append`: Append contents of this :class:`DataFrame` to existing data. + * `overwrite`: Overwrite existing data. + * `error`: Throw an exception if data already exists. + * `ignore`: Silently ignore this operation if data already exists. + """ + if path is not None: + options["path"] = path + if source is None: + source = self.sql_ctx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + jmode = self._java_save_mode(mode) + joptions = MapConverter().convert(options, + self._sc._gateway._gateway_client) + self._jdf.save(source, jmode, joptions) + + @property + def schema(self): + """Returns the schema of this :class:`DataFrame` as a :class:`types.StructType`. + + >>> df.schema + StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true))) + """ + if self._schema is None: + self._schema = _parse_datatype_json_string(self._jdf.schema().json()) + return self._schema + + def printSchema(self): + """Prints out the schema in the tree format. + + >>> df.printSchema() + root + |-- age: integer (nullable = true) + |-- name: string (nullable = true) + + """ + print (self._jdf.schema().treeString()) + + def explain(self, extended=False): + """Prints the (logical and physical) plans to the console for debugging purpose. + + :param extended: boolean, default ``False``. If ``False``, prints only the physical plan. + + >>> df.explain() + PhysicalRDD [age#0,name#1], MapPartitionsRDD[...] at mapPartitions at SQLContext.scala:... + + >>> df.explain(True) + == Parsed Logical Plan == + ... + == Analyzed Logical Plan == + ... + == Optimized Logical Plan == + ... + == Physical Plan == + ... + == RDD == + """ + if extended: + print self._jdf.queryExecution().toString() + else: + print self._jdf.queryExecution().executedPlan().toString() + + def isLocal(self): + """Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally + (without any Spark executors). + """ + return self._jdf.isLocal() + + def show(self, n=20): + """Prints the first ``n`` rows to the console. + + >>> df + DataFrame[age: int, name: string] + >>> df.show() + age name + 2 Alice + 5 Bob + """ + print self._jdf.showString(n).encode('utf8', 'ignore') + + def __repr__(self): + return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) + + def count(self): + """Returns the number of rows in this :class:`DataFrame`. + + >>> df.count() + 2L + """ + return self._jdf.count() + + def collect(self): + """Returns all the records as a list of :class:`Row`. + + >>> df.collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + """ + with SCCallSiteSync(self._sc) as css: + port = self._sc._jvm.PythonRDD.collectAndServe(self._jdf.javaToPython().rdd()) + rs = list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) + cls = _create_cls(self.schema) + return [cls(r) for r in rs] + + def limit(self, num): + """Limits the result count to the number specified. + + >>> df.limit(1).collect() + [Row(age=2, name=u'Alice')] + >>> df.limit(0).collect() + [] + """ + jdf = self._jdf.limit(num) + return DataFrame(jdf, self.sql_ctx) + + def take(self, num): + """Returns the first ``num`` rows as a :class:`list` of :class:`Row`. + + >>> df.take(2) + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + """ + return self.limit(num).collect() + + def map(self, f): + """ Returns a new :class:`RDD` by applying a the ``f`` function to each :class:`Row`. + + This is a shorthand for ``df.rdd.map()``. + + >>> df.map(lambda p: p.name).collect() + [u'Alice', u'Bob'] + """ + return self.rdd.map(f) + + def flatMap(self, f): + """ Returns a new :class:`RDD` by first applying the ``f`` function to each :class:`Row`, + and then flattening the results. + + This is a shorthand for ``df.rdd.flatMap()``. + + >>> df.flatMap(lambda p: p.name).collect() + [u'A', u'l', u'i', u'c', u'e', u'B', u'o', u'b'] + """ + return self.rdd.flatMap(f) + + def mapPartitions(self, f, preservesPartitioning=False): + """Returns a new :class:`RDD` by applying the ``f`` function to each partition. + + This is a shorthand for ``df.rdd.mapPartitions()``. + + >>> rdd = sc.parallelize([1, 2, 3, 4], 4) + >>> def f(iterator): yield 1 + >>> rdd.mapPartitions(f).sum() + 4 + """ + return self.rdd.mapPartitions(f, preservesPartitioning) + + def foreach(self, f): + """Applies the ``f`` function to all :class:`Row` of this :class:`DataFrame`. + + This is a shorthand for ``df.rdd.foreach()``. + + >>> def f(person): + ... print person.name + >>> df.foreach(f) + """ + return self.rdd.foreach(f) + + def foreachPartition(self, f): + """Applies the ``f`` function to each partition of this :class:`DataFrame`. + + This a shorthand for ``df.rdd.foreachPartition()``. + + >>> def f(people): + ... for person in people: + ... print person.name + >>> df.foreachPartition(f) + """ + return self.rdd.foreachPartition(f) + + def cache(self): + """ Persists with the default storage level (C{MEMORY_ONLY_SER}). + """ + self.is_cached = True + self._jdf.cache() + return self + + def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): + """Sets the storage level to persist its values across operations + after the first time it is computed. This can only be used to assign + a new storage level if the RDD does not have a storage level set yet. + If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). + """ + self.is_cached = True + javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) + self._jdf.persist(javaStorageLevel) + return self + + def unpersist(self, blocking=True): + """Marks the :class:`DataFrame` as non-persistent, and remove all blocks for it from + memory and disk. + """ + self.is_cached = False + self._jdf.unpersist(blocking) + return self + + # def coalesce(self, numPartitions, shuffle=False): + # rdd = self._jdf.coalesce(numPartitions, shuffle, None) + # return DataFrame(rdd, self.sql_ctx) + + def repartition(self, numPartitions): + """Returns a new :class:`DataFrame` that has exactly ``numPartitions`` partitions. + + >>> df.repartition(10).rdd.getNumPartitions() + 10 + """ + return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx) + + def distinct(self): + """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. + + >>> df.distinct().count() + 2L + """ + return DataFrame(self._jdf.distinct(), self.sql_ctx) + + def sample(self, withReplacement, fraction, seed=None): + """Returns a sampled subset of this :class:`DataFrame`. + + >>> df.sample(False, 0.5, 97).count() + 1L + """ + assert fraction >= 0.0, "Negative fraction value: %s" % fraction + seed = seed if seed is not None else random.randint(0, sys.maxint) + rdd = self._jdf.sample(withReplacement, fraction, long(seed)) + return DataFrame(rdd, self.sql_ctx) + + @property + def dtypes(self): + """Returns all column names and their data types as a list. + + >>> df.dtypes + [('age', 'int'), ('name', 'string')] + """ + return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields] + + @property + def columns(self): + """Returns all column names as a list. + + >>> df.columns + [u'age', u'name'] + """ + return [f.name for f in self.schema.fields] + + def join(self, other, joinExprs=None, joinType=None): + """Joins with another :class:`DataFrame`, using the given join expression. + + The following performs a full outer join between ``df1`` and ``df2``. + + :param other: Right side of the join + :param joinExprs: Join expression + :param joinType: str, default 'inner'. + One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. + + >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() + [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] + """ + + if joinExprs is None: + jdf = self._jdf.join(other._jdf) + else: + assert isinstance(joinExprs, Column), "joinExprs should be Column" + if joinType is None: + jdf = self._jdf.join(other._jdf, joinExprs._jc) + else: + assert isinstance(joinType, basestring), "joinType should be basestring" + jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType) + return DataFrame(jdf, self.sql_ctx) + + def sort(self, *cols): + """Returns a new :class:`DataFrame` sorted by the specified column(s). + + :param cols: list of :class:`Column` to sort by. + + >>> df.sort(df.age.desc()).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + >>> df.orderBy(df.age.desc()).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + >>> from pyspark.sql.functions import * + >>> df.sort(asc("age")).collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + >>> df.orderBy(desc("age"), "name").collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + """ + if not cols: + raise ValueError("should sort by at least one column") + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + self._sc._gateway._gateway_client) + jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols)) + return DataFrame(jdf, self.sql_ctx) + + orderBy = sort + + def describe(self, *cols): + """Computes statistics for numeric columns. + + This include count, mean, stddev, min, and max. If no columns are + given, this function computes statistics for all numerical columns. + + >>> df.describe().show() + summary age + count 2 + mean 3.5 + stddev 1.5 + min 2 + max 5 + """ + cols = ListConverter().convert(cols, + self.sql_ctx._sc._gateway._gateway_client) + jdf = self._jdf.describe(self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)) + return DataFrame(jdf, self.sql_ctx) + + def head(self, n=None): + """ + Returns the first ``n`` rows as a list of :class:`Row`, + or the first :class:`Row` if ``n`` is ``None.`` + + >>> df.head() + Row(age=2, name=u'Alice') + >>> df.head(1) + [Row(age=2, name=u'Alice')] + """ + if n is None: + rs = self.head(1) + return rs[0] if rs else None + return self.take(n) + + def first(self): + """Returns the first row as a :class:`Row`. + + >>> df.first() + Row(age=2, name=u'Alice') + """ + return self.head() + + def __getitem__(self, item): + """Returns the column as a :class:`Column`. + + >>> df.select(df['age']).collect() + [Row(age=2), Row(age=5)] + >>> df[ ["name", "age"]].collect() + [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] + >>> df[ df.age > 3 ].collect() + [Row(age=5, name=u'Bob')] + """ + if isinstance(item, basestring): + jc = self._jdf.apply(item) + return Column(jc) + elif isinstance(item, Column): + return self.filter(item) + elif isinstance(item, list): + return self.select(*item) + else: + raise IndexError("unexpected index: %s" % item) + + def __getattr__(self, name): + """Returns the :class:`Column` denoted by ``name``. + + >>> df.select(df.age).collect() + [Row(age=2), Row(age=5)] + """ + if name.startswith("__"): + raise AttributeError(name) + jc = self._jdf.apply(name) + return Column(jc) + + def select(self, *cols): + """Projects a set of expressions and returns a new :class:`DataFrame`. + + :param cols: list of column names (string) or expressions (:class:`Column`). + If one of the column names is '*', that column is expanded to include all columns + in the current DataFrame. + + >>> df.select('*').collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + >>> df.select('name', 'age').collect() + [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] + >>> df.select(df.name, (df.age + 10).alias('age')).collect() + [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)] + """ + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + self._sc._gateway._gateway_client) + jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + return DataFrame(jdf, self.sql_ctx) + + def selectExpr(self, *expr): + """Projects a set of SQL expressions and returns a new :class:`DataFrame`. + + This is a variant of :func:`select` that accepts SQL expressions. + + >>> df.selectExpr("age * 2", "abs(age)").collect() + [Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)] + """ + jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client) + jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr)) + return DataFrame(jdf, self.sql_ctx) + + def filter(self, condition): + """Filters rows using the given condition. + + :func:`where` is an alias for :func:`filter`. + + :param condition: a :class:`Column` of :class:`types.BooleanType` + or a string of SQL expression. + + >>> df.filter(df.age > 3).collect() + [Row(age=5, name=u'Bob')] + >>> df.where(df.age == 2).collect() + [Row(age=2, name=u'Alice')] + + >>> df.filter("age > 3").collect() + [Row(age=5, name=u'Bob')] + >>> df.where("age = 2").collect() + [Row(age=2, name=u'Alice')] + """ + if isinstance(condition, basestring): + jdf = self._jdf.filter(condition) + elif isinstance(condition, Column): + jdf = self._jdf.filter(condition._jc) + else: + raise TypeError("condition should be string or Column") + return DataFrame(jdf, self.sql_ctx) + + where = filter + + def groupBy(self, *cols): + """Groups the :class:`DataFrame` using the specified columns, + so we can run aggregation on them. See :class:`GroupedData` + for all the available aggregate functions. + + :param cols: list of columns to group by. + Each element should be a column name (string) or an expression (:class:`Column`). + + >>> df.groupBy().avg().collect() + [Row(AVG(age)=3.5)] + >>> df.groupBy('name').agg({'age': 'mean'}).collect() + [Row(name=u'Bob', AVG(age)=5.0), Row(name=u'Alice', AVG(age)=2.0)] + >>> df.groupBy(df.name).avg().collect() + [Row(name=u'Bob', AVG(age)=5.0), Row(name=u'Alice', AVG(age)=2.0)] + """ + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + self._sc._gateway._gateway_client) + jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + return GroupedData(jdf, self.sql_ctx) + + def agg(self, *exprs): + """ Aggregate on the entire :class:`DataFrame` without groups + (shorthand for ``df.groupBy.agg()``). + + >>> df.agg({"age": "max"}).collect() + [Row(MAX(age)=5)] + >>> from pyspark.sql import functions as F + >>> df.agg(F.min(df.age)).collect() + [Row(MIN(age)=2)] + """ + return self.groupBy().agg(*exprs) + + def unionAll(self, other): + """ Return a new :class:`DataFrame` containing union of rows in this + frame and another frame. + + This is equivalent to `UNION ALL` in SQL. + """ + return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx) + + def intersect(self, other): + """ Return a new :class:`DataFrame` containing rows only in + both this frame and another frame. + + This is equivalent to `INTERSECT` in SQL. + """ + return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) + + def subtract(self, other): + """ Return a new :class:`DataFrame` containing rows in this frame + but not in another frame. + + This is equivalent to `EXCEPT` in SQL. + """ + return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) + + def dropna(self, how='any', thresh=None, subset=None): + """Returns a new :class:`DataFrame` omitting rows with null values. + + This is an alias for ``na.drop()``. + + :param how: 'any' or 'all'. + If 'any', drop a row if it contains any nulls. + If 'all', drop a row only if all its values are null. + :param thresh: int, default None + If specified, drop rows that have less than `thresh` non-null values. + This overwrites the `how` parameter. + :param subset: optional list of column names to consider. + + >>> df4.dropna().show() + age height name + 10 80 Alice + + >>> df4.na.drop().show() + age height name + 10 80 Alice + """ + if how is not None and how not in ['any', 'all']: + raise ValueError("how ('" + how + "') should be 'any' or 'all'") + + if subset is None: + subset = self.columns + elif isinstance(subset, basestring): + subset = [subset] + elif not isinstance(subset, (list, tuple)): + raise ValueError("subset should be a list or tuple of column names") + + if thresh is None: + thresh = len(subset) if how == 'any' else 1 + + cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client) + cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols) + return DataFrame(self._jdf.na().drop(thresh, cols), self.sql_ctx) + + def fillna(self, value, subset=None): + """Replace null values, alias for ``na.fill()``. + + :param value: int, long, float, string, or dict. + Value to replace null values with. + If the value is a dict, then `subset` is ignored and `value` must be a mapping + from column name (string) to replacement value. The replacement value must be + an int, long, float, or string. + :param subset: optional list of column names to consider. + Columns specified in subset that do not have matching data type are ignored. + For example, if `value` is a string, and subset contains a non-string column, + then the non-string column is simply ignored. + + >>> df4.fillna(50).show() + age height name + 10 80 Alice + 5 50 Bob + 50 50 Tom + 50 50 null + + >>> df4.fillna({'age': 50, 'name': 'unknown'}).show() + age height name + 10 80 Alice + 5 null Bob + 50 null Tom + 50 null unknown + + >>> df4.na.fill({'age': 50, 'name': 'unknown'}).show() + age height name + 10 80 Alice + 5 null Bob + 50 null Tom + 50 null unknown + """ + if not isinstance(value, (float, int, long, basestring, dict)): + raise ValueError("value should be a float, int, long, string, or dict") + + if isinstance(value, (int, long)): + value = float(value) + + if isinstance(value, dict): + value = MapConverter().convert(value, self.sql_ctx._sc._gateway._gateway_client) + return DataFrame(self._jdf.na().fill(value), self.sql_ctx) + elif subset is None: + return DataFrame(self._jdf.na().fill(value), self.sql_ctx) + else: + if isinstance(subset, basestring): + subset = [subset] + elif not isinstance(subset, (list, tuple)): + raise ValueError("subset should be a list or tuple of column names") + + cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client) + cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols) + return DataFrame(self._jdf.na().fill(value, cols), self.sql_ctx) + + def withColumn(self, colName, col): + """Returns a new :class:`DataFrame` by adding a column. + + :param colName: string, name of the new column. + :param col: a :class:`Column` expression for the new column. + + >>> df.withColumn('age2', df.age + 2).collect() + [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] + """ + return self.select('*', col.alias(colName)) + + def withColumnRenamed(self, existing, new): + """REturns a new :class:`DataFrame` by renaming an existing column. + + :param existing: string, name of the existing column to rename. + :param col: string, new name of the column. + + >>> df.withColumnRenamed('age', 'age2').collect() + [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')] + """ + cols = [Column(_to_java_column(c)).alias(new) + if c == existing else c + for c in self.columns] + return self.select(*cols) + + def toPandas(self): + """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. + + This is only available if Pandas is installed and available. + + >>> df.toPandas() # doctest: +SKIP + age name + 0 2 Alice + 1 5 Bob + """ + import pandas as pd + return pd.DataFrame.from_records(self.collect(), columns=self.columns) + + +# Having SchemaRDD for backward compatibility (for docs) +class SchemaRDD(DataFrame): + """SchemaRDD is deprecated, please use :class:`DataFrame`. + """ + + +def dfapi(f): + def _api(self): + name = f.__name__ + jdf = getattr(self._jdf, name)() + return DataFrame(jdf, self.sql_ctx) + _api.__name__ = f.__name__ + _api.__doc__ = f.__doc__ + return _api + + +def df_varargs_api(f): + def _api(self, *args): + jargs = ListConverter().convert(args, + self.sql_ctx._sc._gateway._gateway_client) + name = f.__name__ + jdf = getattr(self._jdf, name)(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jargs)) + return DataFrame(jdf, self.sql_ctx) + _api.__name__ = f.__name__ + _api.__doc__ = f.__doc__ + return _api + + +class GroupedData(object): + """ + A set of methods for aggregations on a :class:`DataFrame`, + created by :func:`DataFrame.groupBy`. + """ + + def __init__(self, jdf, sql_ctx): + self._jdf = jdf + self.sql_ctx = sql_ctx + + def agg(self, *exprs): + """Compute aggregates and returns the result as a :class:`DataFrame`. + + The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`. + + If ``exprs`` is a single :class:`dict` mapping from string to string, then the key + is the column to perform aggregation on, and the value is the aggregate function. + + Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions. + + :param exprs: a dict mapping from column name (string) to aggregate functions (string), + or a list of :class:`Column`. + + >>> gdf = df.groupBy(df.name) + >>> gdf.agg({"*": "count"}).collect() + [Row(name=u'Bob', COUNT(1)=1), Row(name=u'Alice', COUNT(1)=1)] + + >>> from pyspark.sql import functions as F + >>> gdf.agg(F.min(df.age)).collect() + [Row(MIN(age)=5), Row(MIN(age)=2)] + """ + assert exprs, "exprs should not be empty" + if len(exprs) == 1 and isinstance(exprs[0], dict): + jmap = MapConverter().convert(exprs[0], + self.sql_ctx._sc._gateway._gateway_client) + jdf = self._jdf.agg(jmap) + else: + # Columns + assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" + jcols = ListConverter().convert([c._jc for c in exprs[1:]], + self.sql_ctx._sc._gateway._gateway_client) + jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + return DataFrame(jdf, self.sql_ctx) + + @dfapi + def count(self): + """Counts the number of records for each group. + + >>> df.groupBy(df.age).count().collect() + [Row(age=2, count=1), Row(age=5, count=1)] + """ + + @df_varargs_api + def mean(self, *cols): + """Computes average values for each numeric columns for each group. + + :func:`mean` is an alias for :func:`avg`. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().mean('age').collect() + [Row(AVG(age)=3.5)] + >>> df3.groupBy().mean('age', 'height').collect() + [Row(AVG(age)=3.5, AVG(height)=82.5)] + """ + + @df_varargs_api + def avg(self, *cols): + """Computes average values for each numeric columns for each group. + + :func:`mean` is an alias for :func:`avg`. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().avg('age').collect() + [Row(AVG(age)=3.5)] + >>> df3.groupBy().avg('age', 'height').collect() + [Row(AVG(age)=3.5, AVG(height)=82.5)] + """ + + @df_varargs_api + def max(self, *cols): + """Computes the max value for each numeric columns for each group. + + >>> df.groupBy().max('age').collect() + [Row(MAX(age)=5)] + >>> df3.groupBy().max('age', 'height').collect() + [Row(MAX(age)=5, MAX(height)=85)] + """ + + @df_varargs_api + def min(self, *cols): + """Computes the min value for each numeric column for each group. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().min('age').collect() + [Row(MIN(age)=2)] + >>> df3.groupBy().min('age', 'height').collect() + [Row(MIN(age)=2, MIN(height)=80)] + """ + + @df_varargs_api + def sum(self, *cols): + """Compute the sum for each numeric columns for each group. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().sum('age').collect() + [Row(SUM(age)=7)] + >>> df3.groupBy().sum('age', 'height').collect() + [Row(SUM(age)=7, SUM(height)=165)] + """ + + +def _create_column_from_literal(literal): + sc = SparkContext._active_spark_context + return sc._jvm.functions.lit(literal) + + +def _create_column_from_name(name): + sc = SparkContext._active_spark_context + return sc._jvm.functions.col(name) + + +def _to_java_column(col): + if isinstance(col, Column): + jcol = col._jc + else: + jcol = _create_column_from_name(col) + return jcol + + +def _unary_op(name, doc="unary operator"): + """ Create a method for given unary operator """ + def _(self): + jc = getattr(self._jc, name)() + return Column(jc) + _.__doc__ = doc + return _ + + +def _func_op(name, doc=''): + def _(self): + sc = SparkContext._active_spark_context + jc = getattr(sc._jvm.functions, name)(self._jc) + return Column(jc) + _.__doc__ = doc + return _ + + +def _bin_op(name, doc="binary operator"): + """ Create a method for given binary operator + """ + def _(self, other): + jc = other._jc if isinstance(other, Column) else other + njc = getattr(self._jc, name)(jc) + return Column(njc) + _.__doc__ = doc + return _ + + +def _reverse_op(name, doc="binary operator"): + """ Create a method for binary operator (this object is on right side) + """ + def _(self, other): + jother = _create_column_from_literal(other) + jc = getattr(jother, name)(self._jc) + return Column(jc) + _.__doc__ = doc + return _ + + +class Column(object): + + """ + A column in a DataFrame. + + :class:`Column` instances can be created by:: + + # 1. Select a column out of a DataFrame + + df.colName + df["colName"] + + # 2. Create from an expression + df.colName + 1 + 1 / df.colName + """ + + def __init__(self, jc): + self._jc = jc + + # arithmetic operators + __neg__ = _func_op("negate") + __add__ = _bin_op("plus") + __sub__ = _bin_op("minus") + __mul__ = _bin_op("multiply") + __div__ = _bin_op("divide") + __mod__ = _bin_op("mod") + __radd__ = _bin_op("plus") + __rsub__ = _reverse_op("minus") + __rmul__ = _bin_op("multiply") + __rdiv__ = _reverse_op("divide") + __rmod__ = _reverse_op("mod") + + # logistic operators + __eq__ = _bin_op("equalTo") + __ne__ = _bin_op("notEqual") + __lt__ = _bin_op("lt") + __le__ = _bin_op("leq") + __ge__ = _bin_op("geq") + __gt__ = _bin_op("gt") + + # `and`, `or`, `not` cannot be overloaded in Python, + # so use bitwise operators as boolean operators + __and__ = _bin_op('and') + __or__ = _bin_op('or') + __invert__ = _func_op('not') + __rand__ = _bin_op("and") + __ror__ = _bin_op("or") + + # container operators + __contains__ = _bin_op("contains") + __getitem__ = _bin_op("getItem") + getField = _bin_op("getField", "An expression that gets a field by name in a StructField.") + + # string methods + rlike = _bin_op("rlike") + like = _bin_op("like") + startswith = _bin_op("startsWith") + endswith = _bin_op("endsWith") + + def substr(self, startPos, length): + """ + Return a :class:`Column` which is a substring of the column + + :param startPos: start position (int or Column) + :param length: length of the substring (int or Column) + + >>> df.select(df.name.substr(1, 3).alias("col")).collect() + [Row(col=u'Ali'), Row(col=u'Bob')] + """ + if type(startPos) != type(length): + raise TypeError("Can not mix the type") + if isinstance(startPos, (int, long)): + jc = self._jc.substr(startPos, length) + elif isinstance(startPos, Column): + jc = self._jc.substr(startPos._jc, length._jc) + else: + raise TypeError("Unexpected type: %s" % type(startPos)) + return Column(jc) + + __getslice__ = substr + + def inSet(self, *cols): + """ A boolean expression that is evaluated to true if the value of this + expression is contained by the evaluated values of the arguments. + + >>> df[df.name.inSet("Bob", "Mike")].collect() + [Row(age=5, name=u'Bob')] + >>> df[df.age.inSet([1, 2, 3])].collect() + [Row(age=2, name=u'Alice')] + """ + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols] + sc = SparkContext._active_spark_context + jcols = ListConverter().convert(cols, sc._gateway._gateway_client) + jc = getattr(self._jc, "in")(sc._jvm.PythonUtils.toSeq(jcols)) + return Column(jc) + + # order + asc = _unary_op("asc", "Returns a sort expression based on the" + " ascending order of the given column name.") + desc = _unary_op("desc", "Returns a sort expression based on the" + " descending order of the given column name.") + + isNull = _unary_op("isNull", "True if the current expression is null.") + isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") + + def alias(self, alias): + """Return a alias for this column + + >>> df.select(df.age.alias("age2")).collect() + [Row(age2=2), Row(age2=5)] + """ + return Column(getattr(self._jc, "as")(alias)) + + def cast(self, dataType): + """ Convert the column into type `dataType` + + >>> df.select(df.age.cast("string").alias('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + >>> df.select(df.age.cast(StringType()).alias('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + """ + if isinstance(dataType, basestring): + jc = self._jc.cast(dataType) + elif isinstance(dataType, DataType): + sc = SparkContext._active_spark_context + ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) + jdt = ssql_ctx.parseDataType(dataType.json()) + jc = self._jc.cast(jdt) + else: + raise TypeError("unexpected type: %s" % type(dataType)) + return Column(jc) + + def __repr__(self): + return 'Column<%s>' % self._jc.toString().encode('utf8') + + +class DataFrameNaFunctions(object): + """Functionality for working with missing data in :class:`DataFrame`. + """ + + def __init__(self, df): + self.df = df + + def drop(self, how='any', thresh=None, subset=None): + return self.df.dropna(how=how, thresh=thresh, subset=subset) + + drop.__doc__ = DataFrame.dropna.__doc__ + + def fill(self, value, subset=None): + return self.df.fillna(value=value, subset=subset) + + fill.__doc__ = DataFrame.fillna.__doc__ + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import Row, SQLContext + import pyspark.sql.dataframe + globs = pyspark.sql.dataframe.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) + globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')])\ + .toDF(StructType([StructField('age', IntegerType()), + StructField('name', StringType())])) + globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF() + globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80), + Row(name='Bob', age=5, height=85)]).toDF() + + globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80), + Row(name='Bob', age=5, height=None), + Row(name='Tom', age=None, height=None), + Row(name=None, age=None, height=None)]).toDF() + + (failure_count, test_count) = doctest.testmod( + pyspark.sql.dataframe, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py new file mode 100644 index 0000000000000..daeb6916b58bc --- /dev/null +++ b/python/pyspark/sql/functions.py @@ -0,0 +1,175 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +A collections of builtin functions +""" + +from itertools import imap + +from py4j.java_collections import ListConverter + +from pyspark import SparkContext +from pyspark.rdd import _prepare_for_python_RDD +from pyspark.serializers import PickleSerializer, AutoBatchedSerializer +from pyspark.sql.types import StringType +from pyspark.sql.dataframe import Column, _to_java_column + + +__all__ = ['countDistinct', 'approxCountDistinct', 'udf'] + + +def _create_function(name, doc=""): + """ Create a function for aggregator by name""" + def _(col): + sc = SparkContext._active_spark_context + jc = getattr(sc._jvm.functions, name)(col._jc if isinstance(col, Column) else col) + return Column(jc) + _.__name__ = name + _.__doc__ = doc + return _ + + +_functions = { + 'lit': 'Creates a :class:`Column` of literal value.', + 'col': 'Returns a :class:`Column` based on the given column name.', + 'column': 'Returns a :class:`Column` based on the given column name.', + 'asc': 'Returns a sort expression based on the ascending order of the given column name.', + 'desc': 'Returns a sort expression based on the descending order of the given column name.', + + 'upper': 'Converts a string expression to upper case.', + 'lower': 'Converts a string expression to upper case.', + 'sqrt': 'Computes the square root of the specified float value.', + 'abs': 'Computes the absolutle value.', + + 'max': 'Aggregate function: returns the maximum value of the expression in a group.', + 'min': 'Aggregate function: returns the minimum value of the expression in a group.', + 'first': 'Aggregate function: returns the first value in a group.', + 'last': 'Aggregate function: returns the last value in a group.', + 'count': 'Aggregate function: returns the number of items in a group.', + 'sum': 'Aggregate function: returns the sum of all values in the expression.', + 'avg': 'Aggregate function: returns the average of the values in a group.', + 'mean': 'Aggregate function: returns the average of the values in a group.', + 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.', +} + + +for _name, _doc in _functions.items(): + globals()[_name] = _create_function(_name, _doc) +del _name, _doc +__all__ += _functions.keys() +__all__.sort() + + +def countDistinct(col, *cols): + """Returns a new :class:`Column` for distinct count of ``col`` or ``cols``. + + >>> df.agg(countDistinct(df.age, df.name).alias('c')).collect() + [Row(c=2)] + + >>> df.agg(countDistinct("age", "name").alias('c')).collect() + [Row(c=2)] + """ + sc = SparkContext._active_spark_context + jcols = ListConverter().convert([_to_java_column(c) for c in cols], sc._gateway._gateway_client) + jc = sc._jvm.functions.countDistinct(_to_java_column(col), sc._jvm.PythonUtils.toSeq(jcols)) + return Column(jc) + + +def approxCountDistinct(col, rsd=None): + """Returns a new :class:`Column` for approximate distinct count of ``col``. + + >>> df.agg(approxCountDistinct(df.age).alias('c')).collect() + [Row(c=2)] + """ + sc = SparkContext._active_spark_context + if rsd is None: + jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col)) + else: + jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col), rsd) + return Column(jc) + + +class UserDefinedFunction(object): + """ + User defined function in Python + """ + def __init__(self, func, returnType): + self.func = func + self.returnType = returnType + self._broadcast = None + self._judf = self._create_judf() + + def _create_judf(self): + f = self.func # put it in closure `func` + func = lambda _, it: imap(lambda x: f(*x), it) + ser = AutoBatchedSerializer(PickleSerializer()) + command = (func, None, ser, ser) + sc = SparkContext._active_spark_context + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) + ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) + jdt = ssql_ctx.parseDataType(self.returnType.json()) + fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ + judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, + includes, sc.pythonExec, broadcast_vars, + sc._javaAccumulator, jdt) + return judf + + def __del__(self): + if self._broadcast is not None: + self._broadcast.unpersist() + self._broadcast = None + + def __call__(self, *cols): + sc = SparkContext._active_spark_context + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + sc._gateway._gateway_client) + jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols)) + return Column(jc) + + +def udf(f, returnType=StringType()): + """Creates a :class:`Column` expression representing a user defined function (UDF). + + >>> from pyspark.sql.types import IntegerType + >>> slen = udf(lambda s: len(s), IntegerType()) + >>> df.select(slen(df.name).alias('slen')).collect() + [Row(slen=5), Row(slen=3)] + """ + return UserDefinedFunction(f, returnType) + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import Row, SQLContext + import pyspark.sql.functions + globs = pyspark.sql.functions.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) + globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF() + (failure_count, test_count) = doctest.testmod( + pyspark.sql.functions, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py new file mode 100644 index 0000000000000..b3a6a2c6a9229 --- /dev/null +++ b/python/pyspark/sql/tests.py @@ -0,0 +1,623 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Unit tests for pyspark.sql; additional tests are implemented as doctests in +individual modules. +""" +import os +import sys +import pydoc +import shutil +import tempfile +import pickle +import functools + +import py4j + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark.sql import SQLContext, HiveContext, Column, Row +from pyspark.sql.types import * +from pyspark.sql.types import UserDefinedType, _infer_type +from pyspark.tests import ReusedPySparkTestCase +from pyspark.sql.functions import UserDefinedFunction + + +class ExamplePointUDT(UserDefinedType): + """ + User-defined type (UDT) for ExamplePoint. + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + @classmethod + def module(cls): + return 'pyspark.tests' + + @classmethod + def scalaUDT(cls): + return 'org.apache.spark.sql.test.ExamplePointUDT' + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return ExamplePoint(datum[0], datum[1]) + + +class ExamplePoint: + """ + An example class to demonstrate UDT in Scala, Java, and Python. + """ + + __UDT__ = ExamplePointUDT() + + def __init__(self, x, y): + self.x = x + self.y = y + + def __repr__(self): + return "ExamplePoint(%s,%s)" % (self.x, self.y) + + def __str__(self): + return "(%s,%s)" % (self.x, self.y) + + def __eq__(self, other): + return isinstance(other, ExamplePoint) and \ + other.x == self.x and other.y == self.y + + +class DataTypeTests(unittest.TestCase): + # regression test for SPARK-6055 + def test_data_type_eq(self): + lt = LongType() + lt2 = pickle.loads(pickle.dumps(LongType())) + self.assertEquals(lt, lt2) + + +class SQLTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + cls.sqlCtx = SQLContext(cls.sc) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + rdd = cls.sc.parallelize(cls.testData) + cls.df = rdd.toDF() + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name, ignore_errors=True) + + def test_udf_with_callable(self): + d = [Row(number=i, squared=i**2) for i in range(10)] + rdd = self.sc.parallelize(d) + data = self.sqlCtx.createDataFrame(rdd) + + class PlusFour: + def __call__(self, col): + if col is not None: + return col + 4 + + call = PlusFour() + pudf = UserDefinedFunction(call, LongType()) + res = data.select(pudf(data['number']).alias('plus_four')) + self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) + + def test_udf_with_partial_function(self): + d = [Row(number=i, squared=i**2) for i in range(10)] + rdd = self.sc.parallelize(d) + data = self.sqlCtx.createDataFrame(rdd) + + def some_func(col, param): + if col is not None: + return col + param + + pfunc = functools.partial(some_func, param=4) + pudf = UserDefinedFunction(pfunc, LongType()) + res = data.select(pudf(data['number']).alias('plus_four')) + self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) + + def test_udf(self): + self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) + [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], 5) + + def test_udf2(self): + self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType()) + self.sqlCtx.createDataFrame(self.sc.parallelize([Row(a="test")])).registerTempTable("test") + [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() + self.assertEqual(4, res[0]) + + def test_udf_with_array_type(self): + d = [Row(l=range(3), d={"key": range(5)})] + rdd = self.sc.parallelize(d) + self.sqlCtx.createDataFrame(rdd).registerTempTable("test") + self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) + self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) + [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() + self.assertEqual(range(3), l1) + self.assertEqual(1, l2) + + def test_broadcast_in_udf(self): + bar = {"a": "aa", "b": "bb", "c": "abc"} + foo = self.sc.broadcast(bar) + self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') + [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect() + self.assertEqual("abc", res[0]) + [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() + self.assertEqual("", res[0]) + + def test_basic_functions(self): + rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) + df = self.sqlCtx.jsonRDD(rdd) + df.count() + df.collect() + df.schema + + # cache and checkpoint + self.assertFalse(df.is_cached) + df.persist() + df.unpersist() + df.cache() + self.assertTrue(df.is_cached) + self.assertEqual(2, df.count()) + + df.registerTempTable("temp") + df = self.sqlCtx.sql("select foo from temp") + df.count() + df.collect() + + def test_apply_schema_to_row(self): + df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) + df2 = self.sqlCtx.createDataFrame(df.map(lambda x: x), df.schema) + self.assertEqual(df.collect(), df2.collect()) + + rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) + df3 = self.sqlCtx.createDataFrame(rdd, df.schema) + self.assertEqual(10, df3.count()) + + def test_serialize_nested_array_and_map(self): + d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] + rdd = self.sc.parallelize(d) + df = self.sqlCtx.createDataFrame(rdd) + row = df.head() + self.assertEqual(1, len(row.l)) + self.assertEqual(1, row.l[0].a) + self.assertEqual("2", row.d["key"].d) + + l = df.map(lambda x: x.l).first() + self.assertEqual(1, len(l)) + self.assertEqual('s', l[0].b) + + d = df.map(lambda x: x.d).first() + self.assertEqual(1, len(d)) + self.assertEqual(1.0, d["key"].c) + + row = df.map(lambda x: x.d["key"]).first() + self.assertEqual(1.0, row.c) + self.assertEqual("2", row.d) + + def test_infer_schema(self): + d = [Row(l=[], d={}, s=None), + Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] + rdd = self.sc.parallelize(d) + df = self.sqlCtx.createDataFrame(rdd) + self.assertEqual([], df.map(lambda r: r.l).first()) + self.assertEqual([None, ""], df.map(lambda r: r.s).collect()) + df.registerTempTable("test") + result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") + self.assertEqual(1, result.head()[0]) + + df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0) + self.assertEqual(df.schema, df2.schema) + self.assertEqual({}, df2.map(lambda r: r.d).first()) + self.assertEqual([None, ""], df2.map(lambda r: r.s).collect()) + df2.registerTempTable("test2") + result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") + self.assertEqual(1, result.head()[0]) + + def test_infer_nested_schema(self): + NestedRow = Row("f1", "f2") + nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}), + NestedRow([2, 3], {"row2": 2.0})]) + df = self.sqlCtx.inferSchema(nestedRdd1) + self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0]) + + nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]), + NestedRow([[2, 3], [3, 4]], [2, 3])]) + df = self.sqlCtx.inferSchema(nestedRdd2) + self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0]) + + from collections import namedtuple + CustomRow = namedtuple('CustomRow', 'field1 field2') + rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"), + CustomRow(field1=2, field2="row2"), + CustomRow(field1=3, field2="row3")]) + df = self.sqlCtx.inferSchema(rdd) + self.assertEquals(Row(field1=1, field2=u'row1'), df.first()) + + def test_apply_schema(self): + from datetime import date, datetime + rdd = self.sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0, + date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1), + {"a": 1}, (2,), [1, 2, 3], None)]) + schema = StructType([ + StructField("byte1", ByteType(), False), + StructField("byte2", ByteType(), False), + StructField("short1", ShortType(), False), + StructField("short2", ShortType(), False), + StructField("int1", IntegerType(), False), + StructField("float1", FloatType(), False), + StructField("date1", DateType(), False), + StructField("time1", TimestampType(), False), + StructField("map1", MapType(StringType(), IntegerType(), False), False), + StructField("struct1", StructType([StructField("b", ShortType(), False)]), False), + StructField("list1", ArrayType(ByteType(), False), False), + StructField("null1", DoubleType(), True)]) + df = self.sqlCtx.applySchema(rdd, schema) + results = df.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1, + x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1)) + r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1), + datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + self.assertEqual(r, results.first()) + + df.registerTempTable("table2") + r = self.sqlCtx.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + + "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " + + "float1 + 1.5 as float1 FROM table2").first() + + self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r)) + + from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type + rdd = self.sc.parallelize([(127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), + {"a": 1}, (2,), [1, 2, 3])]) + abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]" + schema = _parse_schema_abstract(abstract) + typedSchema = _infer_schema_type(rdd.first(), schema) + df = self.sqlCtx.applySchema(rdd, typedSchema) + r = (127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), {"a": 1}, Row(b=2), [1, 2, 3]) + self.assertEqual(r, tuple(df.first())) + + def test_struct_in_map(self): + d = [Row(m={Row(i=1): Row(s="")})] + df = self.sc.parallelize(d).toDF() + k, v = df.head().m.items()[0] + self.assertEqual(1, k.i) + self.assertEqual("", v.s) + + def test_convert_row_to_dict(self): + row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) + self.assertEqual(1, row.asDict()['l'][0].a) + df = self.sc.parallelize([row]).toDF() + df.registerTempTable("test") + row = self.sqlCtx.sql("select l, d from test").head() + self.assertEqual(1, row.asDict()["l"][0].a) + self.assertEqual(1.0, row.asDict()['d']['key'].c) + + def test_infer_schema_with_udt(self): + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + df = self.sc.parallelize([row]).toDF() + schema = df.schema + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), ExamplePointUDT) + df.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point + self.assertEqual(point, ExamplePoint(1.0, 2.0)) + + def test_apply_schema_with_udt(self): + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + row = (1.0, ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + df = rdd.toDF(schema) + point = df.head().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + + def test_parquet_with_udt(self): + from pyspark.sql.tests import ExamplePoint + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + df0 = self.sc.parallelize([row]).toDF() + output_dir = os.path.join(self.tempdir.name, "labeled_point") + df0.saveAsParquetFile(output_dir) + df1 = self.sqlCtx.parquetFile(output_dir) + point = df1.head().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + + def test_column_operators(self): + ci = self.df.key + cs = self.df.value + c = ci == cs + self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) + rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) + self.assertTrue(all(isinstance(c, Column) for c in rcc)) + cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs] + self.assertTrue(all(isinstance(c, Column) for c in cb)) + cbool = (ci & ci), (ci | ci), (~ci) + self.assertTrue(all(isinstance(c, Column) for c in cbool)) + css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a') + self.assertTrue(all(isinstance(c, Column) for c in css)) + self.assertTrue(isinstance(ci.cast(LongType()), Column)) + + def test_column_select(self): + df = self.df + self.assertEqual(self.testData, df.select("*").collect()) + self.assertEqual(self.testData, df.select(df.key, df.value).collect()) + self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) + + def test_aggregator(self): + df = self.df + g = df.groupBy() + self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) + self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) + + from pyspark.sql import functions + self.assertEqual((0, u'99'), + tuple(g.agg(functions.first(df.key), functions.last(df.value)).first())) + self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0]) + self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0]) + + def test_save_and_load(self): + df = self.df + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df.save(tmpPath, "org.apache.spark.sql.json", "error") + actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json") + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + + schema = StructType([StructField("value", StringType(), True)]) + actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json", schema) + self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect())) + + df.save(tmpPath, "org.apache.spark.sql.json", "overwrite") + actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json") + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + + df.save(source="org.apache.spark.sql.json", mode="overwrite", path=tmpPath, + noUse="this options will not be used in save.") + actual = self.sqlCtx.load(source="org.apache.spark.sql.json", path=tmpPath, + noUse="this options will not be used in load.") + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + + defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + actual = self.sqlCtx.load(path=tmpPath) + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + + shutil.rmtree(tmpPath) + + def test_help_command(self): + # Regression test for SPARK-5464 + rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) + df = self.sqlCtx.jsonRDD(rdd) + # render_doc() reproduces the help() exception without printing output + pydoc.render_doc(df) + pydoc.render_doc(df.foo) + pydoc.render_doc(df.take(1)) + + def test_infer_long_type(self): + longrow = [Row(f1='a', f2=100000000000000)] + df = self.sc.parallelize(longrow).toDF() + self.assertEqual(df.schema.fields[1].dataType, LongType()) + + # this saving as Parquet caused issues as well. + output_dir = os.path.join(self.tempdir.name, "infer_long_type") + df.saveAsParquetFile(output_dir) + df1 = self.sqlCtx.parquetFile(output_dir) + self.assertEquals('a', df1.first().f1) + self.assertEquals(100000000000000, df1.first().f2) + + self.assertEqual(_infer_type(1), LongType()) + self.assertEqual(_infer_type(2**10), LongType()) + self.assertEqual(_infer_type(2**20), LongType()) + self.assertEqual(_infer_type(2**31 - 1), LongType()) + self.assertEqual(_infer_type(2**31), LongType()) + self.assertEqual(_infer_type(2**61), LongType()) + self.assertEqual(_infer_type(2**71), LongType()) + + def test_dropna(self): + schema = StructType([ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", DoubleType(), True)]) + + # shouldn't drop a non-null row + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', 50, 80.1)], schema).dropna().count(), + 1) + + # dropping rows with a single null value + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 80.1)], schema).dropna().count(), + 0) + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 80.1)], schema).dropna(how='any').count(), + 0) + + # if how = 'all', only drop rows if all values are null + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 80.1)], schema).dropna(how='all').count(), + 1) + self.assertEqual(self.sqlCtx.createDataFrame( + [(None, None, None)], schema).dropna(how='all').count(), + 0) + + # how and subset + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', 50, None)], schema).dropna(how='any', subset=['name', 'age']).count(), + 1) + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, None)], schema).dropna(how='any', subset=['name', 'age']).count(), + 0) + + # threshold + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 80.1)], schema).dropna(thresh=2).count(), + 1) + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, None)], schema).dropna(thresh=2).count(), + 0) + + # threshold and subset + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', 50, None)], schema).dropna(thresh=2, subset=['name', 'age']).count(), + 1) + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 180.9)], schema).dropna(thresh=2, subset=['name', 'age']).count(), + 0) + + # thresh should take precedence over how + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', 50, None)], schema).dropna( + how='any', thresh=2, subset=['name', 'age']).count(), + 1) + + def test_fillna(self): + schema = StructType([ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", DoubleType(), True)]) + + # fillna shouldn't change non-null values + row = self.sqlCtx.createDataFrame([(u'Alice', 10, 80.1)], schema).fillna(50).first() + self.assertEqual(row.age, 10) + + # fillna with int + row = self.sqlCtx.createDataFrame([(u'Alice', None, None)], schema).fillna(50).first() + self.assertEqual(row.age, 50) + self.assertEqual(row.height, 50.0) + + # fillna with double + row = self.sqlCtx.createDataFrame([(u'Alice', None, None)], schema).fillna(50.1).first() + self.assertEqual(row.age, 50) + self.assertEqual(row.height, 50.1) + + # fillna with string + row = self.sqlCtx.createDataFrame([(None, None, None)], schema).fillna("hello").first() + self.assertEqual(row.name, u"hello") + self.assertEqual(row.age, None) + + # fillna with subset specified for numeric cols + row = self.sqlCtx.createDataFrame( + [(None, None, None)], schema).fillna(50, subset=['name', 'age']).first() + self.assertEqual(row.name, None) + self.assertEqual(row.age, 50) + self.assertEqual(row.height, None) + + # fillna with subset specified for numeric cols + row = self.sqlCtx.createDataFrame( + [(None, None, None)], schema).fillna("haha", subset=['name', 'age']).first() + self.assertEqual(row.name, "haha") + self.assertEqual(row.age, None) + self.assertEqual(row.height, None) + + +class HiveContextSQLTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + try: + cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() + except py4j.protocol.Py4JError: + cls.sqlCtx = None + return + os.unlink(cls.tempdir.name) + _scala_HiveContext =\ + cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc()) + cls.sqlCtx = HiveContext(cls.sc, _scala_HiveContext) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + cls.df = cls.sc.parallelize(cls.testData).toDF() + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name, ignore_errors=True) + + def test_save_and_load_table(self): + if self.sqlCtx is None: + return # no hive available, skipped + + df = self.df + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "append", path=tmpPath) + actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath, + "org.apache.spark.sql.json") + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + self.sqlCtx.sql("DROP TABLE externalJsonTable") + + df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "overwrite", path=tmpPath) + schema = StructType([StructField("value", StringType(), True)]) + actual = self.sqlCtx.createExternalTable("externalJsonTable", + source="org.apache.spark.sql.json", + schema=schema, path=tmpPath, + noUse="this options will not be used") + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + self.assertTrue( + sorted(df.select("value").collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect())) + self.sqlCtx.sql("DROP TABLE savedJsonTable") + self.sqlCtx.sql("DROP TABLE externalJsonTable") + + defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + df.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite") + actual = self.sqlCtx.createExternalTable("externalJsonTable", path=tmpPath) + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + self.sqlCtx.sql("DROP TABLE savedJsonTable") + self.sqlCtx.sql("DROP TABLE externalJsonTable") + self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + + shutil.rmtree(tmpPath) + +if __name__ == "__main__": + unittest.main() diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py new file mode 100644 index 0000000000000..7e0124b13671b --- /dev/null +++ b/python/pyspark/sql/types.py @@ -0,0 +1,1251 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import decimal +import datetime +import keyword +import warnings +import json +import re +import weakref +from array import array +from operator import itemgetter + + +__all__ = [ + "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", + "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", + "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType"] + + +class DataType(object): + """Base class for data types.""" + + def __repr__(self): + return self.__class__.__name__ + + def __hash__(self): + return hash(str(self)) + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not self.__eq__(other) + + @classmethod + def typeName(cls): + return cls.__name__[:-4].lower() + + def simpleString(self): + return self.typeName() + + def jsonValue(self): + return self.typeName() + + def json(self): + return json.dumps(self.jsonValue(), + separators=(',', ':'), + sort_keys=True) + + +# This singleton pattern does not work with pickle, you will get +# another object after pickle and unpickle +class PrimitiveTypeSingleton(type): + """Metaclass for PrimitiveType""" + + _instances = {} + + def __call__(cls): + if cls not in cls._instances: + cls._instances[cls] = super(PrimitiveTypeSingleton, cls).__call__() + return cls._instances[cls] + + +class PrimitiveType(DataType): + """Spark SQL PrimitiveType""" + + __metaclass__ = PrimitiveTypeSingleton + + +class NullType(PrimitiveType): + """Null type. + + The data type representing None, used for the types that cannot be inferred. + """ + + +class StringType(PrimitiveType): + """String data type. + """ + + +class BinaryType(PrimitiveType): + """Binary (byte array) data type. + """ + + +class BooleanType(PrimitiveType): + """Boolean data type. + """ + + +class DateType(PrimitiveType): + """Date (datetime.date) data type. + """ + + +class TimestampType(PrimitiveType): + """Timestamp (datetime.datetime) data type. + """ + + +class DecimalType(DataType): + """Decimal (decimal.Decimal) data type. + """ + + def __init__(self, precision=None, scale=None): + self.precision = precision + self.scale = scale + self.hasPrecisionInfo = precision is not None + + def simpleString(self): + if self.hasPrecisionInfo: + return "decimal(%d,%d)" % (self.precision, self.scale) + else: + return "decimal(10,0)" + + def jsonValue(self): + if self.hasPrecisionInfo: + return "decimal(%d,%d)" % (self.precision, self.scale) + else: + return "decimal" + + def __repr__(self): + if self.hasPrecisionInfo: + return "DecimalType(%d,%d)" % (self.precision, self.scale) + else: + return "DecimalType()" + + +class DoubleType(PrimitiveType): + """Double data type, representing double precision floats. + """ + + +class FloatType(PrimitiveType): + """Float data type, representing single precision floats. + """ + + +class ByteType(PrimitiveType): + """Byte data type, i.e. a signed integer in a single byte. + """ + def simpleString(self): + return 'tinyint' + + +class IntegerType(PrimitiveType): + """Int data type, i.e. a signed 32-bit integer. + """ + def simpleString(self): + return 'int' + + +class LongType(PrimitiveType): + """Long data type, i.e. a signed 64-bit integer. + + If the values are beyond the range of [-9223372036854775808, 9223372036854775807], + please use :class:`DecimalType`. + """ + def simpleString(self): + return 'bigint' + + +class ShortType(PrimitiveType): + """Short data type, i.e. a signed 16-bit integer. + """ + def simpleString(self): + return 'smallint' + + +class ArrayType(DataType): + """Array data type. + + :param elementType: :class:`DataType` of each element in the array. + :param containsNull: boolean, whether the array can contain null (None) values. + """ + + def __init__(self, elementType, containsNull=True): + """ + >>> ArrayType(StringType()) == ArrayType(StringType(), True) + True + >>> ArrayType(StringType(), False) == ArrayType(StringType()) + False + """ + assert isinstance(elementType, DataType), "elementType should be DataType" + self.elementType = elementType + self.containsNull = containsNull + + def simpleString(self): + return 'array<%s>' % self.elementType.simpleString() + + def __repr__(self): + return "ArrayType(%s,%s)" % (self.elementType, + str(self.containsNull).lower()) + + def jsonValue(self): + return {"type": self.typeName(), + "elementType": self.elementType.jsonValue(), + "containsNull": self.containsNull} + + @classmethod + def fromJson(cls, json): + return ArrayType(_parse_datatype_json_value(json["elementType"]), + json["containsNull"]) + + +class MapType(DataType): + """Map data type. + + :param keyType: :class:`DataType` of the keys in the map. + :param valueType: :class:`DataType` of the values in the map. + :param valueContainsNull: indicates whether values can contain null (None) values. + + Keys in a map data type are not allowed to be null (None). + """ + + def __init__(self, keyType, valueType, valueContainsNull=True): + """ + >>> (MapType(StringType(), IntegerType()) + ... == MapType(StringType(), IntegerType(), True)) + True + >>> (MapType(StringType(), IntegerType(), False) + ... == MapType(StringType(), FloatType())) + False + """ + assert isinstance(keyType, DataType), "keyType should be DataType" + assert isinstance(valueType, DataType), "valueType should be DataType" + self.keyType = keyType + self.valueType = valueType + self.valueContainsNull = valueContainsNull + + def simpleString(self): + return 'map<%s,%s>' % (self.keyType.simpleString(), self.valueType.simpleString()) + + def __repr__(self): + return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, + str(self.valueContainsNull).lower()) + + def jsonValue(self): + return {"type": self.typeName(), + "keyType": self.keyType.jsonValue(), + "valueType": self.valueType.jsonValue(), + "valueContainsNull": self.valueContainsNull} + + @classmethod + def fromJson(cls, json): + return MapType(_parse_datatype_json_value(json["keyType"]), + _parse_datatype_json_value(json["valueType"]), + json["valueContainsNull"]) + + +class StructField(DataType): + """A field in :class:`StructType`. + + :param name: string, name of the field. + :param dataType: :class:`DataType` of the field. + :param nullable: boolean, whether the field can be null (None) or not. + :param metadata: a dict from string to simple type that can be serialized to JSON automatically + """ + + def __init__(self, name, dataType, nullable=True, metadata=None): + """ + >>> (StructField("f1", StringType(), True) + ... == StructField("f1", StringType(), True)) + True + >>> (StructField("f1", StringType(), True) + ... == StructField("f2", StringType(), True)) + False + """ + assert isinstance(dataType, DataType), "dataType should be DataType" + self.name = name + self.dataType = dataType + self.nullable = nullable + self.metadata = metadata or {} + + def simpleString(self): + return '%s:%s' % (self.name, self.dataType.simpleString()) + + def __repr__(self): + return "StructField(%s,%s,%s)" % (self.name, self.dataType, + str(self.nullable).lower()) + + def jsonValue(self): + return {"name": self.name, + "type": self.dataType.jsonValue(), + "nullable": self.nullable, + "metadata": self.metadata} + + @classmethod + def fromJson(cls, json): + return StructField(json["name"], + _parse_datatype_json_value(json["type"]), + json["nullable"], + json["metadata"]) + + +class StructType(DataType): + """Struct type, consisting of a list of :class:`StructField`. + + This is the data type representing a :class:`Row`. + """ + + def __init__(self, fields): + """ + >>> struct1 = StructType([StructField("f1", StringType(), True)]) + >>> struct2 = StructType([StructField("f1", StringType(), True)]) + >>> struct1 == struct2 + True + >>> struct1 = StructType([StructField("f1", StringType(), True)]) + >>> struct2 = StructType([StructField("f1", StringType(), True), + ... StructField("f2", IntegerType(), False)]) + >>> struct1 == struct2 + False + """ + assert all(isinstance(f, DataType) for f in fields), "fields should be a list of DataType" + self.fields = fields + + def simpleString(self): + return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields)) + + def __repr__(self): + return ("StructType(List(%s))" % + ",".join(str(field) for field in self.fields)) + + def jsonValue(self): + return {"type": self.typeName(), + "fields": [f.jsonValue() for f in self.fields]} + + @classmethod + def fromJson(cls, json): + return StructType([StructField.fromJson(f) for f in json["fields"]]) + + +class UserDefinedType(DataType): + """User-defined type (UDT). + + .. note:: WARN: Spark Internal Use Only + """ + + @classmethod + def typeName(cls): + return cls.__name__.lower() + + @classmethod + def sqlType(cls): + """ + Underlying SQL storage type for this UDT. + """ + raise NotImplementedError("UDT must implement sqlType().") + + @classmethod + def module(cls): + """ + The Python module of the UDT. + """ + raise NotImplementedError("UDT must implement module().") + + @classmethod + def scalaUDT(cls): + """ + The class name of the paired Scala UDT. + """ + raise NotImplementedError("UDT must have a paired Scala UDT.") + + def serialize(self, obj): + """ + Converts the a user-type object into a SQL datum. + """ + raise NotImplementedError("UDT must implement serialize().") + + def deserialize(self, datum): + """ + Converts a SQL datum into a user-type object. + """ + raise NotImplementedError("UDT must implement deserialize().") + + def simpleString(self): + return 'udt' + + def json(self): + return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) + + def jsonValue(self): + schema = { + "type": "udt", + "class": self.scalaUDT(), + "pyClass": "%s.%s" % (self.module(), type(self).__name__), + "sqlType": self.sqlType().jsonValue() + } + return schema + + @classmethod + def fromJson(cls, json): + pyUDT = json["pyClass"] + split = pyUDT.rfind(".") + pyModule = pyUDT[:split] + pyClass = pyUDT[split+1:] + m = __import__(pyModule, globals(), locals(), [pyClass], -1) + UDT = getattr(m, pyClass) + return UDT() + + def __eq__(self, other): + return type(self) == type(other) + + +_all_primitive_types = dict((v.typeName(), v) + for v in globals().itervalues() + if type(v) is PrimitiveTypeSingleton and + v.__base__ == PrimitiveType) + + +_all_complex_types = dict((v.typeName(), v) + for v in [ArrayType, MapType, StructType]) + + +def _parse_datatype_json_string(json_string): + """Parses the given data type JSON string. + >>> import pickle + >>> def check_datatype(datatype): + ... pickled = pickle.loads(pickle.dumps(datatype)) + ... assert datatype == pickled + ... scala_datatype = sqlContext._ssql_ctx.parseDataType(datatype.json()) + ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) + ... assert datatype == python_datatype + >>> for cls in _all_primitive_types.values(): + ... check_datatype(cls()) + + >>> # Simple ArrayType. + >>> simple_arraytype = ArrayType(StringType(), True) + >>> check_datatype(simple_arraytype) + + >>> # Simple MapType. + >>> simple_maptype = MapType(StringType(), LongType()) + >>> check_datatype(simple_maptype) + + >>> # Simple StructType. + >>> simple_structtype = StructType([ + ... StructField("a", DecimalType(), False), + ... StructField("b", BooleanType(), True), + ... StructField("c", LongType(), True), + ... StructField("d", BinaryType(), False)]) + >>> check_datatype(simple_structtype) + + >>> # Complex StructType. + >>> complex_structtype = StructType([ + ... StructField("simpleArray", simple_arraytype, True), + ... StructField("simpleMap", simple_maptype, True), + ... StructField("simpleStruct", simple_structtype, True), + ... StructField("boolean", BooleanType(), False), + ... StructField("withMeta", DoubleType(), False, {"name": "age"})]) + >>> check_datatype(complex_structtype) + + >>> # Complex ArrayType. + >>> complex_arraytype = ArrayType(complex_structtype, True) + >>> check_datatype(complex_arraytype) + + >>> # Complex MapType. + >>> complex_maptype = MapType(complex_structtype, + ... complex_arraytype, False) + >>> check_datatype(complex_maptype) + + >>> check_datatype(ExamplePointUDT()) + >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> check_datatype(structtype_with_udt) + """ + return _parse_datatype_json_value(json.loads(json_string)) + + +_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)") + + +def _parse_datatype_json_value(json_value): + if type(json_value) is unicode: + if json_value in _all_primitive_types.keys(): + return _all_primitive_types[json_value]() + elif json_value == u'decimal': + return DecimalType() + elif _FIXED_DECIMAL.match(json_value): + m = _FIXED_DECIMAL.match(json_value) + return DecimalType(int(m.group(1)), int(m.group(2))) + else: + raise ValueError("Could not parse datatype: %s" % json_value) + else: + tpe = json_value["type"] + if tpe in _all_complex_types: + return _all_complex_types[tpe].fromJson(json_value) + elif tpe == 'udt': + return UserDefinedType.fromJson(json_value) + else: + raise ValueError("not supported type: %s" % tpe) + + +# Mapping Python types to Spark SQL DataType +_type_mappings = { + type(None): NullType, + bool: BooleanType, + int: LongType, + long: LongType, + float: DoubleType, + str: StringType, + unicode: StringType, + bytearray: BinaryType, + decimal.Decimal: DecimalType, + datetime.date: DateType, + datetime.datetime: TimestampType, + datetime.time: TimestampType, +} + + +def _infer_type(obj): + """Infer the DataType from obj + + >>> p = ExamplePoint(1.0, 2.0) + >>> _infer_type(p) + ExamplePointUDT + """ + if obj is None: + return NullType() + + if hasattr(obj, '__UDT__'): + return obj.__UDT__ + + dataType = _type_mappings.get(type(obj)) + if dataType is not None: + return dataType() + + if isinstance(obj, dict): + for key, value in obj.iteritems(): + if key is not None and value is not None: + return MapType(_infer_type(key), _infer_type(value), True) + else: + return MapType(NullType(), NullType(), True) + elif isinstance(obj, (list, array)): + for v in obj: + if v is not None: + return ArrayType(_infer_type(obj[0]), True) + else: + return ArrayType(NullType(), True) + else: + try: + return _infer_schema(obj) + except ValueError: + raise ValueError("not supported type: %s" % type(obj)) + + +def _infer_schema(row): + """Infer the schema from dict/namedtuple/object""" + if isinstance(row, dict): + items = sorted(row.items()) + + elif isinstance(row, (tuple, list)): + if hasattr(row, "_fields"): # namedtuple + items = zip(row._fields, tuple(row)) + elif hasattr(row, "__FIELDS__"): # Row + items = zip(row.__FIELDS__, tuple(row)) + else: + names = ['_%d' % i for i in range(1, len(row) + 1)] + items = zip(names, row) + + elif hasattr(row, "__dict__"): # object + items = sorted(row.__dict__.items()) + + else: + raise ValueError("Can not infer schema for type: %s" % type(row)) + + fields = [StructField(k, _infer_type(v), True) for k, v in items] + return StructType(fields) + + +def _need_python_to_sql_conversion(dataType): + """ + Checks whether we need python to sql conversion for the given type. + For now, only UDTs need this conversion. + + >>> _need_python_to_sql_conversion(DoubleType()) + False + >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), + ... StructField("values", ArrayType(DoubleType(), False), False)]) + >>> _need_python_to_sql_conversion(schema0) + False + >>> _need_python_to_sql_conversion(ExamplePointUDT()) + True + >>> schema1 = ArrayType(ExamplePointUDT(), False) + >>> _need_python_to_sql_conversion(schema1) + True + >>> schema2 = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> _need_python_to_sql_conversion(schema2) + True + """ + if isinstance(dataType, StructType): + return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields]) + elif isinstance(dataType, ArrayType): + return _need_python_to_sql_conversion(dataType.elementType) + elif isinstance(dataType, MapType): + return _need_python_to_sql_conversion(dataType.keyType) or \ + _need_python_to_sql_conversion(dataType.valueType) + elif isinstance(dataType, UserDefinedType): + return True + else: + return False + + +def _python_to_sql_converter(dataType): + """ + Returns a converter that converts a Python object into a SQL datum for the given type. + + >>> conv = _python_to_sql_converter(DoubleType()) + >>> conv(1.0) + 1.0 + >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False)) + >>> conv([1.0, 2.0]) + [1.0, 2.0] + >>> conv = _python_to_sql_converter(ExamplePointUDT()) + >>> conv(ExamplePoint(1.0, 2.0)) + [1.0, 2.0] + >>> schema = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> conv = _python_to_sql_converter(schema) + >>> conv((1.0, ExamplePoint(1.0, 2.0))) + (1.0, [1.0, 2.0]) + """ + if not _need_python_to_sql_conversion(dataType): + return lambda x: x + + if isinstance(dataType, StructType): + names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) + converters = map(_python_to_sql_converter, types) + + def converter(obj): + if isinstance(obj, dict): + return tuple(c(obj.get(n)) for n, c in zip(names, converters)) + elif isinstance(obj, tuple): + if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"): + return tuple(c(v) for c, v in zip(converters, obj)) + elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs + d = dict(obj) + return tuple(c(d.get(n)) for n, c in zip(names, converters)) + else: + return tuple(c(v) for c, v in zip(converters, obj)) + else: + raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) + return converter + elif isinstance(dataType, ArrayType): + element_converter = _python_to_sql_converter(dataType.elementType) + return lambda a: [element_converter(v) for v in a] + elif isinstance(dataType, MapType): + key_converter = _python_to_sql_converter(dataType.keyType) + value_converter = _python_to_sql_converter(dataType.valueType) + return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) + elif isinstance(dataType, UserDefinedType): + return lambda obj: dataType.serialize(obj) + else: + raise ValueError("Unexpected type %r" % dataType) + + +def _has_nulltype(dt): + """ Return whether there is NullType in `dt` or not """ + if isinstance(dt, StructType): + return any(_has_nulltype(f.dataType) for f in dt.fields) + elif isinstance(dt, ArrayType): + return _has_nulltype((dt.elementType)) + elif isinstance(dt, MapType): + return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType) + else: + return isinstance(dt, NullType) + + +def _merge_type(a, b): + if isinstance(a, NullType): + return b + elif isinstance(b, NullType): + return a + elif type(a) is not type(b): + # TODO: type cast (such as int -> long) + raise TypeError("Can not merge type %s and %s" % (a, b)) + + # same type + if isinstance(a, StructType): + nfs = dict((f.name, f.dataType) for f in b.fields) + fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()))) + for f in a.fields] + names = set([f.name for f in fields]) + for n in nfs: + if n not in names: + fields.append(StructField(n, nfs[n])) + return StructType(fields) + + elif isinstance(a, ArrayType): + return ArrayType(_merge_type(a.elementType, b.elementType), True) + + elif isinstance(a, MapType): + return MapType(_merge_type(a.keyType, b.keyType), + _merge_type(a.valueType, b.valueType), + True) + else: + return a + + +def _need_converter(dataType): + if isinstance(dataType, StructType): + return True + elif isinstance(dataType, ArrayType): + return _need_converter(dataType.elementType) + elif isinstance(dataType, MapType): + return _need_converter(dataType.keyType) or _need_converter(dataType.valueType) + elif isinstance(dataType, NullType): + return True + else: + return False + + +def _create_converter(dataType): + """Create an converter to drop the names of fields in obj """ + if not _need_converter(dataType): + return lambda x: x + + if isinstance(dataType, ArrayType): + conv = _create_converter(dataType.elementType) + return lambda row: map(conv, row) + + elif isinstance(dataType, MapType): + kconv = _create_converter(dataType.keyType) + vconv = _create_converter(dataType.valueType) + return lambda row: dict((kconv(k), vconv(v)) for k, v in row.iteritems()) + + elif isinstance(dataType, NullType): + return lambda x: None + + elif not isinstance(dataType, StructType): + return lambda x: x + + # dataType must be StructType + names = [f.name for f in dataType.fields] + converters = [_create_converter(f.dataType) for f in dataType.fields] + convert_fields = any(_need_converter(f.dataType) for f in dataType.fields) + + def convert_struct(obj): + if obj is None: + return + + if isinstance(obj, (tuple, list)): + if convert_fields: + return tuple(conv(v) for v, conv in zip(obj, converters)) + else: + return tuple(obj) + + if isinstance(obj, dict): + d = obj + elif hasattr(obj, "__dict__"): # object + d = obj.__dict__ + else: + raise ValueError("Unexpected obj: %s" % obj) + + if convert_fields: + return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) + else: + return tuple([d.get(name) for name in names]) + + return convert_struct + + +_BRACKETS = {'(': ')', '[': ']', '{': '}'} + + +def _split_schema_abstract(s): + """ + split the schema abstract into fields + + >>> _split_schema_abstract("a b c") + ['a', 'b', 'c'] + >>> _split_schema_abstract("a(a b)") + ['a(a b)'] + >>> _split_schema_abstract("a b[] c{a b}") + ['a', 'b[]', 'c{a b}'] + >>> _split_schema_abstract(" ") + [] + """ + + r = [] + w = '' + brackets = [] + for c in s: + if c == ' ' and not brackets: + if w: + r.append(w) + w = '' + else: + w += c + if c in _BRACKETS: + brackets.append(c) + elif c in _BRACKETS.values(): + if not brackets or c != _BRACKETS[brackets.pop()]: + raise ValueError("unexpected " + c) + + if brackets: + raise ValueError("brackets not closed: %s" % brackets) + if w: + r.append(w) + return r + + +def _parse_field_abstract(s): + """ + Parse a field in schema abstract + + >>> _parse_field_abstract("a") + StructField(a,NullType,true) + >>> _parse_field_abstract("b(c d)") + StructField(b,StructType(...c,NullType,true),StructField(d... + >>> _parse_field_abstract("a[]") + StructField(a,ArrayType(NullType,true),true) + >>> _parse_field_abstract("a{[]}") + StructField(a,MapType(NullType,ArrayType(NullType,true),true),true) + """ + if set(_BRACKETS.keys()) & set(s): + idx = min((s.index(c) for c in _BRACKETS if c in s)) + name = s[:idx] + return StructField(name, _parse_schema_abstract(s[idx:]), True) + else: + return StructField(s, NullType(), True) + + +def _parse_schema_abstract(s): + """ + parse abstract into schema + + >>> _parse_schema_abstract("a b c") + StructType...a...b...c... + >>> _parse_schema_abstract("a[b c] b{}") + StructType...a,ArrayType...b...c...b,MapType... + >>> _parse_schema_abstract("c{} d{a b}") + StructType...c,MapType...d,MapType...a...b... + >>> _parse_schema_abstract("a b(t)").fields[1] + StructField(b,StructType(List(StructField(t,NullType,true))),true) + """ + s = s.strip() + if not s: + return NullType() + + elif s.startswith('('): + return _parse_schema_abstract(s[1:-1]) + + elif s.startswith('['): + return ArrayType(_parse_schema_abstract(s[1:-1]), True) + + elif s.startswith('{'): + return MapType(NullType(), _parse_schema_abstract(s[1:-1])) + + parts = _split_schema_abstract(s) + fields = [_parse_field_abstract(p) for p in parts] + return StructType(fields) + + +def _infer_schema_type(obj, dataType): + """ + Fill the dataType with types inferred from obj + + >>> schema = _parse_schema_abstract("a b c d") + >>> row = (1, 1.0, "str", datetime.date(2014, 10, 10)) + >>> _infer_schema_type(row, schema) + StructType...LongType...DoubleType...StringType...DateType... + >>> row = [[1], {"key": (1, 2.0)}] + >>> schema = _parse_schema_abstract("a[] b{c d}") + >>> _infer_schema_type(row, schema) + StructType...a,ArrayType...b,MapType(StringType,...c,LongType... + """ + if dataType is NullType(): + return _infer_type(obj) + + if not obj: + return NullType() + + if isinstance(dataType, ArrayType): + eType = _infer_schema_type(obj[0], dataType.elementType) + return ArrayType(eType, True) + + elif isinstance(dataType, MapType): + k, v = obj.iteritems().next() + return MapType(_infer_schema_type(k, dataType.keyType), + _infer_schema_type(v, dataType.valueType)) + + elif isinstance(dataType, StructType): + fs = dataType.fields + assert len(fs) == len(obj), \ + "Obj(%s) have different length with fields(%s)" % (obj, fs) + fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True) + for o, f in zip(obj, fs)] + return StructType(fields) + + else: + raise ValueError("Unexpected dataType: %s" % dataType) + + +_acceptable_types = { + BooleanType: (bool,), + ByteType: (int, long), + ShortType: (int, long), + IntegerType: (int, long), + LongType: (int, long), + FloatType: (float,), + DoubleType: (float,), + DecimalType: (decimal.Decimal,), + StringType: (str, unicode), + BinaryType: (bytearray,), + DateType: (datetime.date,), + TimestampType: (datetime.datetime,), + ArrayType: (list, tuple, array), + MapType: (dict,), + StructType: (tuple, list), +} + + +def _verify_type(obj, dataType): + """ + Verify the type of obj against dataType, raise an exception if + they do not match. + + >>> _verify_type(None, StructType([])) + >>> _verify_type("", StringType()) + >>> _verify_type(0, LongType()) + >>> _verify_type(range(3), ArrayType(ShortType())) + >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + TypeError:... + >>> _verify_type({}, MapType(StringType(), IntegerType())) + >>> _verify_type((), StructType([])) + >>> _verify_type([], StructType([])) + >>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) + >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + """ + # all objects are nullable + if obj is None: + return + + if isinstance(dataType, UserDefinedType): + if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): + raise ValueError("%r is not an instance of type %r" % (obj, dataType)) + _verify_type(dataType.serialize(obj), dataType.sqlType()) + return + + _type = type(dataType) + assert _type in _acceptable_types, "unknown datatype: %s" % dataType + + # subclass of them can not be deserialized in JVM + if type(obj) not in _acceptable_types[_type]: + raise TypeError("%s can not accept object in type %s" + % (dataType, type(obj))) + + if isinstance(dataType, ArrayType): + for i in obj: + _verify_type(i, dataType.elementType) + + elif isinstance(dataType, MapType): + for k, v in obj.iteritems(): + _verify_type(k, dataType.keyType) + _verify_type(v, dataType.valueType) + + elif isinstance(dataType, StructType): + if len(obj) != len(dataType.fields): + raise ValueError("Length of object (%d) does not match with " + "length of fields (%d)" % (len(obj), len(dataType.fields))) + for v, f in zip(obj, dataType.fields): + _verify_type(v, f.dataType) + +_cached_cls = weakref.WeakValueDictionary() + + +def _restore_object(dataType, obj): + """ Restore object during unpickling. """ + # use id(dataType) as key to speed up lookup in dict + # Because of batched pickling, dataType will be the + # same object in most cases. + k = id(dataType) + cls = _cached_cls.get(k) + if cls is None: + # use dataType as key to avoid create multiple class + cls = _cached_cls.get(dataType) + if cls is None: + cls = _create_cls(dataType) + _cached_cls[dataType] = cls + _cached_cls[k] = cls + return cls(obj) + + +def _create_object(cls, v): + """ Create an customized object with class `cls`. """ + # datetime.date would be deserialized as datetime.datetime + # from java type, so we need to set it back. + if cls is datetime.date and isinstance(v, datetime.datetime): + return v.date() + return cls(v) if v is not None else v + + +def _create_getter(dt, i): + """ Create a getter for item `i` with schema """ + cls = _create_cls(dt) + + def getter(self): + return _create_object(cls, self[i]) + + return getter + + +def _has_struct_or_date(dt): + """Return whether `dt` is or has StructType/DateType in it""" + if isinstance(dt, StructType): + return True + elif isinstance(dt, ArrayType): + return _has_struct_or_date(dt.elementType) + elif isinstance(dt, MapType): + return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType) + elif isinstance(dt, DateType): + return True + elif isinstance(dt, UserDefinedType): + return True + return False + + +def _create_properties(fields): + """Create properties according to fields""" + ps = {} + for i, f in enumerate(fields): + name = f.name + if (name.startswith("__") and name.endswith("__") + or keyword.iskeyword(name)): + warnings.warn("field name %s can not be accessed in Python," + "use position to access it instead" % name) + if _has_struct_or_date(f.dataType): + # delay creating object until accessing it + getter = _create_getter(f.dataType, i) + else: + getter = itemgetter(i) + ps[name] = property(getter) + return ps + + +def _create_cls(dataType): + """ + Create an class by dataType + + The created class is similar to namedtuple, but can have nested schema. + + >>> schema = _parse_schema_abstract("a b c") + >>> row = (1, 1.0, "str") + >>> schema = _infer_schema_type(row, schema) + >>> obj = _create_cls(schema)(row) + >>> import pickle + >>> pickle.loads(pickle.dumps(obj)) + Row(a=1, b=1.0, c='str') + + >>> row = [[1], {"key": (1, 2.0)}] + >>> schema = _parse_schema_abstract("a[] b{c d}") + >>> schema = _infer_schema_type(row, schema) + >>> obj = _create_cls(schema)(row) + >>> pickle.loads(pickle.dumps(obj)) + Row(a=[1], b={'key': Row(c=1, d=2.0)}) + >>> pickle.loads(pickle.dumps(obj.a)) + [1] + >>> pickle.loads(pickle.dumps(obj.b)) + {'key': Row(c=1, d=2.0)} + """ + + if isinstance(dataType, ArrayType): + cls = _create_cls(dataType.elementType) + + def List(l): + if l is None: + return + return [_create_object(cls, v) for v in l] + + return List + + elif isinstance(dataType, MapType): + kcls = _create_cls(dataType.keyType) + vcls = _create_cls(dataType.valueType) + + def Dict(d): + if d is None: + return + return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items()) + + return Dict + + elif isinstance(dataType, DateType): + return datetime.date + + elif isinstance(dataType, UserDefinedType): + return lambda datum: dataType.deserialize(datum) + + elif not isinstance(dataType, StructType): + # no wrapper for primitive types + return lambda x: x + + class Row(tuple): + + """ Row in DataFrame """ + __DATATYPE__ = dataType + __FIELDS__ = tuple(f.name for f in dataType.fields) + __slots__ = () + + # create property for fast access + locals().update(_create_properties(dataType.fields)) + + def asDict(self): + """ Return as a dict """ + return dict((n, getattr(self, n)) for n in self.__FIELDS__) + + def __repr__(self): + # call collect __repr__ for nested objects + return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) + for n in self.__FIELDS__)) + + def __reduce__(self): + return (_restore_object, (self.__DATATYPE__, tuple(self))) + + return Row + + +def _create_row(fields, values): + row = Row(*values) + row.__FIELDS__ = fields + return row + + +class Row(tuple): + + """ + A row in L{DataFrame}. The fields in it can be accessed like attributes. + + Row can be used to create a row object by using named arguments, + the fields will be sorted by names. + + >>> row = Row(name="Alice", age=11) + >>> row + Row(age=11, name='Alice') + >>> row.name, row.age + ('Alice', 11) + + Row also can be used to create another Row like class, then it + could be used to create Row objects, such as + + >>> Person = Row("name", "age") + >>> Person + + >>> Person("Alice", 11) + Row(name='Alice', age=11) + """ + + def __new__(self, *args, **kwargs): + if args and kwargs: + raise ValueError("Can not use both args " + "and kwargs to create Row") + if args: + # create row class or objects + return tuple.__new__(self, args) + + elif kwargs: + # create row objects + names = sorted(kwargs.keys()) + row = tuple.__new__(self, [kwargs[n] for n in names]) + row.__FIELDS__ = names + return row + + else: + raise ValueError("No args or kwargs") + + def asDict(self): + """ + Return as an dict + """ + if not hasattr(self, "__FIELDS__"): + raise TypeError("Cannot convert a Row class into dict") + return dict(zip(self.__FIELDS__, self)) + + # let obect acs like class + def __call__(self, *args): + """create new Row object""" + return _create_row(self, args) + + def __getattr__(self, item): + if item.startswith("__"): + raise AttributeError(item) + try: + # it will be slow when it has many fields, + # but this will not be used in normal cases + idx = self.__FIELDS__.index(item) + return self[idx] + except IndexError: + raise AttributeError(item) + + def __reduce__(self): + if hasattr(self, "__FIELDS__"): + return (_create_row, (self.__FIELDS__, tuple(self))) + else: + return tuple.__reduce__(self) + + def __repr__(self): + if hasattr(self, "__FIELDS__"): + return "Row(%s)" % ", ".join("%s=%r" % (k, v) + for k, v in zip(self.__FIELDS__, self)) + else: + return "" % ", ".join(self) + + +def _test(): + import doctest + from pyspark.context import SparkContext + # let doctest run in pyspark.sql.types, so DataTypes can be picklable + import pyspark.sql.types + from pyspark.sql import Row, SQLContext + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + globs = pyspark.sql.types.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) + globs['ExamplePoint'] = ExamplePoint + globs['ExamplePointUDT'] = ExamplePointUDT + (failure_count, test_count) = doctest.testmod( + pyspark.sql.types, globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/status.py b/python/pyspark/status.py new file mode 100644 index 0000000000000..a6fa7dd3144d4 --- /dev/null +++ b/python/pyspark/status.py @@ -0,0 +1,96 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import namedtuple + +__all__ = ["SparkJobInfo", "SparkStageInfo", "StatusTracker"] + + +class SparkJobInfo(namedtuple("SparkJobInfo", "jobId stageIds status")): + """ + Exposes information about Spark Jobs. + """ + + +class SparkStageInfo(namedtuple("SparkStageInfo", + "stageId currentAttemptId name numTasks numActiveTasks " + "numCompletedTasks numFailedTasks")): + """ + Exposes information about Spark Stages. + """ + + +class StatusTracker(object): + """ + Low-level status reporting APIs for monitoring job and stage progress. + + These APIs intentionally provide very weak consistency semantics; + consumers of these APIs should be prepared to handle empty / missing + information. For example, a job's stage ids may be known but the status + API may not have any information about the details of those stages, so + `getStageInfo` could potentially return `None` for a valid stage id. + + To limit memory usage, these APIs only provide information on recent + jobs / stages. These APIs will provide information for the last + `spark.ui.retainedStages` stages and `spark.ui.retainedJobs` jobs. + """ + def __init__(self, jtracker): + self._jtracker = jtracker + + def getJobIdsForGroup(self, jobGroup=None): + """ + Return a list of all known jobs in a particular job group. If + `jobGroup` is None, then returns all known jobs that are not + associated with a job group. + + The returned list may contain running, failed, and completed jobs, + and may vary across invocations of this method. This method does + not guarantee the order of the elements in its result. + """ + return list(self._jtracker.getJobIdsForGroup(jobGroup)) + + def getActiveStageIds(self): + """ + Returns an array containing the ids of all active stages. + """ + return sorted(list(self._jtracker.getActiveStageIds())) + + def getActiveJobsIds(self): + """ + Returns an array containing the ids of all active jobs. + """ + return sorted((list(self._jtracker.getActiveJobIds()))) + + def getJobInfo(self, jobId): + """ + Returns a :class:`SparkJobInfo` object, or None if the job info + could not be found or was garbage collected. + """ + job = self._jtracker.getJobInfo(jobId) + if job is not None: + return SparkJobInfo(jobId, job.stageIds(), str(job.status())) + + def getStageInfo(self, stageId): + """ + Returns a :class:`SparkStageInfo` object, or None if the stage + info could not be found or was garbage collected. + """ + stage = self._jtracker.getStageInfo(stageId) + if stage is not None: + # TODO: fetch them in batch for better performance + attrs = [getattr(stage, f)() for f in SparkStageInfo._fields[1:]] + return SparkStageInfo(stageId, *attrs) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index d48f3598e33b2..2c73083c9f9a8 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -21,7 +21,7 @@ from py4j.java_gateway import java_import, JavaObject from pyspark import RDD, SparkConf -from pyspark.serializers import UTF8Deserializer, CloudPickleSerializer +from pyspark.serializers import NoOpSerializer, UTF8Deserializer, CloudPickleSerializer from pyspark.context import SparkContext from pyspark.storagelevel import StorageLevel from pyspark.streaming.dstream import DStream @@ -189,7 +189,16 @@ def awaitTermination(self, timeout=None): if timeout is None: self._jssc.awaitTermination() else: - self._jssc.awaitTermination(int(timeout * 1000)) + self._jssc.awaitTerminationOrTimeout(int(timeout * 1000)) + + def awaitTerminationOrTimeout(self, timeout): + """ + Wait for the execution to stop. Return `true` if it's stopped; or + throw the reported error during the execution; or `false` if the + waiting time elapsed before returning from the method. + @param timeout: time to wait in seconds + """ + self._jssc.awaitTerminationOrTimeout(int(timeout * 1000)) def stop(self, stopSparkContext=True, stopGraceFully=False): """ @@ -251,6 +260,20 @@ def textFileStream(self, directory): """ return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer()) + def binaryRecordsStream(self, directory, recordLength): + """ + Create an input stream that monitors a Hadoop-compatible file system + for new files and reads them as flat binary files with records of + fixed length. Files must be written to the monitored directory by "moving" + them from another location within the same file system. + File names starting with . are ignored. + + @param directory: Directory to load data from + @param recordLength: Length of each record in bytes + """ + return DStream(self._jssc.binaryRecordsStream(directory, recordLength), self, + NoOpSerializer()) + def _check_serializers(self, rdds): # make sure they have same serializer if len(set(rdd._jrdd_deserializer for rdd in rdds)) > 1: diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 2fe39392ff081..3fa42444239f7 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -578,7 +578,7 @@ def reduceFunc(t, a, b): if a is None: g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None)) else: - g = a.cogroup(b, numPartitions) + g = a.cogroup(b.partitionBy(numPartitions), numPartitions) g = g.mapValues(lambda (va, vb): (list(vb), list(va)[0] if len(va) else None)) state = g.mapValues(lambda (vs, s): updateFunc(vs, s)) return state.filter(lambda (k, v): v is not None) diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 19ad71f99d4d5..f083ed149effb 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -16,7 +16,7 @@ # from py4j.java_collections import MapConverter -from py4j.java_gateway import java_import, Py4JError +from py4j.java_gateway import java_import, Py4JError, Py4JJavaError from pyspark.storagelevel import StorageLevel from pyspark.serializers import PairDeserializer, NoOpSerializer @@ -50,8 +50,6 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, :param valueDecoder: A function used to decode value (default is utf8_decoder) :return: A DStream object """ - java_import(ssc._jvm, "org.apache.spark.streaming.kafka.KafkaUtils") - kafkaParams.update({ "zookeeper.connect": zkQuorum, "group.id": groupId, @@ -63,20 +61,34 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, jparam = MapConverter().convert(kafkaParams, ssc.sparkContext._gateway._gateway_client) jlevel = ssc._sc._getJavaStorageLevel(storageLevel) - def getClassByName(name): - return ssc._jvm.org.apache.spark.util.Utils.classForName(name) - try: - array = getClassByName("[B") - decoder = getClassByName("kafka.serializer.DefaultDecoder") - jstream = ssc._jvm.KafkaUtils.createStream(ssc._jssc, array, array, decoder, decoder, - jparam, jtopics, jlevel) - except Py4JError, e: + # Use KafkaUtilsPythonHelper to access Scala's KafkaUtils (see SPARK-6027) + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ + .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") + helper = helperClass.newInstance() + jstream = helper.createStream(ssc._jssc, jparam, jtopics, jlevel) + except Py4JJavaError, e: # TODO: use --jar once it also work on driver - if not e.message or 'call a package' in e.message: - print "No kafka package, please put the assembly jar into classpath:" - print " $ bin/spark-submit --driver-class-path external/kafka-assembly/target/" + \ - "scala-*/spark-streaming-kafka-assembly-*.jar" + if 'ClassNotFoundException' in str(e.java_exception): + print """ +________________________________________________________________________________________________ + + Spark Streaming's Kafka libraries not found in class path. Try one of the following. + + 1. Include the Kafka library and its dependencies with in the + spark-submit command as + + $ bin/spark-submit --packages org.apache.spark:spark-streaming-kafka:%s ... + + 2. Download the JAR of the artifact from Maven Central http://search.maven.org/, + Group Id = org.apache.spark, Artifact Id = spark-streaming-kafka-assembly, Version = %s. + Then, include the jar in the spark-submit command as + + $ bin/spark-submit --jars ... + +________________________________________________________________________________________________ + +""" % (ssc.sparkContext.version, ssc.sparkContext.version) raise e ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) stream = DStream(jstream, ssc, ser) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index a8d876d0fa3b3..608f8e26473a6 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -21,6 +21,7 @@ import operator import unittest import tempfile +import struct from pyspark.context import SparkConf, SparkContext, RDD from pyspark.streaming.context import StreamingContext @@ -455,6 +456,20 @@ def test_text_file_stream(self): self.wait_for(result, 2) self.assertEqual([range(10), range(10)], result) + def test_binary_records_stream(self): + d = tempfile.mkdtemp() + self.ssc = StreamingContext(self.sc, self.duration) + dstream = self.ssc.binaryRecordsStream(d, 10).map( + lambda v: struct.unpack("10b", str(v))) + result = self._collect(dstream, 2, block=False) + self.ssc.start() + for name in ('a', 'b'): + time.sleep(1) + with open(os.path.join(d, name), "wb") as f: + f.write(bytearray(range(10))) + self.wait_for(result, 2) + self.assertEqual([range(10), range(10)], map(lambda v: list(v[0]), result)) + def test_union(self): input = [range(i + 1) for i in range(3)] dstream = self.ssc.queueStream(input) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index c7d0622d65f25..c10f857c3d8ff 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -23,7 +23,6 @@ from fileinput import input from glob import glob import os -import pydoc import re import shutil import subprocess @@ -52,8 +51,6 @@ from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter -from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ - UserDefinedType, DoubleType from pyspark import shuffle from pyspark.profiler import BasicProfiler @@ -546,6 +543,12 @@ def test_zip_with_different_serializers(self): # regression test for bug in _reserializer() self.assertEqual(cnt, t.zip(rdd).count()) + def test_zip_with_different_object_sizes(self): + # regress test for SPARK-5973 + a = self.sc.parallelize(range(10000)).map(lambda i: '*' * i) + b = self.sc.parallelize(range(10000, 20000)).map(lambda i: '*' * i) + self.assertEqual(10000, a.zip(b).count()) + def test_zip_with_different_number_of_items(self): a = self.sc.parallelize(range(5), 2) # different number of partitions @@ -730,7 +733,6 @@ def test_multiple_python_java_RDD_conversions(self): (u'1', {u'director': u'David Lean'}), (u'2', {u'director': u'Andrew Dominik'}) ] - from pyspark.rdd import RDD data_rdd = self.sc.parallelize(data) data_java_rdd = data_rdd._to_java_object_rdd() data_python_rdd = self.sc._jvm.SerDe.javaToPython(data_java_rdd) @@ -743,6 +745,59 @@ def test_multiple_python_java_RDD_conversions(self): converted_rdd = RDD(data_python_rdd, self.sc) self.assertEqual(2, converted_rdd.count()) + def test_narrow_dependency_in_join(self): + rdd = self.sc.parallelize(range(10)).map(lambda x: (x, x)) + parted = rdd.partitionBy(2) + self.assertEqual(2, parted.union(parted).getNumPartitions()) + self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions()) + self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions()) + + self.sc.setJobGroup("test1", "test", True) + tracker = self.sc.statusTracker() + + d = sorted(parted.join(parted).collect()) + self.assertEqual(10, len(d)) + self.assertEqual((0, (0, 0)), d[0]) + jobId = tracker.getJobIdsForGroup("test1")[0] + self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds)) + + self.sc.setJobGroup("test2", "test", True) + d = sorted(parted.join(rdd).collect()) + self.assertEqual(10, len(d)) + self.assertEqual((0, (0, 0)), d[0]) + jobId = tracker.getJobIdsForGroup("test2")[0] + self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds)) + + self.sc.setJobGroup("test3", "test", True) + d = sorted(parted.cogroup(parted).collect()) + self.assertEqual(10, len(d)) + self.assertEqual([[0], [0]], map(list, d[0][1])) + jobId = tracker.getJobIdsForGroup("test3")[0] + self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds)) + + self.sc.setJobGroup("test4", "test", True) + d = sorted(parted.cogroup(rdd).collect()) + self.assertEqual(10, len(d)) + self.assertEqual([[0], [0]], map(list, d[0][1])) + jobId = tracker.getJobIdsForGroup("test4")[0] + self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds)) + + # Regression test for SPARK-6294 + def test_take_on_jrdd(self): + rdd = self.sc.parallelize(range(1 << 20)).map(lambda x: str(x)) + rdd._jrdd.first() + + def test_sortByKey_uses_all_partitions_not_only_first_and_last(self): + # Regression test for SPARK-5969 + seq = [(i * 59 % 101, i) for i in range(101)] # unsorted sequence + rdd = self.sc.parallelize(seq) + for ascending in [True, False]: + sort = rdd.sortByKey(ascending=ascending, numPartitions=5) + self.assertEqual(sort.collect(), sorted(seq, reverse=not ascending)) + sizes = sort.glom().map(len).collect() + for size in sizes: + self.assertGreater(size, 0) + class ProfilerTests(PySparkTestCase): @@ -795,264 +850,6 @@ def heavy_foo(x): rdd.foreach(heavy_foo) -class ExamplePointUDT(UserDefinedType): - """ - User-defined type (UDT) for ExamplePoint. - """ - - @classmethod - def sqlType(self): - return ArrayType(DoubleType(), False) - - @classmethod - def module(cls): - return 'pyspark.tests' - - @classmethod - def scalaUDT(cls): - return 'org.apache.spark.sql.test.ExamplePointUDT' - - def serialize(self, obj): - return [obj.x, obj.y] - - def deserialize(self, datum): - return ExamplePoint(datum[0], datum[1]) - - -class ExamplePoint: - """ - An example class to demonstrate UDT in Scala, Java, and Python. - """ - - __UDT__ = ExamplePointUDT() - - def __init__(self, x, y): - self.x = x - self.y = y - - def __repr__(self): - return "ExamplePoint(%s,%s)" % (self.x, self.y) - - def __str__(self): - return "(%s,%s)" % (self.x, self.y) - - def __eq__(self, other): - return isinstance(other, ExamplePoint) and \ - other.x == self.x and other.y == self.y - - -class SQLTests(ReusedPySparkTestCase): - - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.tempdir = tempfile.NamedTemporaryFile(delete=False) - os.unlink(cls.tempdir.name) - - @classmethod - def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - shutil.rmtree(cls.tempdir.name, ignore_errors=True) - - def setUp(self): - self.sqlCtx = SQLContext(self.sc) - self.testData = [Row(key=i, value=str(i)) for i in range(100)] - rdd = self.sc.parallelize(self.testData) - self.df = self.sqlCtx.inferSchema(rdd) - - def test_udf(self): - self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) - [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() - self.assertEqual(row[0], 5) - - def test_udf2(self): - self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType()) - self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test") - [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() - self.assertEqual(4, res[0]) - - def test_udf_with_array_type(self): - d = [Row(l=range(3), d={"key": range(5)})] - rdd = self.sc.parallelize(d) - self.sqlCtx.inferSchema(rdd).registerTempTable("test") - self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) - self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) - [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() - self.assertEqual(range(3), l1) - self.assertEqual(1, l2) - - def test_broadcast_in_udf(self): - bar = {"a": "aa", "b": "bb", "c": "abc"} - foo = self.sc.broadcast(bar) - self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') - [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect() - self.assertEqual("abc", res[0]) - [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() - self.assertEqual("", res[0]) - - def test_basic_functions(self): - rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - df = self.sqlCtx.jsonRDD(rdd) - df.count() - df.collect() - df.schema() - - # cache and checkpoint - self.assertFalse(df.is_cached) - df.persist() - df.unpersist() - df.cache() - self.assertTrue(df.is_cached) - self.assertEqual(2, df.count()) - - df.registerTempTable("temp") - df = self.sqlCtx.sql("select foo from temp") - df.count() - df.collect() - - def test_apply_schema_to_row(self): - df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) - df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema()) - self.assertEqual(df.collect(), df2.collect()) - - rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) - df3 = self.sqlCtx.applySchema(rdd, df.schema()) - self.assertEqual(10, df3.count()) - - def test_serialize_nested_array_and_map(self): - d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] - rdd = self.sc.parallelize(d) - df = self.sqlCtx.inferSchema(rdd) - row = df.head() - self.assertEqual(1, len(row.l)) - self.assertEqual(1, row.l[0].a) - self.assertEqual("2", row.d["key"].d) - - l = df.map(lambda x: x.l).first() - self.assertEqual(1, len(l)) - self.assertEqual('s', l[0].b) - - d = df.map(lambda x: x.d).first() - self.assertEqual(1, len(d)) - self.assertEqual(1.0, d["key"].c) - - row = df.map(lambda x: x.d["key"]).first() - self.assertEqual(1.0, row.c) - self.assertEqual("2", row.d) - - def test_infer_schema(self): - d = [Row(l=[], d={}), - Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] - rdd = self.sc.parallelize(d) - df = self.sqlCtx.inferSchema(rdd) - self.assertEqual([], df.map(lambda r: r.l).first()) - self.assertEqual([None, ""], df.map(lambda r: r.s).collect()) - df.registerTempTable("test") - result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") - self.assertEqual(1, result.head()[0]) - - df2 = self.sqlCtx.inferSchema(rdd, 1.0) - self.assertEqual(df.schema(), df2.schema()) - self.assertEqual({}, df2.map(lambda r: r.d).first()) - self.assertEqual([None, ""], df2.map(lambda r: r.s).collect()) - df2.registerTempTable("test2") - result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") - self.assertEqual(1, result.head()[0]) - - def test_struct_in_map(self): - d = [Row(m={Row(i=1): Row(s="")})] - rdd = self.sc.parallelize(d) - df = self.sqlCtx.inferSchema(rdd) - k, v = df.head().m.items()[0] - self.assertEqual(1, k.i) - self.assertEqual("", v.s) - - def test_convert_row_to_dict(self): - row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) - self.assertEqual(1, row.asDict()['l'][0].a) - rdd = self.sc.parallelize([row]) - df = self.sqlCtx.inferSchema(rdd) - df.registerTempTable("test") - row = self.sqlCtx.sql("select l, d from test").head() - self.assertEqual(1, row.asDict()["l"][0].a) - self.assertEqual(1.0, row.asDict()['d']['key'].c) - - def test_infer_schema_with_udt(self): - from pyspark.tests import ExamplePoint, ExamplePointUDT - row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - rdd = self.sc.parallelize([row]) - df = self.sqlCtx.inferSchema(rdd) - schema = df.schema() - field = [f for f in schema.fields if f.name == "point"][0] - self.assertEqual(type(field.dataType), ExamplePointUDT) - df.registerTempTable("labeled_point") - point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point - self.assertEqual(point, ExamplePoint(1.0, 2.0)) - - def test_apply_schema_with_udt(self): - from pyspark.tests import ExamplePoint, ExamplePointUDT - row = (1.0, ExamplePoint(1.0, 2.0)) - rdd = self.sc.parallelize([row]) - schema = StructType([StructField("label", DoubleType(), False), - StructField("point", ExamplePointUDT(), False)]) - df = self.sqlCtx.applySchema(rdd, schema) - point = df.head().point - self.assertEquals(point, ExamplePoint(1.0, 2.0)) - - def test_parquet_with_udt(self): - from pyspark.tests import ExamplePoint - row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - rdd = self.sc.parallelize([row]) - df0 = self.sqlCtx.inferSchema(rdd) - output_dir = os.path.join(self.tempdir.name, "labeled_point") - df0.saveAsParquetFile(output_dir) - df1 = self.sqlCtx.parquetFile(output_dir) - point = df1.head().point - self.assertEquals(point, ExamplePoint(1.0, 2.0)) - - def test_column_operators(self): - from pyspark.sql import Column, LongType - ci = self.df.key - cs = self.df.value - c = ci == cs - self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) - rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) - self.assertTrue(all(isinstance(c, Column) for c in rcc)) - cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs] - self.assertTrue(all(isinstance(c, Column) for c in cb)) - cbit = (ci & ci), (ci | ci), (ci ^ ci), (~ci) - self.assertTrue(all(isinstance(c, Column) for c in cbit)) - css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a') - self.assertTrue(all(isinstance(c, Column) for c in css)) - self.assertTrue(isinstance(ci.cast(LongType()), Column)) - - def test_column_select(self): - df = self.df - self.assertEqual(self.testData, df.select("*").collect()) - self.assertEqual(self.testData, df.select(df.key, df.value).collect()) - self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) - - def test_aggregator(self): - df = self.df - g = df.groupBy() - self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) - self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) - - from pyspark.sql import Aggregator as Agg - self.assertEqual((0, u'99'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first())) - self.assertTrue(95 < g.agg(Agg.approxCountDistinct(df.key)).first()[0]) - self.assertEqual(100, g.agg(Agg.countDistinct(df.value)).first()[0]) - - def test_help_command(self): - # Regression test for SPARK-5464 - rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - df = self.sqlCtx.jsonRDD(rdd) - # render_doc() reproduces the help() exception without printing output - pydoc.render_doc(df) - pydoc.render_doc(df.foo) - pydoc.render_doc(df.take(1)) - - class InputFormatTests(ReusedPySparkTestCase): @classmethod @@ -1665,31 +1462,59 @@ def setUp(self): def tearDown(self): shutil.rmtree(self.programDir) - def createTempFile(self, name, content): + def createTempFile(self, name, content, dir=None): """ Create a temp file with the given name and content and return its path. Strips leading spaces from content up to the first '|' in each line. """ pattern = re.compile(r'^ *\|', re.MULTILINE) content = re.sub(pattern, '', content.strip()) - path = os.path.join(self.programDir, name) + if dir is None: + path = os.path.join(self.programDir, name) + else: + os.makedirs(os.path.join(self.programDir, dir)) + path = os.path.join(self.programDir, dir, name) with open(path, "w") as f: f.write(content) return path - def createFileInZip(self, name, content): + def createFileInZip(self, name, content, ext=".zip", dir=None, zip_name=None): """ Create a zip archive containing a file with the given content and return its path. Strips leading spaces from content up to the first '|' in each line. """ pattern = re.compile(r'^ *\|', re.MULTILINE) content = re.sub(pattern, '', content.strip()) - path = os.path.join(self.programDir, name + ".zip") + if dir is None: + path = os.path.join(self.programDir, name + ext) + else: + path = os.path.join(self.programDir, dir, zip_name + ext) zip = zipfile.ZipFile(path, 'w') zip.writestr(name, content) zip.close() return path + def create_spark_package(self, artifact_name): + group_id, artifact_id, version = artifact_name.split(":") + self.createTempFile("%s-%s.pom" % (artifact_id, version), (""" + | + | + | 4.0.0 + | %s + | %s + | %s + | + """ % (group_id, artifact_id, version)).lstrip(), + os.path.join(group_id, artifact_id, version)) + self.createFileInZip("%s.py" % artifact_id, """ + |def myfunc(x): + | return x + 1 + """, ".jar", os.path.join(group_id, artifact_id, version), + "%s-%s" % (artifact_id, version)) + def test_single_script(self): """Submit and test a single script file""" script = self.createTempFile("test.py", """ @@ -1758,6 +1583,39 @@ def test_module_dependency_on_cluster(self): self.assertEqual(0, proc.returncode) self.assertIn("[2, 3, 4]", out) + def test_package_dependency(self): + """Submit and test a script with a dependency on a Spark Package""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + |from mylib import myfunc + | + |sc = SparkContext() + |print sc.parallelize([1, 2, 3]).map(myfunc).collect() + """) + self.create_spark_package("a:mylib:0.1") + proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, script], stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 3, 4]", out) + + def test_package_dependency_on_cluster(self): + """Submit and test a script with a dependency on a Spark Package on a cluster""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + |from mylib import myfunc + | + |sc = SparkContext() + |print sc.parallelize([1, 2, 3]).map(myfunc).collect() + """) + self.create_spark_package("a:mylib:0.1") + proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, "--master", + "local-cluster[1,1,512]", script], stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 3, 4]", out) + def test_single_script_on_cluster(self): """Submit and test a single script on a cluster""" script = self.createTempFile("test.py", """ @@ -1811,6 +1669,37 @@ def test_with_stop(self): sc.stop() self.assertEqual(SparkContext._active_spark_context, None) + def test_progress_api(self): + with SparkContext() as sc: + sc.setJobGroup('test_progress_api', '', True) + + rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100)) + t = threading.Thread(target=rdd.collect) + t.daemon = True + t.start() + # wait for scheduler to start + time.sleep(1) + + tracker = sc.statusTracker() + jobIds = tracker.getJobIdsForGroup('test_progress_api') + self.assertEqual(1, len(jobIds)) + job = tracker.getJobInfo(jobIds[0]) + self.assertEqual(1, len(job.stageIds)) + stage = tracker.getStageInfo(job.stageIds[0]) + self.assertEqual(rdd.getNumPartitions(), stage.numTasks) + + sc.cancelAllJobs() + t.join() + # wait for event listener to update the status + time.sleep(1) + + job = tracker.getJobInfo(jobIds[0]) + self.assertEqual('FAILED', job.status) + self.assertEqual([], tracker.getActiveJobsIds()) + self.assertEqual([], tracker.getActiveStageIds()) + + sc.stop() + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): diff --git a/python/run-tests b/python/run-tests index e91f1a875d356..a2c2f37a54eda 100755 --- a/python/run-tests +++ b/python/run-tests @@ -35,7 +35,7 @@ rm -rf metastore warehouse function run_test() { echo "Running test: $1" | tee -a $LOG_FILE - SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 2>&1 | tee -a $LOG_FILE + SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 > $LOG_FILE 2>&1 FAILED=$((PIPESTATUS[0]||$FAILED)) @@ -64,7 +64,11 @@ function run_core_tests() { function run_sql_tests() { echo "Run sql tests ..." - run_test "pyspark/sql.py" + run_test "pyspark/sql/types.py" + run_test "pyspark/sql/context.py" + run_test "pyspark/sql/dataframe.py" + run_test "pyspark/sql/functions.py" + run_test "pyspark/sql/tests.py" } function run_mllib_tests() { diff --git a/repl/pom.xml b/repl/pom.xml index bd39b90fd8714..18c5664ba7a2f 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.3.2-SNAPSHOT ../pom.xml @@ -66,7 +66,6 @@ org.apache.spark spark-sql_${scala.binary.version} ${project.version} - test org.scala-lang @@ -87,6 +86,11 @@ scalacheck_${scala.binary.version} test + + org.mockito + mockito-all + test + diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 72c1a989999b4..8dc0e0c965923 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -45,6 +45,7 @@ import scala.reflect.api.{Mirror, TypeCreator, Universe => ApiUniverse} import org.apache.spark.Logging import org.apache.spark.SparkConf import org.apache.spark.SparkContext +import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils /** The Scala interactive shell. It provides a read-eval-print loop @@ -130,6 +131,7 @@ class SparkILoop( // NOTE: Must be public for visibility @DeveloperApi var sparkContext: SparkContext = _ + var sqlContext: SQLContext = _ override def echoCommandMessage(msg: String) { intp.reporter printMessage msg @@ -1016,6 +1018,23 @@ class SparkILoop( sparkContext } + @DeveloperApi + def createSQLContext(): SQLContext = { + val name = "org.apache.spark.sql.hive.HiveContext" + val loader = Utils.getContextOrSparkClassLoader + try { + sqlContext = loader.loadClass(name).getConstructor(classOf[SparkContext]) + .newInstance(sparkContext).asInstanceOf[SQLContext] + logInfo("Created sql context (with Hive support)..") + } + catch { + case cnf: java.lang.ClassNotFoundException => + sqlContext = new SQLContext(sparkContext) + logInfo("Created sql context..") + } + sqlContext + } + private def getMaster(): String = { val master = this.master match { case Some(m) => m @@ -1045,15 +1064,16 @@ class SparkILoop( private def main(settings: Settings): Unit = process(settings) } -object SparkILoop { +object SparkILoop extends Logging { implicit def loopToInterpreter(repl: SparkILoop): SparkIMain = repl.intp private def echo(msg: String) = Console println msg def getAddedJars: Array[String] = { val envJars = sys.env.get("ADD_JARS") - val propJars = sys.props.get("spark.jars").flatMap { p => - if (p == "") None else Some(p) + if (envJars.isDefined) { + logWarning("ADD_JARS environment variable is deprecated, use --jar spark submit argument instead") } + val propJars = sys.props.get("spark.jars").flatMap { p => if (p == "") None else Some(p) } val jars = propJars.orElse(envJars).getOrElse("") Utils.resolveURIs(jars).split(",").filter(_.nonEmpty) } diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala index 99bd777c04fdb..05faef8786d2c 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala @@ -127,7 +127,17 @@ private[repl] trait SparkILoopInit { _sc } """) + command(""" + @transient val sqlContext = { + val _sqlContext = org.apache.spark.repl.Main.interp.createSQLContext() + println("SQL context available as sqlContext.") + _sqlContext + } + """) command("import org.apache.spark.SparkContext._") + command("import sqlContext.implicits._") + command("import sqlContext.sql") + command("import org.apache.spark.sql.functions._") } } diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index e594ad868ea1c..934daaeaafca1 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -121,9 +121,9 @@ class ReplSuite extends FunSuite { val output = runInterpreter("local", """ |var v = 7 - |sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => v).collect().reduceLeft(_+_) |v = 10 - |sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => v).collect().reduceLeft(_+_) """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -137,7 +137,7 @@ class ReplSuite extends FunSuite { |class C { |def foo = 5 |} - |sc.parallelize(1 to 10).map(x => (new C).foo).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => (new C).foo).collect().reduceLeft(_+_) """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -148,7 +148,7 @@ class ReplSuite extends FunSuite { val output = runInterpreter("local", """ |def double(x: Int) = x + x - |sc.parallelize(1 to 10).map(x => double(x)).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => double(x)).collect().reduceLeft(_+_) """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -160,9 +160,9 @@ class ReplSuite extends FunSuite { """ |var v = 7 |def getV() = v - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) |v = 10 - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -178,9 +178,9 @@ class ReplSuite extends FunSuite { """ |var array = new Array[Int](5) |val broadcastArray = sc.broadcast(array) - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() |array(0) = 5 - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -216,14 +216,14 @@ class ReplSuite extends FunSuite { """ |var v = 7 |def getV() = v - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) |v = 10 - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) |var array = new Array[Int](5) |val broadcastArray = sc.broadcast(array) - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() |array(0) = 5 - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -255,14 +255,14 @@ class ReplSuite extends FunSuite { assertDoesNotContain("Exception", output) } - test("SPARK-2576 importing SQLContext.createDataFrame.") { + test("SPARK-2576 importing SQLContext.implicits._") { // We need to use local-cluster to test this case. val output = runInterpreter("local-cluster[1,1,512]", """ |val sqlContext = new org.apache.spark.sql.SQLContext(sc) - |import sqlContext.createDataFrame + |import sqlContext.implicits._ |case class TestCaseClass(value: Int) - |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDataFrame.collect() + |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF().collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -275,26 +275,26 @@ class ReplSuite extends FunSuite { |val t = new TestClass |import t.testMethod |case class TestCaseClass(value: Int) - |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect + |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) } - if (System.getenv("MESOS_NATIVE_LIBRARY") != null) { + if (System.getenv("MESOS_NATIVE_JAVA_LIBRARY") != null) { test("running on Mesos") { val output = runInterpreter("localquiet", """ |var v = 7 |def getV() = v - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) |v = 10 - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) |var array = new Array[Int](5) |val broadcastArray = sc.broadcast(array) - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() |array(0) = 5 - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -309,10 +309,22 @@ class ReplSuite extends FunSuite { val output = runInterpreter("local[2]", """ |case class Foo(i: Int) - |val ret = sc.parallelize((1 to 100).map(Foo), 10).collect + |val ret = sc.parallelize((1 to 100).map(Foo), 10).collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) assertContains("ret: Array[Foo] = Array(Foo(1),", output) } + + test("collecting objects of class defined in repl - shuffling") { + val output = runInterpreter("local-cluster[1,1,512]", + """ + |case class Foo(i: Int) + |val list = List((1, Foo(1)), (1, Foo(2))) + |val ret = sc.parallelize(list).groupByKey().collect() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("ret: Array[(Int, Iterable[Foo])] = Array((1,", output) + } } diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index 69e44d4f916e1..2210fbaafeadb 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -19,6 +19,7 @@ package org.apache.spark.repl import org.apache.spark.util.Utils import org.apache.spark._ +import org.apache.spark.sql.SQLContext import scala.tools.nsc.Settings import scala.tools.nsc.interpreter.SparkILoop @@ -34,6 +35,7 @@ object Main extends Logging { "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", "-Yrepl-sync"), true) val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf)) var sparkContext: SparkContext = _ + var sqlContext: SQLContext = _ var interp = new SparkILoop // this is a public var because tests reset it. def main(args: Array[String]) { @@ -49,6 +51,9 @@ object Main extends Logging { def getAddedJars: Array[String] = { val envJars = sys.env.get("ADD_JARS") + if (envJars.isDefined) { + logWarning("ADD_JARS environment variable is deprecated, use --jar spark submit argument instead") + } val propJars = sys.props.get("spark.jars").flatMap { p => if (p == "") None else Some(p) } val jars = propJars.orElse(envJars).getOrElse("") Utils.resolveURIs(jars).split(",").filter(_.nonEmpty) @@ -74,6 +79,22 @@ object Main extends Logging { sparkContext } + def createSQLContext(): SQLContext = { + val name = "org.apache.spark.sql.hive.HiveContext" + val loader = Utils.getContextOrSparkClassLoader + try { + sqlContext = loader.loadClass(name).getConstructor(classOf[SparkContext]) + .newInstance(sparkContext).asInstanceOf[SQLContext] + logInfo("Created sql context (with Hive support)..") + } + catch { + case cnf: java.lang.ClassNotFoundException => + sqlContext = new SQLContext(sparkContext) + logInfo("Created sql context..") + } + sqlContext + } + private def getMaster: String = { val master = { val envMaster = sys.env.get("MASTER") diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 250727305970d..7a5e94da5cbf3 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -66,8 +66,18 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter) println("Spark context available as sc.") _sc } - """) + """) + command( """ + @transient val sqlContext = { + val _sqlContext = org.apache.spark.repl.Main.createSQLContext() + println("SQL context available as sqlContext.") + _sqlContext + } + """) command("import org.apache.spark.SparkContext._") + command("import sqlContext.implicits._") + command("import sqlContext.sql") + command("import org.apache.spark.sql.functions._") } } diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index f966f25c5a14c..fbef5b25ba688 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -128,9 +128,9 @@ class ReplSuite extends FunSuite { val output = runInterpreter("local", """ |var v = 7 - |sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => v).collect().reduceLeft(_+_) |v = 10 - |sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => v).collect().reduceLeft(_+_) """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -144,7 +144,7 @@ class ReplSuite extends FunSuite { |class C { |def foo = 5 |} - |sc.parallelize(1 to 10).map(x => (new C).foo).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => (new C).foo).collect().reduceLeft(_+_) """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -155,7 +155,7 @@ class ReplSuite extends FunSuite { val output = runInterpreter("local", """ |def double(x: Int) = x + x - |sc.parallelize(1 to 10).map(x => double(x)).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => double(x)).collect().reduceLeft(_+_) """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -167,9 +167,9 @@ class ReplSuite extends FunSuite { """ |var v = 7 |def getV() = v - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) |v = 10 - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -185,9 +185,9 @@ class ReplSuite extends FunSuite { """ |var array = new Array[Int](5) |val broadcastArray = sc.broadcast(array) - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() |array(0) = 5 - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -224,14 +224,14 @@ class ReplSuite extends FunSuite { """ |var v = 7 |def getV() = v - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) |v = 10 - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) |var array = new Array[Int](5) |val broadcastArray = sc.broadcast(array) - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() |array(0) = 5 - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -263,14 +263,14 @@ class ReplSuite extends FunSuite { assertDoesNotContain("Exception", output) } - test("SPARK-2576 importing SQLContext.createSchemaRDD.") { + test("SPARK-2576 importing SQLContext.createDataFrame.") { // We need to use local-cluster to test this case. val output = runInterpreter("local-cluster[1,1,512]", """ |val sqlContext = new org.apache.spark.sql.SQLContext(sc) - |import sqlContext.createSchemaRDD + |import sqlContext.implicits._ |case class TestCaseClass(value: Int) - |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toSchemaRDD.collect + |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF().collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -283,26 +283,26 @@ class ReplSuite extends FunSuite { |val t = new TestClass |import t.testMethod |case class TestCaseClass(value: Int) - |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect + |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) } - if (System.getenv("MESOS_NATIVE_LIBRARY") != null) { + if (System.getenv("MESOS_NATIVE_JAVA_LIBRARY") != null) { test("running on Mesos") { val output = runInterpreter("localquiet", """ |var v = 7 |def getV() = v - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) |v = 10 - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) |var array = new Array[Int](5) |val broadcastArray = sc.broadcast(array) - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() |array(0) = 5 - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -317,10 +317,22 @@ class ReplSuite extends FunSuite { val output = runInterpreter("local[2]", """ |case class Foo(i: Int) - |val ret = sc.parallelize((1 to 100).map(Foo), 10).collect + |val ret = sc.parallelize((1 to 100).map(Foo), 10).collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) assertContains("ret: Array[Foo] = Array(Foo(1),", output) } + + test("collecting objects of class defined in repl - shuffling") { + val output = runInterpreter("local-cluster[1,1,512]", + """ + |case class Foo(i: Int) + |val list = List((1, Foo(1)), (1, Foo(2))) + |val ret = sc.parallelize(list).groupByKey().collect() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("ret: Array[(Int, Iterable[Foo])] = Array((1,", output) + } } diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 9805609120005..004941d5f50ae 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -17,9 +17,10 @@ package org.apache.spark.repl -import java.io.{ByteArrayOutputStream, InputStream, FileNotFoundException} -import java.net.{URI, URL, URLEncoder} -import java.util.concurrent.{Executors, ExecutorService} +import java.io.{IOException, ByteArrayOutputStream, InputStream} +import java.net.{HttpURLConnection, URI, URL, URLEncoder} + +import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} @@ -43,6 +44,9 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader val parentLoader = new ParentClassLoader(parent) + // Allows HTTP connect and read timeouts to be controlled for testing / debugging purposes + private[repl] var httpUrlConnectionTimeoutMillis: Int = -1 + // Hadoop FileSystem object for our URI, if it isn't using HTTP var fileSystem: FileSystem = { if (Set("http", "https", "ftp").contains(uri.getScheme)) { @@ -71,30 +75,66 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader } } + private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = { + val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { + val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) + val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager) + newuri.toURL + } else { + new URL(classUri + "/" + urlEncode(pathInDirectory)) + } + val connection: HttpURLConnection = Utils.setupSecureURLConnection(url.openConnection(), + SparkEnv.get.securityManager).asInstanceOf[HttpURLConnection] + // Set the connection timeouts (for testing purposes) + if (httpUrlConnectionTimeoutMillis != -1) { + connection.setConnectTimeout(httpUrlConnectionTimeoutMillis) + connection.setReadTimeout(httpUrlConnectionTimeoutMillis) + } + connection.connect() + try { + if (connection.getResponseCode != 200) { + // Close the error stream so that the connection is eligible for re-use + try { + connection.getErrorStream.close() + } catch { + case ioe: IOException => + logError("Exception while closing error stream", ioe) + } + throw new ClassNotFoundException(s"Class file not found at URL $url") + } else { + connection.getInputStream + } + } catch { + case NonFatal(e) if !e.isInstanceOf[ClassNotFoundException] => + connection.disconnect() + throw e + } + } + + private def getClassFileInputStreamFromFileSystem(pathInDirectory: String): InputStream = { + val path = new Path(directory, pathInDirectory) + if (fileSystem.exists(path)) { + fileSystem.open(path) + } else { + throw new ClassNotFoundException(s"Class file not found at path $path") + } + } + def findClassLocally(name: String): Option[Class[_]] = { + val pathInDirectory = name.replace('.', '/') + ".class" + var inputStream: InputStream = null try { - val pathInDirectory = name.replace('.', '/') + ".class" - val inputStream = { + inputStream = { if (fileSystem != null) { - fileSystem.open(new Path(directory, pathInDirectory)) + getClassFileInputStreamFromFileSystem(pathInDirectory) } else { - val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { - val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) - val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager) - newuri.toURL - } else { - new URL(classUri + "/" + urlEncode(pathInDirectory)) - } - - Utils.setupSecureURLConnection(url.openConnection(), SparkEnv.get.securityManager) - .getInputStream + getClassFileInputStreamFromHttpServer(pathInDirectory) } } val bytes = readAndTransformClass(name, inputStream) - inputStream.close() Some(defineClass(name, bytes, 0, bytes.length)) } catch { - case e: FileNotFoundException => + case e: ClassNotFoundException => // We did not find the class logDebug(s"Did not load class $name from REPL class server at $uri", e) None @@ -102,6 +142,15 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader // Something bad happened while checking if the class exists logError(s"Failed to check existence of class $name on REPL class server at $uri", e) None + } finally { + if (inputStream != null) { + try { + inputStream.close() + } catch { + case e: Exception => + logError("Exception while closing inputStream", e) + } + } } } diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index 6a79e76a34db8..c709cde740748 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -20,13 +20,25 @@ package org.apache.spark.repl import java.io.File import java.net.{URL, URLClassLoader} +import scala.concurrent.duration._ +import scala.language.implicitConversions +import scala.language.postfixOps + import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite +import org.scalatest.concurrent.Interruptor +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.mock.MockitoSugar +import org.mockito.Mockito._ -import org.apache.spark.{SparkConf, TestUtils} +import org.apache.spark._ import org.apache.spark.util.Utils -class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { +class ExecutorClassLoaderSuite + extends FunSuite + with BeforeAndAfterAll + with MockitoSugar + with Logging { val childClassNames = List("ReplFakeClass1", "ReplFakeClass2") val parentClassNames = List("ReplFakeClass1", "ReplFakeClass2", "ReplFakeClass3") @@ -34,6 +46,7 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { var tempDir2: File = _ var url1: String = _ var urls2: Array[URL] = _ + var classServer: HttpServer = _ override def beforeAll() { super.beforeAll() @@ -47,8 +60,12 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { override def afterAll() { super.afterAll() + if (classServer != null) { + classServer.stop() + } Utils.deleteRecursively(tempDir1) Utils.deleteRecursively(tempDir2) + SparkEnv.set(null) } test("child first") { @@ -83,4 +100,53 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { } } + test("failing to fetch classes from HTTP server should not leak resources (SPARK-6209)") { + // This is a regression test for SPARK-6209, a bug where each failed attempt to load a class + // from the driver's class server would leak a HTTP connection, causing the class server's + // thread / connection pool to be exhausted. + val conf = new SparkConf() + val securityManager = new SecurityManager(conf) + classServer = new HttpServer(conf, tempDir1, securityManager) + classServer.start() + // ExecutorClassLoader uses SparkEnv's SecurityManager, so we need to mock this + val mockEnv = mock[SparkEnv] + when(mockEnv.securityManager).thenReturn(securityManager) + SparkEnv.set(mockEnv) + // Create an ExecutorClassLoader that's configured to load classes from the HTTP server + val parentLoader = new URLClassLoader(Array.empty, null) + val classLoader = new ExecutorClassLoader(conf, classServer.uri, parentLoader, false) + classLoader.httpUrlConnectionTimeoutMillis = 500 + // Check that this class loader can actually load classes that exist + val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() + val fakeClassVersion = fakeClass.toString + assert(fakeClassVersion === "1") + // Try to perform a full GC now, since GC during the test might mask resource leaks + System.gc() + // When the original bug occurs, the test thread becomes blocked in a classloading call + // and does not respond to interrupts. Therefore, use a custom ScalaTest interruptor to + // shut down the HTTP server when the test times out + val interruptor: Interruptor = new Interruptor { + override def apply(thread: Thread): Unit = { + classServer.stop() + classServer = null + thread.interrupt() + } + } + def tryAndFailToLoadABunchOfClasses(): Unit = { + // The number of trials here should be much larger than Jetty's thread / connection limit + // in order to expose thread or connection leaks + for (i <- 1 to 1000) { + if (Thread.currentThread().isInterrupted) { + throw new InterruptedException() + } + // Incorporate the iteration number into the class name in order to avoid any response + // caching that might be added in the future + intercept[ClassNotFoundException] { + classLoader.loadClass(s"ReplFakeClassDoesNotExist$i").newInstance() + } + } + } + failAfter(10 seconds)(tryAndFailToLoadABunchOfClasses())(interruptor) + } + } diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh index 89608bc41b71d..5e812a1d91c6b 100755 --- a/sbin/spark-daemon.sh +++ b/sbin/spark-daemon.sh @@ -129,8 +129,9 @@ case $option in mkdir -p "$SPARK_PID_DIR" if [ -f $pid ]; then - if kill -0 `cat $pid` > /dev/null 2>&1; then - echo $command running as process `cat $pid`. Stop it first. + TARGET_ID="$(cat "$pid")" + if [[ $(ps -p "$TARGET_ID" -o args=) =~ $command ]]; then + echo "$command running as process $TARGET_ID. Stop it first." exit 1 fi fi @@ -141,7 +142,7 @@ case $option in fi spark_rotate_log "$log" - echo starting $command, logging to $log + echo "starting $command, logging to $log" if [ $option == spark-submit ]; then source "$SPARK_HOME"/bin/utils.sh gatherSparkSubmitOpts "$@" @@ -154,7 +155,7 @@ case $option in echo $newpid > $pid sleep 2 # Check if the process has died; in that case we'll tail the log so the user can see - if ! kill -0 $newpid >/dev/null 2>&1; then + if [[ ! $(ps -p "$newpid" -o args=) =~ $command ]]; then echo "failed to launch $command:" tail -2 "$log" | sed 's/^/ /' echo "full log in $log" @@ -164,14 +165,15 @@ case $option in (stop) if [ -f $pid ]; then - if kill -0 `cat $pid` > /dev/null 2>&1; then - echo stopping $command - kill `cat $pid` + TARGET_ID="$(cat "$pid")" + if [[ $(ps -p "$TARGET_ID" -o comm=) =~ "java" ]]; then + echo "stopping $command" + kill "$TARGET_ID" && rm -f "$pid" else - echo no $command to stop + echo "no $command to stop" fi else - echo no $command to stop + echo "no $command to stop" fi ;; diff --git a/sql/README.md b/sql/README.md index 61a20916a92aa..0125de61f875a 100644 --- a/sql/README.md +++ b/sql/README.md @@ -22,7 +22,8 @@ export HADOOP_HOME="/hadoop-1.0.4" Using the console ================= -An interactive scala console can be invoked by running `build/sbt hive/console`. From here you can execute queries and inspect the various stages of query optimization. +An interactive scala console can be invoked by running `build/sbt hive/console`. +From here you can execute queries with HiveQl and manipulate DataFrame by using DSL. ```scala catalyst$ build/sbt hive/console @@ -35,46 +36,27 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.execution import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.TestHive._ +import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.types._ -Welcome to Scala version 2.10.4 (Java HotSpot(TM) 64-Bit Server VM, Java 1.7.0_45). +import org.apache.spark.sql.parquet.ParquetTestData Type in expressions to have them evaluated. Type :help for more information. scala> val query = sql("SELECT * FROM (SELECT * FROM src) a") -query: org.apache.spark.sql.DataFrame = -== Query Plan == -== Physical Plan == -HiveTableScan [key#10,value#11], (MetastoreRelation default, src, None), None +query: org.apache.spark.sql.DataFrame = org.apache.spark.sql.DataFrame@74448eed ``` -Query results are RDDs and can be operated as such. +Query results are `DataFrames` and can be operated as such. ``` scala> query.collect() res2: Array[org.apache.spark.sql.Row] = Array([238,val_238], [86,val_86], [311,val_311], [27,val_27]... ``` -You can also build further queries on top of these RDDs using the query DSL. +You can also build further queries on top of these `DataFrames` using the query DSL. ``` -scala> query.where('key === 100).collect() -res3: Array[org.apache.spark.sql.Row] = Array([100,val_100], [100,val_100]) -``` - -From the console you can even write rules that transform query plans. For example, the above query has redundant project operators that aren't doing anything. This redundancy can be eliminated using the `transform` function that is available on all [`TreeNode`](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala) objects. -```scala -scala> query.queryExecution.analyzed -res4: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan = -Project [key#10,value#11] - Project [key#10,value#11] - MetastoreRelation default, src, None - - -scala> query.queryExecution.analyzed transform { - | case Project(projectList, child) if projectList == child.output => child - | } -res5: res17: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan = -Project [key#10,value#11] - MetastoreRelation default, src, None +scala> query.where(query("key") > 30).select(avg(query("key"))).collect() +res3: Array[org.apache.spark.sql.Row] = Array([274.79025423728814]) ``` diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index a1947fb022e54..9804817e3d1b1 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -21,8 +21,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.3.2-SNAPSHOT ../../pom.xml diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala new file mode 100644 index 0000000000000..34fedead44db3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * Thrown when a query fails to analyze, usually because the query itself is invalid. + */ +@DeveloperApi +class AnalysisException protected[sql] ( + val message: String, + val line: Option[Int] = None, + val startPosition: Option[Int] = None) + extends Exception with Serializable { + + def withPosition(line: Option[Int], startPosition: Option[Int]) = { + val newException = new AnalysisException(message, line, startPosition) + newException.setStackTrace(getStackTrace) + newException + } + + override def getMessage: String = { + val lineAnnotation = line.map(l => s" line $l").getOrElse("") + val positionAnnotation = startPosition.map(p => s" pos $p").getOrElse("") + s"$message;$lineAnnotation$positionAnnotation" + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 41bb4f012f2e1..d794f034f5578 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import scala.util.hashing.MurmurHash3 import org.apache.spark.sql.catalyst.expressions.GenericRow - +import org.apache.spark.sql.types.{StructType, DateUtils} object Row { /** @@ -122,6 +122,11 @@ trait Row extends Serializable { /** Number of elements in the Row. */ def length: Int + /** + * Schema for the row. + */ + def schema: StructType = null + /** * Returns the value at position i. If the value is null, null is returned. The following * is a mapping between Spark SQL types and return types: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index 366be00473d1c..3823584287741 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -26,7 +26,7 @@ import scala.util.parsing.input.CharArrayReader.EofCh import org.apache.spark.sql.catalyst.plans.logical._ private[sql] object KeywordNormalizer { - def apply(str: String) = str.toLowerCase() + def apply(str: String): String = str.toLowerCase() } private[sql] abstract class AbstractSparkSQLParser @@ -42,7 +42,7 @@ private[sql] abstract class AbstractSparkSQLParser } protected case class Keyword(str: String) { - def normalize = KeywordNormalizer(str) + def normalize: String = KeywordNormalizer(str) def parser: Parser[String] = normalize } @@ -81,7 +81,7 @@ private[sql] abstract class AbstractSparkSQLParser class SqlLexical extends StdLexical { case class FloatLit(chars: String) extends Token { - override def toString = chars + override def toString: String = chars } /* This is a work around to support the lazy setting */ @@ -120,7 +120,7 @@ class SqlLexical extends StdLexical { | failure("illegal character") ) - override def identChar = letter | elem('_') + override def identChar: Parser[Elem] = letter | elem('_') override def whitespace: Parser[Any] = ( whitespaceChar diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index e0db587efb08d..8bfd0471d9c7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.catalyst -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import org.apache.spark.util.Utils -import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types._ - /** * A default version of ScalaReflection that uses the runtime universe. */ @@ -72,6 +71,12 @@ trait ScalaReflection { }.toArray) case (d: BigDecimal, _) => Decimal(d) case (d: java.math.BigDecimal, _) => Decimal(d) + case (d: java.sql.Date, _) => DateUtils.fromJavaDate(d) + case (r: Row, structType: StructType) => + new GenericRow( + r.toSeq.zip(structType.fields).map { case (elem, field) => + convertToCatalyst(elem, field.dataType) + }.toArray) case (other, _) => other } @@ -85,14 +90,15 @@ trait ScalaReflection { } case (r: Row, s: StructType) => convertRowToScala(r, s) case (d: Decimal, _: DecimalType) => d.toJavaBigDecimal + case (i: Int, DateType) => DateUtils.toJavaDate(i) case (other, _) => other } def convertRowToScala(r: Row, schema: StructType): Row = { // TODO: This is very slow!!! - new GenericRow( + new GenericRowWithSchema( r.toSeq.zip(schema.fields.map(_.dataType)) - .map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray) + .map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray, schema) } /** Returns a Sequence of attributes for the given case class type. */ @@ -102,10 +108,11 @@ trait ScalaReflection { } /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T]) + def schemaFor[T: TypeTag]: Schema = + ScalaReflectionLock.synchronized { schemaFor(typeOf[T]) } /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor(tpe: `Type`): Schema = { + def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized { val className: String = tpe.erasure.typeSymbol.asClass.fullName tpe match { case t if Utils.classIsLoadable(className) && @@ -120,6 +127,21 @@ trait ScalaReflection { case t if t <:< typeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t Schema(schemaFor(optType).dataType, nullable = true) + // Need to decide if we actually need a special type here. + case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) + case t if t <:< typeOf[Array[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) + Schema(ArrayType(dataType, containsNull = nullable), nullable = true) + case t if t <:< typeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) + Schema(ArrayType(dataType, containsNull = nullable), nullable = true) + case t if t <:< typeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + val Schema(valueDataType, valueNullable) = schemaFor(valueType) + Schema(MapType(schemaFor(keyType).dataType, + valueDataType, valueContainsNull = valueNullable), nullable = true) case t if t <:< typeOf[Product] => val formalTypeArgs = t.typeSymbol.asClass.typeParams val TypeRef(_, _, actualTypeArgs) = t @@ -142,24 +164,9 @@ trait ScalaReflection { schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)) StructField(p.name.toString, dataType, nullable) }), nullable = true) - // Need to decide if we actually need a special type here. - case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) - case t if t <:< typeOf[Array[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val Schema(dataType, nullable) = schemaFor(elementType) - Schema(ArrayType(dataType, containsNull = nullable), nullable = true) - case t if t <:< typeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val Schema(dataType, nullable) = schemaFor(elementType) - Schema(ArrayType(dataType, containsNull = nullable), nullable = true) - case t if t <:< typeOf[Map[_, _]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - Schema(MapType(schemaFor(keyType).dataType, - valueDataType, valueContainsNull = valueNullable), nullable = true) case t if t <:< typeOf[String] => Schema(StringType, nullable = true) case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) - case t if t <:< typeOf[Date] => Schema(DateType, nullable = true) + case t if t <:< typeOf[java.sql.Date] => Schema(DateType, nullable = true) case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) case t if t <:< typeOf[java.math.BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true) @@ -177,6 +184,8 @@ trait ScalaReflection { case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + case other => + throw new UnsupportedOperationException(s"Schema for type $other is not supported") } } @@ -191,7 +200,7 @@ trait ScalaReflection { case obj: LongType.JvmType => LongType case obj: FloatType.JvmType => FloatType case obj: DoubleType.JvmType => DoubleType - case obj: DateType.JvmType => DateType + case obj: java.sql.Date => DateType case obj: java.math.BigDecimal => DecimalType.Unlimited case obj: Decimal => DecimalType.Unlimited case obj: TimestampType.JvmType => TimestampType @@ -210,7 +219,7 @@ trait ScalaReflection { */ def asRelation: LocalRelation = { val output = attributesFor[A] - LocalRelation(output, data) + LocalRelation.fromProduct(output, data) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala old mode 100755 new mode 100644 index 25e639d390da0..b176f7e729a42 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -35,7 +35,17 @@ import org.apache.spark.sql.types._ * This is currently included mostly for illustrative purposes. Users wanting more complete support * for a SQL like language should checkout the HiveQL support in the sql/hive sub-project. */ -class SqlParser extends AbstractSparkSQLParser { +class SqlParser extends AbstractSparkSQLParser with DataTypeParser { + + def parseExpression(input: String): Expression = { + // Initialize the Keywords. + lexical.initialize(reservedWords) + phrase(projection)(new lexical.Scanner(input)) match { + case Success(plan, _) => plan + case failureOrError => sys.error(failureOrError.toString) + } + } + // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` // properties via reflection the class in runtime for constructing the SqlLexical object protected val ABS = Keyword("ABS") @@ -47,15 +57,12 @@ class SqlParser extends AbstractSparkSQLParser { protected val AVG = Keyword("AVG") protected val BETWEEN = Keyword("BETWEEN") protected val BY = Keyword("BY") - protected val CACHE = Keyword("CACHE") protected val CASE = Keyword("CASE") protected val CAST = Keyword("CAST") protected val COALESCE = Keyword("COALESCE") protected val COUNT = Keyword("COUNT") - protected val DECIMAL = Keyword("DECIMAL") protected val DESC = Keyword("DESC") protected val DISTINCT = Keyword("DISTINCT") - protected val DOUBLE = Keyword("DOUBLE") protected val ELSE = Keyword("ELSE") protected val END = Keyword("END") protected val EXCEPT = Keyword("EXCEPT") @@ -94,13 +101,11 @@ class SqlParser extends AbstractSparkSQLParser { protected val SELECT = Keyword("SELECT") protected val SEMI = Keyword("SEMI") protected val SQRT = Keyword("SQRT") - protected val STRING = Keyword("STRING") protected val SUBSTR = Keyword("SUBSTR") protected val SUBSTRING = Keyword("SUBSTRING") protected val SUM = Keyword("SUM") protected val TABLE = Keyword("TABLE") protected val THEN = Keyword("THEN") - protected val TIMESTAMP = Keyword("TIMESTAMP") protected val TRUE = Keyword("TRUE") protected val UNION = Keyword("UNION") protected val UPPER = Keyword("UPPER") @@ -134,7 +139,7 @@ class SqlParser extends AbstractSparkSQLParser { sortType.? ~ (LIMIT ~> expression).? ^^ { case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => - val base = r.getOrElse(NoRelation) + val base = r.getOrElse(OneRowRelation) val withFilter = f.map(Filter(_, base)).getOrElse(base) val withProjection = g .map(Aggregate(_, assignAliases(p), withFilter)) @@ -304,7 +309,9 @@ class SqlParser extends AbstractSparkSQLParser { ) protected lazy val cast: Parser[Expression] = - CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ { case exp ~ t => Cast(exp, t) } + CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ { + case exp ~ t => Cast(exp, t) + } protected lazy val literal: Parser[Literal] = ( numericLiteral @@ -362,7 +369,7 @@ class SqlParser extends AbstractSparkSQLParser { | expression ~ ("[" ~> expression <~ "]") ^^ { case base ~ ordinal => GetItem(base, ordinal) } | (expression <~ ".") ~ ident ^^ - { case base ~ fieldName => GetField(base, fieldName) } + { case base ~ fieldName => UnresolvedGetField(base, fieldName) } | cast | "(" ~> expression <~ ")" | function @@ -374,19 +381,6 @@ class SqlParser extends AbstractSparkSQLParser { protected lazy val dotExpressionHeader: Parser[Expression] = (ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ { - case i1 ~ i2 ~ rest => UnresolvedAttribute(i1 + "." + i2 + rest.mkString(".", ".", "")) - } - - protected lazy val dataType: Parser[DataType] = - ( STRING ^^^ StringType - | TIMESTAMP ^^^ TimestampType - | DOUBLE ^^^ DoubleType - | fixedDecimalType - | DECIMAL ^^^ DecimalType.Unlimited - ) - - protected lazy val fixedDecimalType: Parser[DataType] = - (DECIMAL ~ "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ { - case precision ~ scale => DecimalType(precision.toInt, scale.toInt) + case i1 ~ i2 ~ rest => UnresolvedAttribute((Seq(i1, i2) ++ rest).mkString(".")) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index cefd70acf3931..a45a48a424e9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -18,12 +18,12 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.util.collection.OpenHashSet +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types._ /** * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing @@ -37,11 +37,12 @@ object SimpleAnalyzer extends Analyzer(EmptyCatalog, EmptyFunctionRegistry, true * [[UnresolvedRelation]]s into fully typed objects using information in a schema [[Catalog]] and * a [[FunctionRegistry]]. */ -class Analyzer(catalog: Catalog, - registry: FunctionRegistry, - caseSensitive: Boolean, - maxIterations: Int = 100) - extends RuleExecutor[LogicalPlan] with HiveTypeCoercion { +class Analyzer( + catalog: Catalog, + registry: FunctionRegistry, + caseSensitive: Boolean, + maxIterations: Int = 100) + extends RuleExecutor[LogicalPlan] with HiveTypeCoercion with CheckAnalysis { val resolver = if (caseSensitive) caseSensitiveResolution else caseInsensitiveResolution @@ -50,51 +51,23 @@ class Analyzer(catalog: Catalog, /** * Override to provide additional rules for the "Resolution" batch. */ - val extendedRules: Seq[Rule[LogicalPlan]] = Nil + val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Nil lazy val batches: Seq[Batch] = Seq( - Batch("MultiInstanceRelations", Once, - NewRelationInstances), Batch("Resolution", fixedPoint, - ResolveReferences :: ResolveRelations :: + ResolveReferences :: ResolveGroupingAnalytics :: ResolveSortReferences :: - NewRelationInstances :: ImplicitGenerate :: ResolveFunctions :: GlobalAggregates :: UnresolvedHavingClauseAttributes :: TrimGroupingAliases :: typeCoercionRules ++ - extendedRules : _*), - Batch("Check Analysis", Once, - CheckResolution, - CheckAggregation), - Batch("AnalysisOperators", fixedPoint, - EliminateAnalysisOperators) + extendedResolutionRules : _*) ) - /** - * Makes sure all attributes and logical plans have been resolved. - */ - object CheckResolution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = { - plan.transform { - case p if p.expressions.exists(!_.resolved) => - throw new TreeNodeException(p, - s"Unresolved attributes: ${p.expressions.filterNot(_.resolved).mkString(",")}") - case p if !p.resolved && p.childrenResolved => - throw new TreeNodeException(p, "Unresolved plan found") - } match { - // As a backstop, use the root node to check that the entire plan tree is resolved. - case p if !p.resolved => - throw new TreeNodeException(p, "Unresolved plan in tree") - case p => p - } - } - } - /** * Removes no-op Alias expressions from the plan. */ @@ -193,46 +166,24 @@ class Analyzer(catalog: Catalog, } /** - * Checks for non-aggregated attributes with aggregation + * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog. */ - object CheckAggregation extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = { - plan.transform { - case aggregatePlan @ Aggregate(groupingExprs, aggregateExprs, child) => - def isValidAggregateExpression(expr: Expression): Boolean = expr match { - case _: AggregateExpression => true - case e: Attribute => groupingExprs.contains(e) - case e if groupingExprs.contains(e) => true - case e if e.references.isEmpty => true - case e => e.children.forall(isValidAggregateExpression) - } - - aggregateExprs.find { e => - !isValidAggregateExpression(e.transform { - // Should trim aliases around `GetField`s. These aliases are introduced while - // resolving struct field accesses, because `GetField` is not a `NamedExpression`. - // (Should we just turn `GetField` into a `NamedExpression`?) - case Alias(g: GetField, _) => g - }) - }.foreach { e => - throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e") - } - - aggregatePlan + object ResolveRelations extends Rule[LogicalPlan] { + def getTable(u: UnresolvedRelation): LogicalPlan = { + try { + catalog.lookupRelation(u.tableIdentifier, u.alias) + } catch { + case _: NoSuchTableException => + u.failAnalysis(s"no such table ${u.tableName}") } } - } - /** - * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog. - */ - object ResolveRelations extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case i @ InsertIntoTable(UnresolvedRelation(tableIdentifier, alias), _, _, _) => + case i @ InsertIntoTable(u: UnresolvedRelation, _, _, _) => i.copy( - table = EliminateAnalysisOperators(catalog.lookupRelation(tableIdentifier, alias))) - case UnresolvedRelation(tableIdentifier, alias) => - catalog.lookupRelation(tableIdentifier, alias) + table = EliminateSubQueries(getTable(u))) + case u: UnresolvedRelation => + getTable(u) } } @@ -256,6 +207,12 @@ class Analyzer(catalog: Catalog, case o => o :: Nil } Alias(child = f.copy(children = expandedArgs), name)() :: Nil + case Alias(c @ CreateArray(args), name) if containsStar(args) => + val expandedArgs = args.flatMap { + case s: Star => s.expand(child.output, resolver) + case o => o :: Nil + } + Alias(c.copy(children = expandedArgs), name)() :: Nil case o => o :: Nil }, child) @@ -276,39 +233,114 @@ class Analyzer(catalog: Catalog, } ) + // Special handling for cases when self-join introduce duplicate expression ids. + case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty => + val conflictingAttributes = left.outputSet.intersect(right.outputSet) + logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} in $j") + + val (oldRelation, newRelation) = right.collect { + // Handle base relations that might appear more than once. + case oldVersion: MultiInstanceRelation + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => + val newVersion = oldVersion.newInstance() + (oldVersion, newVersion) + + // Handle projects that create conflicting aliases. + case oldVersion @ Project(projectList, _) + if findAliases(projectList).intersect(conflictingAttributes).nonEmpty => + (oldVersion, oldVersion.copy(projectList = newAliases(projectList))) + + case oldVersion @ Aggregate(_, aggregateExpressions, _) + if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => + (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) + }.headOption.getOrElse { // Only handle first case, others will be fixed on the next pass. + sys.error( + s""" + |Failure when resolving conflicting references in Join: + |$plan + | + |Conflicting attributes: ${conflictingAttributes.mkString(",")} + """.stripMargin) + } + + val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) + val newRight = right transformUp { + case r if r == oldRelation => newRelation + } transformUp { + case other => other transformExpressions { + case a: Attribute => attributeRewrites.get(a).getOrElse(a) + } + } + j.copy(right = newRight) + case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") - q transformExpressions { + q transformExpressionsUp { case u @ UnresolvedAttribute(name) if resolver(name, VirtualColumn.groupingIdName) && q.isInstanceOf[GroupingAnalytics] => // Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics q.asInstanceOf[GroupingAnalytics].gid case u @ UnresolvedAttribute(name) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val result = q.resolveChildren(name, resolver).getOrElse(u) + val result = + withPosition(u) { q.resolveChildren(name, resolver).getOrElse(u) } logDebug(s"Resolving $u to $result") result - - // Resolve field names using the resolver. - case f @ GetField(child, fieldName) if !f.resolved && child.resolved => - child.dataType match { - case StructType(fields) => - val resolvedFieldName = fields.map(_.name).find(resolver(_, fieldName)) - resolvedFieldName.map(n => f.copy(fieldName = n)).getOrElse(f) - case _ => f - } + case UnresolvedGetField(child, fieldName) if child.resolved => + resolveGetField(child, fieldName) } } + def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = { + expressions.map { + case a: Alias => Alias(a.child, a.name)() + case other => other + } + } + + def findAliases(projectList: Seq[NamedExpression]): AttributeSet = { + AttributeSet(projectList.collect { case a: Alias => a.toAttribute }) + } + /** * Returns true if `exprs` contains a [[Star]]. */ protected def containsStar(exprs: Seq[Expression]): Boolean = exprs.exists(_.collect { case _: Star => true }.nonEmpty) + + /** + * Returns the resolved `GetField`, and report error if no desired field or over one + * desired fields are found. + */ + protected def resolveGetField(expr: Expression, fieldName: String): Expression = { + def findField(fields: Array[StructField]): Int = { + val checkField = (f: StructField) => resolver(f.name, fieldName) + val ordinal = fields.indexWhere(checkField) + if (ordinal == -1) { + throw new AnalysisException( + s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}") + } else if (fields.indexWhere(checkField, ordinal + 1) != -1) { + throw new AnalysisException( + s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}") + } else { + ordinal + } + } + expr.dataType match { + case StructType(fields) => + val ordinal = findField(fields) + StructGetField(expr, fields(ordinal), ordinal) + case ArrayType(StructType(fields), containsNull) => + val ordinal = findField(fields) + ArrayGetField(expr, fields(ordinal), ordinal, containsNull) + case otherType => + throw new AnalysisException(s"GetField is not valid on fields of type $otherType") + } + } } /** - * In many dialects of SQL is it valid to sort by attributes that are not present in the SELECT + * In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT * clause. This rule detects such queries and adds the required attributes to the original * projection, so that they will be available during sorting. Another projection is added to * remove these attributes after sorting. @@ -317,18 +349,16 @@ class Analyzer(catalog: Catalog, def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case s @ Sort(ordering, global, p @ Project(projectList, child)) if !s.resolved && p.resolved => - val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) - val resolved = unresolved.flatMap(child.resolve(_, resolver)) - val requiredAttributes = AttributeSet(resolved.collect { case a: Attribute => a }) + val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, p, child) - val missingInProject = requiredAttributes -- p.output - if (missingInProject.nonEmpty) { + // If this rule was not a no-op, return the transformed plan, otherwise return the original. + if (missing.nonEmpty) { // Add missing attributes and then project them away after the sort. - Project(projectList.map(_.toAttribute), - Sort(ordering, global, - Project(projectList ++ missingInProject, child))) + Project(p.output, + Sort(resolvedOrdering, global, + Project(projectList ++ missing, child))) } else { - logDebug(s"Failed to find $missingInProject in ${p.output.mkString(", ")}") + logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}") s // Nothing we can do here. Return original plan. } case s @ Sort(ordering, global, a @ Aggregate(grouping, aggs, child)) @@ -340,18 +370,54 @@ class Analyzer(catalog: Catalog, grouping.collect { case ne: NamedExpression => ne.toAttribute } ) - logDebug(s"Grouping expressions: $groupingRelation") - val resolved = unresolved.flatMap(groupingRelation.resolve(_, resolver)) - val missingInAggs = resolved.filterNot(a.outputSet.contains) - logDebug(s"Resolved: $resolved Missing in aggs: $missingInAggs") - if (missingInAggs.nonEmpty) { + val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, a, groupingRelation) + + if (missing.nonEmpty) { // Add missing grouping exprs and then project them away after the sort. Project(a.output, - Sort(ordering, global, Aggregate(grouping, aggs ++ missingInAggs, child))) + Sort(resolvedOrdering, global, + Aggregate(grouping, aggs ++ missing, child))) } else { s // Nothing we can do here. Return original plan. } } + + /** + * Given a child and a grandchild that are present beneath a sort operator, returns + * a resolved sort ordering and a list of attributes that are missing from the child + * but are present in the grandchild. + */ + def resolveAndFindMissing( + ordering: Seq[SortOrder], + child: LogicalPlan, + grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { + // Find any attributes that remain unresolved in the sort. + val unresolved: Seq[String] = + ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) + + // Create a map from name, to resolved attributes, when the desired name can be found + // prior to the projection. + val resolved: Map[String, NamedExpression] = + unresolved.flatMap(u => grandchild.resolve(u, resolver).map(a => u -> a)).toMap + + // Construct a set that contains all of the attributes that we need to evaluate the + // ordering. + val requiredAttributes = AttributeSet(resolved.values) + + // Figure out which ones are missing from the projection, so that we can add them and + // remove them after the sort. + val missingInProject = requiredAttributes -- child.output + + // Now that we have all the attributes we need, reconstruct a resolved ordering. + // It is important to do it here, instead of waiting for the standard resolved as adding + // attributes to the project below can actually introduce ambiquity that was not present + // before. + val resolvedOrdering = ordering.map(_ transform { + case u @ UnresolvedAttribute(name) => resolved.getOrElse(name, u) + }).asInstanceOf[Seq[SortOrder]] + + (resolvedOrdering, missingInProject.toSeq) + } } /** @@ -427,7 +493,7 @@ class Analyzer(catalog: Catalog, * only required to provide scoping information for attributes and can be removed once analysis is * complete. */ -object EliminateAnalysisOperators extends Rule[LogicalPlan] { +object EliminateSubQueries extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Subquery(_, child) => child } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index df8d03b86c533..5eb7dff0cede8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -21,6 +21,12 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery} +/** + * Thrown by a catalog when a table cannot be found. The analzyer will rethrow the exception + * as an AnalysisException with the correct position information. + */ +class NoSuchTableException extends Exception + /** * An interface for looking up relations by name. Used by an [[Analyzer]]. */ @@ -34,6 +40,14 @@ trait Catalog { tableIdentifier: Seq[String], alias: Option[String] = None): LogicalPlan + /** + * Returns tuples of (tableName, isTemporary) for all tables in the given database. + * isTemporary is a Boolean value indicates if a table is a temporary or not. + */ + def getTables(databaseName: Option[String]): Seq[(String, Boolean)] + + def refreshTable(databaseName: String, tableName: String): Unit + def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit def unregisterTable(tableIdentifier: Seq[String]): Unit @@ -72,12 +86,12 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog { tables += ((getDbTableName(tableIdent), plan)) } - override def unregisterTable(tableIdentifier: Seq[String]) = { + override def unregisterTable(tableIdentifier: Seq[String]): Unit = { val tableIdent = processTableIdentifier(tableIdentifier) tables -= getDbTableName(tableIdent) } - override def unregisterAllTables() = { + override def unregisterAllTables(): Unit = { tables.clear() } @@ -101,6 +115,16 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog { // properly qualified with this alias. alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers) } + + override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { + tables.map { + case (name, _) => (name, true) + }.toSeq + } + + override def refreshTable(databaseName: String, tableName: String): Unit = { + throw new UnsupportedOperationException + } } /** @@ -123,8 +147,8 @@ trait OverrideCatalog extends Catalog { } abstract override def lookupRelation( - tableIdentifier: Seq[String], - alias: Option[String] = None): LogicalPlan = { + tableIdentifier: Seq[String], + alias: Option[String] = None): LogicalPlan = { val tableIdent = processTableIdentifier(tableIdentifier) val overriddenTable = overrides.get(getDBTable(tableIdent)) val tableWithQualifers = overriddenTable.map(r => Subquery(tableIdent.last, r)) @@ -137,6 +161,27 @@ trait OverrideCatalog extends Catalog { withAlias.getOrElse(super.lookupRelation(tableIdentifier, alias)) } + abstract override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { + val dbName = if (!caseSensitive) { + if (databaseName.isDefined) Some(databaseName.get.toLowerCase) else None + } else { + databaseName + } + + val temporaryTables = overrides.filter { + // If a temporary table does not have an associated database, we should return its name. + case ((None, _), _) => true + // If a temporary table does have an associated database, we should return it if the database + // matches the given database name. + case ((db: Some[String], _), _) if db == dbName => true + case _ => false + }.map { + case ((_, tableName), _) => (tableName, true) + }.toSeq + + temporaryTables ++ super.getTables(databaseName) + } + override def registerTable( tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { @@ -160,25 +205,33 @@ trait OverrideCatalog extends Catalog { */ object EmptyCatalog extends Catalog { - val caseSensitive: Boolean = true + override val caseSensitive: Boolean = true - def tableExists(tableIdentifier: Seq[String]): Boolean = { + override def tableExists(tableIdentifier: Seq[String]): Boolean = { throw new UnsupportedOperationException } - def lookupRelation( - tableIdentifier: Seq[String], - alias: Option[String] = None) = { + override def lookupRelation( + tableIdentifier: Seq[String], + alias: Option[String] = None): LogicalPlan = { + throw new UnsupportedOperationException + } + + override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { throw new UnsupportedOperationException } - def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { + override def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { throw new UnsupportedOperationException } - def unregisterTable(tableIdentifier: Seq[String]): Unit = { + override def unregisterTable(tableIdentifier: Seq[String]): Unit = { throw new UnsupportedOperationException } override def unregisterAllTables(): Unit = {} + + override def refreshTable(databaseName: String, tableName: String): Unit = { + throw new UnsupportedOperationException + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala new file mode 100644 index 0000000000000..fa02111385c06 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types._ + +/** + * Throws user facing errors when passed invalid queries that fail to analyze. + */ +trait CheckAnalysis { + self: Analyzer => + + /** + * Override to provide additional checks for correct analysis. + * These rules will be evaluated after our built-in check rules. + */ + val extendedCheckRules: Seq[LogicalPlan => Unit] = Nil + + protected def failAnalysis(msg: String): Nothing = { + throw new AnalysisException(msg) + } + + def checkAnalysis(plan: LogicalPlan): Unit = { + // We transform up and order the rules so as to catch the first possible failure instead + // of the result of cascading resolution failures. + plan.foreachUp { + case operator: LogicalPlan => + operator transformExpressionsUp { + case a: Attribute if !a.resolved => + if (operator.childrenResolved) { + // Throw errors for specific problems with get field. + operator.resolveChildren(a.name, resolver, throwErrors = true) + } + + val from = operator.inputSet.map(_.name).mkString(", ") + a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") + + case c: Cast if !c.resolved => + failAnalysis( + s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") + + case b: BinaryExpression if !b.resolved => + failAnalysis( + s"invalid expression ${b.prettyString} " + + s"between ${b.left.simpleString} and ${b.right.simpleString}") + } + + operator match { + case f: Filter if f.condition.dataType != BooleanType => + failAnalysis( + s"filter expression '${f.condition.prettyString}' " + + s"of type ${f.condition.dataType.simpleString} is not a boolean.") + + case Aggregate(groupingExprs, aggregateExprs, child) => + def checkValidAggregateExpression(expr: Expression): Unit = expr match { + case _: AggregateExpression => // OK + case e: Attribute if !groupingExprs.contains(e) => + failAnalysis( + s"expression '${e.prettyString}' is neither present in the group by, " + + s"nor is it an aggregate function. " + + "Add to group by or wrap in first() if you don't care which value you get.") + case e if groupingExprs.contains(e) => // OK + case e if e.references.isEmpty => // OK + case e => e.children.foreach(checkValidAggregateExpression) + } + + val cleaned = aggregateExprs.map(_.transform { + // Should trim aliases around `GetField`s. These aliases are introduced while + // resolving struct field accesses, because `GetField` is not a `NamedExpression`. + // (Should we just turn `GetField` into a `NamedExpression`?) + case Alias(g, _) => g + }) + + cleaned.foreach(checkValidAggregateExpression) + + case _ => // Fallbacks to the following checks + } + + operator match { + case o if o.children.nonEmpty && o.missingInput.nonEmpty => + val missingAttributes = o.missingInput.mkString(",") + val input = o.inputSet.mkString(",") + + failAnalysis( + s"resolved attribute(s) $missingAttributes missing from $input " + + s"in operator ${operator.simpleString}") + + case o if !o.resolved => + failAnalysis( + s"unresolved operator ${operator.simpleString}") + + case _ => // Analysis successful! + } + } + extendedCheckRules.foreach(_(plan)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 760c49fbca4a5..c43ea55899695 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -27,25 +27,27 @@ trait FunctionRegistry { def registerFunction(name: String, builder: FunctionBuilder): Unit def lookupFunction(name: String, children: Seq[Expression]): Expression + + def caseSensitive: Boolean } trait OverrideFunctionRegistry extends FunctionRegistry { - val functionBuilders = new mutable.HashMap[String, FunctionBuilder]() + val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive) - def registerFunction(name: String, builder: FunctionBuilder) = { + override def registerFunction(name: String, builder: FunctionBuilder): Unit = { functionBuilders.put(name, builder) } abstract override def lookupFunction(name: String, children: Seq[Expression]): Expression = { - functionBuilders.get(name).map(_(children)).getOrElse(super.lookupFunction(name,children)) + functionBuilders.get(name).map(_(children)).getOrElse(super.lookupFunction(name, children)) } } -class SimpleFunctionRegistry extends FunctionRegistry { - val functionBuilders = new mutable.HashMap[String, FunctionBuilder]() +class SimpleFunctionRegistry(val caseSensitive: Boolean) extends FunctionRegistry { + val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive) - def registerFunction(name: String, builder: FunctionBuilder) = { + override def registerFunction(name: String, builder: FunctionBuilder): Unit = { functionBuilders.put(name, builder) } @@ -59,9 +61,37 @@ class SimpleFunctionRegistry extends FunctionRegistry { * functions are already filled in and the analyser needs only to resolve attribute references. */ object EmptyFunctionRegistry extends FunctionRegistry { - def registerFunction(name: String, builder: FunctionBuilder) = ??? + override def registerFunction(name: String, builder: FunctionBuilder): Unit = { + throw new UnsupportedOperationException + } - def lookupFunction(name: String, children: Seq[Expression]): Expression = { + override def lookupFunction(name: String, children: Seq[Expression]): Expression = { throw new UnsupportedOperationException } + + override def caseSensitive: Boolean = throw new UnsupportedOperationException } + +/** + * Build a map with String type of key, and it also supports either key case + * sensitive or insensitive. + * TODO move this into util folder? + */ +object StringKeyHashMap { + def apply[T](caseSensitive: Boolean): StringKeyHashMap[T] = caseSensitive match { + case false => new StringKeyHashMap[T](_.toLowerCase) + case true => new StringKeyHashMap[T](identity) + } +} + +class StringKeyHashMap[T](normalizer: (String) => String) { + private val base = new collection.mutable.HashMap[String, T]() + + def apply(key: String): T = base(normalizer(key)) + + def get(key: String): Option[T] = base.get(normalizer(key)) + def put(key: String, value: T): Option[T] = base.put(normalizer(key), value) + def remove(key: String): Option[T] = base.remove(normalizer(key)) + def iterator: Iterator[(String, T)] = base.toIterator +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 34ef7d28cc7f2..3c7b46e0702a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -78,6 +78,7 @@ trait HiveTypeCoercion { FunctionArgumentConversion :: CaseWhenCoercion :: Division :: + PropagateTypes :: Nil /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala index 4c5fb3f45bf49..35b74024a4cab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala @@ -26,28 +26,9 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan * produced by distinct operators in a query tree as this breaks the guarantee that expression * ids, which are used to differentiate attributes, are unique. * - * Before analysis, all operators that include this trait will be asked to produce a new version + * During analysis, operators that include this trait may be asked to produce a new version * of itself with globally unique expression ids. */ trait MultiInstanceRelation { - def newInstance(): this.type -} - -/** - * If any MultiInstanceRelation appears more than once in the query plan then the plan is updated so - * that each instance has unique expression ids for the attributes produced. - */ -object NewRelationInstances extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = { - val localRelations = plan collect { case l: MultiInstanceRelation => l} - val multiAppearance = localRelations - .groupBy(identity[MultiInstanceRelation]) - .filter { case (_, ls) => ls.size > 1 } - .map(_._1) - .toSet - - plan transform { - case l: MultiInstanceRelation if multiAppearance.contains(l) => l.newInstance() - } - } + def newInstance(): LogicalPlan } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala index 3f672a3e0fd91..c61c395cb4bb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.trees.TreeNode + /** * Provides a logical query plan [[Analyzer]] and supporting classes for performing analysis. * Analysis consists of translating [[UnresolvedAttribute]]s and [[UnresolvedRelation]]s @@ -25,11 +28,26 @@ package org.apache.spark.sql.catalyst package object analysis { /** - * Responsible for resolving which identifiers refer to the same entity. For example, by using - * case insensitive equality. + * Resolver should return true if the first string refers to the same entity as the second string. + * For example, by using case insensitive equality. */ type Resolver = (String, String) => Boolean val caseInsensitiveResolution = (a: String, b: String) => a.equalsIgnoreCase(b) val caseSensitiveResolution = (a: String, b: String) => a == b + + implicit class AnalysisErrorAt(t: TreeNode[_]) { + /** Fails the analysis at the point where a specific tree node was parsed. */ + def failAnalysis(msg: String): Nothing = { + throw new AnalysisException(msg, t.origin.line, t.origin.startPosition) + } + } + + /** Catches any AnalysisExceptions thrown by `f` and attaches `t`'s position if any. */ + def withPosition[A](t: TreeNode[_])(f: => A) = { + try f catch { + case a: AnalysisException => + throw a.withPosition(t.origin.line, t.origin.startPosition) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 66060289189ef..300e9ba187bc5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.types.DataType /** * Thrown when an invalid attempt is made to access a property of a tree that has yet to be fully @@ -36,7 +37,12 @@ class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: Str case class UnresolvedRelation( tableIdentifier: Seq[String], alias: Option[String] = None) extends LeafNode { - override def output = Nil + + /** Returns a `.` separated name for this relation. */ + def tableName: String = tableIdentifier.mkString(".") + + override def output: Seq[Attribute] = Nil + override lazy val resolved = false } @@ -44,16 +50,16 @@ case class UnresolvedRelation( * Holds the name of an attribute that has yet to be resolved. */ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNode[Expression] { - override def exprId = throw new UnresolvedException(this, "exprId") - override def dataType = throw new UnresolvedException(this, "dataType") - override def nullable = throw new UnresolvedException(this, "nullable") - override def qualifiers = throw new UnresolvedException(this, "qualifiers") + override def exprId: ExprId = throw new UnresolvedException(this, "exprId") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") override lazy val resolved = false - override def newInstance() = this - override def withNullability(newNullability: Boolean) = this - override def withQualifiers(newQualifiers: Seq[String]) = this - override def withName(newName: String) = UnresolvedAttribute(name) + override def newInstance(): UnresolvedAttribute = this + override def withNullability(newNullability: Boolean): UnresolvedAttribute = this + override def withQualifiers(newQualifiers: Seq[String]): UnresolvedAttribute = this + override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute(name) // Unresolved attributes are transient at compile time and don't get evaluated during execution. override def eval(input: Row = null): EvaluatedType = @@ -63,16 +69,16 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo } case class UnresolvedFunction(name: String, children: Seq[Expression]) extends Expression { - override def dataType = throw new UnresolvedException(this, "dataType") - override def foldable = throw new UnresolvedException(this, "foldable") - override def nullable = throw new UnresolvedException(this, "nullable") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false // Unresolved functions are transient at compile time and don't get evaluated during execution. override def eval(input: Row = null): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString = s"'$name(${children.mkString(",")})" + override def toString: String = s"'$name(${children.mkString(",")})" } /** @@ -82,17 +88,17 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E trait Star extends Attribute with trees.LeafNode[Expression] { self: Product => - override def name = throw new UnresolvedException(this, "name") - override def exprId = throw new UnresolvedException(this, "exprId") - override def dataType = throw new UnresolvedException(this, "dataType") - override def nullable = throw new UnresolvedException(this, "nullable") - override def qualifiers = throw new UnresolvedException(this, "qualifiers") + override def name: String = throw new UnresolvedException(this, "name") + override def exprId: ExprId = throw new UnresolvedException(this, "exprId") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") override lazy val resolved = false - override def newInstance() = this - override def withNullability(newNullability: Boolean) = this - override def withQualifiers(newQualifiers: Seq[String]) = this - override def withName(newName: String) = this + override def newInstance(): Star = this + override def withNullability(newNullability: Boolean): Star = this + override def withQualifiers(newQualifiers: Seq[String]): Star = this + override def withName(newName: String): Star = this // Star gets expanded at runtime so we never evaluate a Star. override def eval(input: Row = null): EvaluatedType = @@ -125,9 +131,47 @@ case class UnresolvedStar(table: Option[String]) extends Star { } } - override def toString = table.map(_ + ".").getOrElse("") + "*" + override def toString: String = table.map(_ + ".").getOrElse("") + "*" } +/** + * Used to assign new names to Generator's output, such as hive udtf. + * For example the SQL expression "stack(2, key, value, key, value) as (a, b)" could be represented + * as follows: + * MultiAlias(stack_function, Seq(a, b)) + + * @param child the computation being performed + * @param names the names to be associated with each output of computing [[child]]. + */ +case class MultiAlias(child: Expression, names: Seq[String]) + extends Attribute with trees.UnaryNode[Expression] { + + override def name: String = throw new UnresolvedException(this, "name") + + override def exprId: ExprId = throw new UnresolvedException(this, "exprId") + + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + + override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") + + override lazy val resolved = false + + override def newInstance(): MultiAlias = this + + override def withNullability(newNullability: Boolean): MultiAlias = this + + override def withQualifiers(newQualifiers: Seq[String]): MultiAlias = this + + override def withName(newName: String): MultiAlias = this + + override def eval(input: Row = null): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + + override def toString: String = s"$child AS $names" + +} /** * Represents all the resolved input attributes to a given relational operator. This is used @@ -137,5 +181,17 @@ case class UnresolvedStar(table: Option[String]) extends Star { */ case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star { override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = expressions - override def toString = expressions.mkString("ResolvedStar(", ", ", ")") + override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")") +} + +case class UnresolvedGetField(child: Expression, fieldName: String) extends UnaryExpression { + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override lazy val resolved = false + + override def eval(input: Row = null): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + + override def toString: String = s"$child.$fieldName" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala old mode 100755 new mode 100644 index 417659eed5957..145f062dd6817 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp} import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedGetField, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} @@ -61,60 +61,60 @@ package object dsl { trait ImplicitOperators { def expr: Expression - def unary_- = UnaryMinus(expr) - def unary_! = Not(expr) - def unary_~ = BitwiseNot(expr) - - def + (other: Expression) = Add(expr, other) - def - (other: Expression) = Subtract(expr, other) - def * (other: Expression) = Multiply(expr, other) - def / (other: Expression) = Divide(expr, other) - def % (other: Expression) = Remainder(expr, other) - def & (other: Expression) = BitwiseAnd(expr, other) - def | (other: Expression) = BitwiseOr(expr, other) - def ^ (other: Expression) = BitwiseXor(expr, other) - - def && (other: Expression) = And(expr, other) - def || (other: Expression) = Or(expr, other) - - def < (other: Expression) = LessThan(expr, other) - def <= (other: Expression) = LessThanOrEqual(expr, other) - def > (other: Expression) = GreaterThan(expr, other) - def >= (other: Expression) = GreaterThanOrEqual(expr, other) - def === (other: Expression) = EqualTo(expr, other) - def <=> (other: Expression) = EqualNullSafe(expr, other) - def !== (other: Expression) = Not(EqualTo(expr, other)) - - def in(list: Expression*) = In(expr, list) - - def like(other: Expression) = Like(expr, other) - def rlike(other: Expression) = RLike(expr, other) - def contains(other: Expression) = Contains(expr, other) - def startsWith(other: Expression) = StartsWith(expr, other) - def endsWith(other: Expression) = EndsWith(expr, other) - def substr(pos: Expression, len: Expression = Literal(Int.MaxValue)) = + def unary_- : Expression= UnaryMinus(expr) + def unary_! : Predicate = Not(expr) + def unary_~ : Expression = BitwiseNot(expr) + + def + (other: Expression): Expression = Add(expr, other) + def - (other: Expression): Expression = Subtract(expr, other) + def * (other: Expression): Expression = Multiply(expr, other) + def / (other: Expression): Expression = Divide(expr, other) + def % (other: Expression): Expression = Remainder(expr, other) + def & (other: Expression): Expression = BitwiseAnd(expr, other) + def | (other: Expression): Expression = BitwiseOr(expr, other) + def ^ (other: Expression): Expression = BitwiseXor(expr, other) + + def && (other: Expression): Predicate = And(expr, other) + def || (other: Expression): Predicate = Or(expr, other) + + def < (other: Expression): Predicate = LessThan(expr, other) + def <= (other: Expression): Predicate = LessThanOrEqual(expr, other) + def > (other: Expression): Predicate = GreaterThan(expr, other) + def >= (other: Expression): Predicate = GreaterThanOrEqual(expr, other) + def === (other: Expression): Predicate = EqualTo(expr, other) + def <=> (other: Expression): Predicate = EqualNullSafe(expr, other) + def !== (other: Expression): Predicate = Not(EqualTo(expr, other)) + + def in(list: Expression*): Expression = In(expr, list) + + def like(other: Expression): Expression = Like(expr, other) + def rlike(other: Expression): Expression = RLike(expr, other) + def contains(other: Expression): Expression = Contains(expr, other) + def startsWith(other: Expression): Expression = StartsWith(expr, other) + def endsWith(other: Expression): Expression = EndsWith(expr, other) + def substr(pos: Expression, len: Expression = Literal(Int.MaxValue)): Expression = Substring(expr, pos, len) - def substring(pos: Expression, len: Expression = Literal(Int.MaxValue)) = + def substring(pos: Expression, len: Expression = Literal(Int.MaxValue)): Expression = Substring(expr, pos, len) - def isNull = IsNull(expr) - def isNotNull = IsNotNull(expr) + def isNull: Predicate = IsNull(expr) + def isNotNull: Predicate = IsNotNull(expr) - def getItem(ordinal: Expression) = GetItem(expr, ordinal) - def getField(fieldName: String) = GetField(expr, fieldName) + def getItem(ordinal: Expression): Expression = GetItem(expr, ordinal) + def getField(fieldName: String): UnresolvedGetField = UnresolvedGetField(expr, fieldName) - def cast(to: DataType) = Cast(expr, to) + def cast(to: DataType): Expression = Cast(expr, to) - def asc = SortOrder(expr, Ascending) - def desc = SortOrder(expr, Descending) + def asc: SortOrder = SortOrder(expr, Ascending) + def desc: SortOrder = SortOrder(expr, Descending) - def as(alias: String) = Alias(expr, alias)() - def as(alias: Symbol) = Alias(expr, alias.name)() + def as(alias: String): NamedExpression = Alias(expr, alias)() + def as(alias: Symbol): NamedExpression = Alias(expr, alias.name)() } trait ExpressionConversions { implicit class DslExpression(e: Expression) extends ImplicitOperators { - def expr = e + def expr: Expression = e } implicit def booleanToLiteral(b: Boolean): Literal = Literal(b) @@ -144,94 +144,100 @@ package object dsl { } } - def sum(e: Expression) = Sum(e) - def sumDistinct(e: Expression) = SumDistinct(e) - def count(e: Expression) = Count(e) - def countDistinct(e: Expression*) = CountDistinct(e) - def approxCountDistinct(e: Expression, rsd: Double = 0.05) = ApproxCountDistinct(e, rsd) - def avg(e: Expression) = Average(e) - def first(e: Expression) = First(e) - def last(e: Expression) = Last(e) - def min(e: Expression) = Min(e) - def max(e: Expression) = Max(e) - def upper(e: Expression) = Upper(e) - def lower(e: Expression) = Lower(e) - def sqrt(e: Expression) = Sqrt(e) - def abs(e: Expression) = Abs(e) - - implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name } + def sum(e: Expression): Expression = Sum(e) + def sumDistinct(e: Expression): Expression = SumDistinct(e) + def count(e: Expression): Expression = Count(e) + def countDistinct(e: Expression*): Expression = CountDistinct(e) + def approxCountDistinct(e: Expression, rsd: Double = 0.05): Expression = + ApproxCountDistinct(e, rsd) + def avg(e: Expression): Expression = Average(e) + def first(e: Expression): Expression = First(e) + def last(e: Expression): Expression = Last(e) + def min(e: Expression): Expression = Min(e) + def max(e: Expression): Expression = Max(e) + def upper(e: Expression): Expression = Upper(e) + def lower(e: Expression): Expression = Lower(e) + def sqrt(e: Expression): Expression = Sqrt(e) + def abs(e: Expression): Expression = Abs(e) + + implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name } // TODO more implicit class for literal? implicit class DslString(val s: String) extends ImplicitOperators { override def expr: Expression = Literal(s) - def attr = analysis.UnresolvedAttribute(s) + def attr: UnresolvedAttribute = analysis.UnresolvedAttribute(s) } abstract class ImplicitAttribute extends ImplicitOperators { def s: String - def expr = attr - def attr = analysis.UnresolvedAttribute(s) + def expr: UnresolvedAttribute = attr + def attr: UnresolvedAttribute = analysis.UnresolvedAttribute(s) /** Creates a new AttributeReference of type boolean */ - def boolean = AttributeReference(s, BooleanType, nullable = true)() + def boolean: AttributeReference = AttributeReference(s, BooleanType, nullable = true)() /** Creates a new AttributeReference of type byte */ - def byte = AttributeReference(s, ByteType, nullable = true)() + def byte: AttributeReference = AttributeReference(s, ByteType, nullable = true)() /** Creates a new AttributeReference of type short */ - def short = AttributeReference(s, ShortType, nullable = true)() + def short: AttributeReference = AttributeReference(s, ShortType, nullable = true)() /** Creates a new AttributeReference of type int */ - def int = AttributeReference(s, IntegerType, nullable = true)() + def int: AttributeReference = AttributeReference(s, IntegerType, nullable = true)() /** Creates a new AttributeReference of type long */ - def long = AttributeReference(s, LongType, nullable = true)() + def long: AttributeReference = AttributeReference(s, LongType, nullable = true)() /** Creates a new AttributeReference of type float */ - def float = AttributeReference(s, FloatType, nullable = true)() + def float: AttributeReference = AttributeReference(s, FloatType, nullable = true)() /** Creates a new AttributeReference of type double */ - def double = AttributeReference(s, DoubleType, nullable = true)() + def double: AttributeReference = AttributeReference(s, DoubleType, nullable = true)() /** Creates a new AttributeReference of type string */ - def string = AttributeReference(s, StringType, nullable = true)() + def string: AttributeReference = AttributeReference(s, StringType, nullable = true)() /** Creates a new AttributeReference of type date */ - def date = AttributeReference(s, DateType, nullable = true)() + def date: AttributeReference = AttributeReference(s, DateType, nullable = true)() /** Creates a new AttributeReference of type decimal */ - def decimal = AttributeReference(s, DecimalType.Unlimited, nullable = true)() + def decimal: AttributeReference = + AttributeReference(s, DecimalType.Unlimited, nullable = true)() /** Creates a new AttributeReference of type decimal */ - def decimal(precision: Int, scale: Int) = + def decimal(precision: Int, scale: Int): AttributeReference = AttributeReference(s, DecimalType(precision, scale), nullable = true)() /** Creates a new AttributeReference of type timestamp */ - def timestamp = AttributeReference(s, TimestampType, nullable = true)() + def timestamp: AttributeReference = AttributeReference(s, TimestampType, nullable = true)() /** Creates a new AttributeReference of type binary */ - def binary = AttributeReference(s, BinaryType, nullable = true)() + def binary: AttributeReference = AttributeReference(s, BinaryType, nullable = true)() /** Creates a new AttributeReference of type array */ - def array(dataType: DataType) = AttributeReference(s, ArrayType(dataType), nullable = true)() + def array(dataType: DataType): AttributeReference = + AttributeReference(s, ArrayType(dataType), nullable = true)() /** Creates a new AttributeReference of type map */ def map(keyType: DataType, valueType: DataType): AttributeReference = map(MapType(keyType, valueType)) - def map(mapType: MapType) = AttributeReference(s, mapType, nullable = true)() + + def map(mapType: MapType): AttributeReference = + AttributeReference(s, mapType, nullable = true)() /** Creates a new AttributeReference of type struct */ def struct(fields: StructField*): AttributeReference = struct(StructType(fields)) - def struct(structType: StructType) = AttributeReference(s, structType, nullable = true)() + def struct(structType: StructType): AttributeReference = + AttributeReference(s, structType, nullable = true)() } implicit class DslAttribute(a: AttributeReference) { - def notNull = a.withNullability(false) - def nullable = a.withNullability(true) + def notNull: AttributeReference = a.withNullability(false) + def nullable: AttributeReference = a.withNullability(true) // Protobuf terminology - def required = a.withNullability(false) + def required: AttributeReference = a.withNullability(false) - def at(ordinal: Int) = BoundReference(ordinal, a.dataType, a.nullable) + def at(ordinal: Int): BoundReference = BoundReference(ordinal, a.dataType, a.nullable) } } @@ -241,23 +247,23 @@ package object dsl { abstract class LogicalPlanFunctions { def logicalPlan: LogicalPlan - def select(exprs: NamedExpression*) = Project(exprs, logicalPlan) + def select(exprs: NamedExpression*): LogicalPlan = Project(exprs, logicalPlan) - def where(condition: Expression) = Filter(condition, logicalPlan) + def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan) - def limit(limitExpr: Expression) = Limit(limitExpr, logicalPlan) + def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan) def join( otherPlan: LogicalPlan, joinType: JoinType = Inner, - condition: Option[Expression] = None) = + condition: Option[Expression] = None): LogicalPlan = Join(logicalPlan, otherPlan, joinType, condition) - def orderBy(sortExprs: SortOrder*) = Sort(sortExprs, true, logicalPlan) + def orderBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, true, logicalPlan) - def sortBy(sortExprs: SortOrder*) = Sort(sortExprs, false, logicalPlan) + def sortBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, false, logicalPlan) - def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*) = { + def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): LogicalPlan = { val aliasedExprs = aggregateExprs.map { case ne: NamedExpression => ne case e => Alias(e, e.toString)() @@ -265,41 +271,43 @@ package object dsl { Aggregate(groupingExprs, aliasedExprs, logicalPlan) } - def subquery(alias: Symbol) = Subquery(alias.name, logicalPlan) + def subquery(alias: Symbol): LogicalPlan = Subquery(alias.name, logicalPlan) - def unionAll(otherPlan: LogicalPlan) = Union(logicalPlan, otherPlan) + def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan) - def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean) = + def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean): LogicalPlan = Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan) def sample( fraction: Double, withReplacement: Boolean = true, - seed: Int = (math.random * 1000).toInt) = + seed: Int = (math.random * 1000).toInt): LogicalPlan = Sample(fraction, withReplacement, seed, logicalPlan) def generate( generator: Generator, join: Boolean = false, outer: Boolean = false, - alias: Option[String] = None) = + alias: Option[String] = None): LogicalPlan = Generate(generator, join, outer, None, logicalPlan) - def insertInto(tableName: String, overwrite: Boolean = false) = + def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite) - def analyze = analysis.SimpleAnalyzer(logicalPlan) + def analyze: LogicalPlan = EliminateSubQueries(analysis.SimpleAnalyzer(logicalPlan)) } object plans { // scalastyle:ignore implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) extends LogicalPlanFunctions { - def writeToFile(path: String) = WriteToFile(path, logicalPlan) + def writeToFile(path: String): LogicalPlan = WriteToFile(path, logicalPlan) } } case class ScalaUdfBuilder[T: TypeTag](f: AnyRef) { - def call(args: Expression*) = ScalaUdf(f, ScalaReflection.schemaFor(typeTag[T]).dataType, args) + def call(args: Expression*): ScalaUdf = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[T]).dataType, args) + } } // scalastyle:off diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index 82e760b6c6916..96a11e352ec50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -23,7 +23,9 @@ package org.apache.spark.sql.catalyst.expressions * of the name, or the expected nullability). */ object AttributeMap { - def apply[A](kvs: Seq[(Attribute, A)]) = new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap) + def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = { + new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap) + } } class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index 171845ad14e3e..5345696570b41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -17,26 +17,29 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.Star protected class AttributeEquals(val a: Attribute) { - override def hashCode() = a.exprId.hashCode() - override def equals(other: Any) = (a, other.asInstanceOf[AttributeEquals].a) match { + override def hashCode(): Int = a match { + case ar: AttributeReference => ar.exprId.hashCode() + case a => a.hashCode() + } + + override def equals(other: Any): Boolean = (a, other.asInstanceOf[AttributeEquals].a) match { case (a1: AttributeReference, a2: AttributeReference) => a1.exprId == a2.exprId case (a1, a2) => a1 == a2 } } object AttributeSet { - def apply(a: Attribute) = - new AttributeSet(Set(new AttributeEquals(a))) + def apply(a: Attribute): AttributeSet = new AttributeSet(Set(new AttributeEquals(a))) /** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */ - def apply(baseSet: Seq[Expression]) = + def apply(baseSet: Iterable[Expression]): AttributeSet = { new AttributeSet( baseSet .flatMap(_.references) .map(new AttributeEquals(_)).toSet) + } } /** @@ -54,8 +57,9 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) extends Traversable[Attribute] with Serializable { /** Returns true if the members of this AttributeSet and other are the same. */ - override def equals(other: Any) = other match { - case otherSet: AttributeSet => baseSet.map(_.a).forall(otherSet.contains) + override def equals(other: Any): Boolean = other match { + case otherSet: AttributeSet => + otherSet.size == baseSet.size && baseSet.map(_.a).forall(otherSet.contains) case _ => false } @@ -78,32 +82,34 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) * Returns true if the [[Attribute Attributes]] in this set are a subset of the Attributes in * `other`. */ - def subsetOf(other: AttributeSet) = baseSet.subsetOf(other.baseSet) + def subsetOf(other: AttributeSet): Boolean = baseSet.subsetOf(other.baseSet) /** * Returns a new [[AttributeSet]] that does not contain any of the [[Attribute Attributes]] found * in `other`. */ - def --(other: Traversable[NamedExpression]) = + def --(other: Traversable[NamedExpression]): AttributeSet = new AttributeSet(baseSet -- other.map(a => new AttributeEquals(a.toAttribute))) /** * Returns a new [[AttributeSet]] that contains all of the [[Attribute Attributes]] found * in `other`. */ - def ++(other: AttributeSet) = new AttributeSet(baseSet ++ other.baseSet) + def ++(other: AttributeSet): AttributeSet = new AttributeSet(baseSet ++ other.baseSet) /** * Returns a new [[AttributeSet]] contain only the [[Attribute Attributes]] where `f` evaluates to * true. */ - override def filter(f: Attribute => Boolean) = new AttributeSet(baseSet.filter(ae => f(ae.a))) + override def filter(f: Attribute => Boolean): AttributeSet = + new AttributeSet(baseSet.filter(ae => f(ae.a))) /** * Returns a new [[AttributeSet]] that only contains [[Attribute Attributes]] that are found in * `this` and `other`. */ - def intersect(other: AttributeSet) = new AttributeSet(baseSet.intersect(other.baseSet)) + def intersect(other: AttributeSet): AttributeSet = + new AttributeSet(baseSet.intersect(other.baseSet)) override def foreach[U](f: (Attribute) => U): Unit = baseSet.map(_.a).foreach(f) @@ -111,7 +117,7 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) // sorts of things in its closure. override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq - override def toString = "{" + baseSet.map(_.a).mkString(", ") + "}" + override def toString: String = "{" + baseSet.map(_.a).mkString(", ") + "}" override def isEmpty: Boolean = baseSet.isEmpty } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 76a9f08dea85f..2225621dbaabd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -32,7 +32,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) type EvaluatedType = Any - override def toString = s"input[$ordinal]" + override def toString: String = s"input[$ordinal]" override def eval(input: Row): Any = input(ordinal) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index ece5ee73618cb..9bde74ac22669 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -29,9 +29,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w override lazy val resolved = childrenResolved && resolve(child.dataType, dataType) - override def foldable = child.foldable + override def foldable: Boolean = child.foldable - override def nullable = forceNullable(child.dataType, dataType) || child.nullable + override def nullable: Boolean = forceNullable(child.dataType, dataType) || child.nullable private[this] def forceNullable(from: DataType, to: DataType) = (from, to) match { case (StringType, _: NumericType) => true @@ -103,7 +103,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } } - override def toString = s"CAST($child, $dataType)" + override def toString: String = s"CAST($child, $dataType)" type EvaluatedType = Any @@ -113,7 +113,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // UDFToString private[this] def castToString(from: DataType): Any => Any = from match { case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8")) - case DateType => buildCast[Date](_, dateToString) + case DateType => buildCast[Int](_, d => DateUtils.toString(d)) case TimestampType => buildCast[Timestamp](_, timestampToString) case _ => buildCast[Any](_, _.toString) } @@ -131,7 +131,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Timestamp](_, t => t.getTime() != 0 || t.getNanos() != 0) case DateType => // Hive would return null when cast from date to boolean - buildCast[Date](_, d => null) + buildCast[Int](_, d => null) case LongType => buildCast[Long](_, _ != 0) case IntegerType => @@ -171,7 +171,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case ByteType => buildCast[Byte](_, b => new Timestamp(b)) case DateType => - buildCast[Date](_, d => new Timestamp(d.getTime)) + buildCast[Int](_, d => new Timestamp(DateUtils.toJavaDate(d).getTime)) // TimestampWritable.decimalToTimestamp case DecimalType() => buildCast[Decimal](_, d => decimalToTimestamp(d)) @@ -224,37 +224,24 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } } - // Converts Timestamp to string according to Hive TimestampWritable convention - private[this] def timestampToDateString(ts: Timestamp): String = { - Cast.threadLocalDateFormat.get.format(ts) - } - // DateConverter private[this] def castToDate(from: DataType): Any => Any = from match { case StringType => buildCast[String](_, s => - try Date.valueOf(s) catch { case _: java.lang.IllegalArgumentException => null }) + try DateUtils.fromJavaDate(Date.valueOf(s)) + catch { case _: java.lang.IllegalArgumentException => null } + ) case TimestampType => // throw valid precision more than seconds, according to Hive. // Timestamp.nanos is in 0 to 999,999,999, no more than a second. - buildCast[Timestamp](_, t => new Date(Math.floor(t.getTime / 1000.0).toLong * 1000)) + buildCast[Timestamp](_, t => DateUtils.millisToDays(t.getTime)) // Hive throws this exception as a Semantic Exception - // It is never possible to compare result when hive return with exception, so we can return null + // It is never possible to compare result when hive return with exception, + // so we can return null // NULL is more reasonable here, since the query itself obeys the grammar. case _ => _ => null } - // Date cannot be cast to long, according to hive - private[this] def dateToLong(d: Date) = null - - // Date cannot be cast to double, according to hive - private[this] def dateToDouble(d: Date) = null - - // Converts Date to string according to Hive DateWritable convention - private[this] def dateToString(d: Date): String = { - Cast.threadLocalDateFormat.get.format(d) - } - // LongConverter private[this] def castToLong(from: DataType): Any => Any = from match { case StringType => @@ -264,7 +251,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case BooleanType => buildCast[Boolean](_, b => if (b) 1L else 0L) case DateType => - buildCast[Date](_, d => dateToLong(d)) + buildCast[Int](_, d => null) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t)) case x: NumericType => @@ -280,7 +267,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case BooleanType => buildCast[Boolean](_, b => if (b) 1 else 0) case DateType => - buildCast[Date](_, d => dateToLong(d)) + buildCast[Int](_, d => null) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toInt) case x: NumericType => @@ -296,7 +283,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case BooleanType => buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort) case DateType => - buildCast[Date](_, d => dateToLong(d)) + buildCast[Int](_, d => null) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toShort) case x: NumericType => @@ -312,7 +299,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case BooleanType => buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte) case DateType => - buildCast[Date](_, d => dateToLong(d)) + buildCast[Int](_, d => null) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toByte) case x: NumericType => @@ -342,7 +329,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case BooleanType => buildCast[Boolean](_, b => changePrecision(if (b) Decimal(1) else Decimal(0), target)) case DateType => - buildCast[Date](_, d => null) // date can't cast to decimal in Hive + buildCast[Int](_, d => null) // date can't cast to decimal in Hive case TimestampType => // Note that we lose precision here. buildCast[Timestamp](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) @@ -367,7 +354,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case BooleanType => buildCast[Boolean](_, b => if (b) 1d else 0d) case DateType => - buildCast[Date](_, d => dateToDouble(d)) + buildCast[Int](_, d => null) case TimestampType => buildCast[Timestamp](_, t => timestampToDouble(t)) case x: NumericType => @@ -383,7 +370,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case BooleanType => buildCast[Boolean](_, b => if (b) 1f else 0f) case DateType => - buildCast[Date](_, d => dateToDouble(d)) + buildCast[Int](_, d => null) case TimestampType => buildCast[Timestamp](_, t => timestampToDouble(t).toFloat) case x: NumericType => @@ -442,16 +429,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w object Cast { // `SimpleDateFormat` is not thread-safe. - private[sql] val threadLocalDateFormat = new ThreadLocal[DateFormat] { - override def initialValue() = { - new SimpleDateFormat("yyyy-MM-dd") + private[sql] val threadLocalTimestampFormat = new ThreadLocal[DateFormat] { + override def initialValue(): SimpleDateFormat = { + new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") } } // `SimpleDateFormat` is not thread-safe. - private[sql] val threadLocalTimestampFormat = new ThreadLocal[DateFormat] { - override def initialValue() = { - new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + private[sql] val threadLocalDateFormat = new ThreadLocal[DateFormat] { + override def initialValue(): SimpleDateFormat = { + new SimpleDateFormat("yyyy-MM-dd") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index cf14992ef835c..4e3bbc06a5b4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode @@ -64,206 +65,17 @@ abstract class Expression extends TreeNode[Expression] { * Returns true if all the children of this expression have been resolved to a specific schema * and false if any still contains any unresolved placeholders. */ - def childrenResolved = !children.exists(!_.resolved) + def childrenResolved: Boolean = !children.exists(!_.resolved) /** - * A set of helper functions that return the correct descendant of `scala.math.Numeric[T]` type - * and do any casting necessary of child evaluation. + * Returns a string representation of this expression that does not have developer centric + * debugging information like the expression id. */ - @inline - def n1(e: Expression, i: Row, f: ((Numeric[Any], Any) => Any)): Any = { - val evalE = e.eval(i) - if (evalE == null) { - null - } else { - e.dataType match { - case n: NumericType => - val castedFunction = f.asInstanceOf[(Numeric[n.JvmType], n.JvmType) => n.JvmType] - castedFunction(n.numeric, evalE.asInstanceOf[n.JvmType]) - case other => sys.error(s"Type $other does not support numeric operations") - } - } - } - - /** - * Evaluation helper function for 2 Numeric children expressions. Those expressions are supposed - * to be in the same data type, and also the return type. - * Either one of the expressions result is null, the evaluation result should be null. - */ - @inline - protected final def n2( - i: Row, - e1: Expression, - e2: Expression, - f: ((Numeric[Any], Any, Any) => Any)): Any = { - - if (e1.dataType != e2.dataType) { - throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}") - } - - val evalE1 = e1.eval(i) - if(evalE1 == null) { - null - } else { - val evalE2 = e2.eval(i) - if (evalE2 == null) { - null - } else { - e1.dataType match { - case n: NumericType => - f.asInstanceOf[(Numeric[n.JvmType], n.JvmType, n.JvmType) => n.JvmType]( - n.numeric, evalE1.asInstanceOf[n.JvmType], evalE2.asInstanceOf[n.JvmType]) - case other => sys.error(s"Type $other does not support numeric operations") - } - } - } - } - - /** - * Evaluation helper function for 2 Fractional children expressions. Those expressions are - * supposed to be in the same data type, and also the return type. - * Either one of the expressions result is null, the evaluation result should be null. - */ - @inline - protected final def f2( - i: Row, - e1: Expression, - e2: Expression, - f: ((Fractional[Any], Any, Any) => Any)): Any = { - if (e1.dataType != e2.dataType) { - throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}") - } - - val evalE1 = e1.eval(i: Row) - if(evalE1 == null) { - null - } else { - val evalE2 = e2.eval(i: Row) - if (evalE2 == null) { - null - } else { - e1.dataType match { - case ft: FractionalType => - f.asInstanceOf[(Fractional[ft.JvmType], ft.JvmType, ft.JvmType) => ft.JvmType]( - ft.fractional, evalE1.asInstanceOf[ft.JvmType], evalE2.asInstanceOf[ft.JvmType]) - case other => sys.error(s"Type $other does not support fractional operations") - } - } - } - } - - /** - * Evaluation helper function for 1 Fractional children expression. - * if the expression result is null, the evaluation result should be null. - */ - @inline - protected final def f1(i: Row, e1: Expression, f: ((Fractional[Any], Any) => Any)): Any = { - val evalE1 = e1.eval(i: Row) - if(evalE1 == null) { - null - } else { - e1.dataType match { - case ft: FractionalType => - f.asInstanceOf[(Fractional[ft.JvmType], ft.JvmType) => ft.JvmType]( - ft.fractional, evalE1.asInstanceOf[ft.JvmType]) - case other => sys.error(s"Type $other does not support fractional operations") - } - } - } - - /** - * Evaluation helper function for 2 Integral children expressions. Those expressions are - * supposed to be in the same data type, and also the return type. - * Either one of the expressions result is null, the evaluation result should be null. - */ - @inline - protected final def i2( - i: Row, - e1: Expression, - e2: Expression, - f: ((Integral[Any], Any, Any) => Any)): Any = { - if (e1.dataType != e2.dataType) { - throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}") - } - - val evalE1 = e1.eval(i) - if(evalE1 == null) { - null - } else { - val evalE2 = e2.eval(i) - if (evalE2 == null) { - null - } else { - e1.dataType match { - case i: IntegralType => - f.asInstanceOf[(Integral[i.JvmType], i.JvmType, i.JvmType) => i.JvmType]( - i.integral, evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType]) - case i: FractionalType => - f.asInstanceOf[(Integral[i.JvmType], i.JvmType, i.JvmType) => i.JvmType]( - i.asIntegral, evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType]) - case other => sys.error(s"Type $other does not support numeric operations") - } - } - } - } - - /** - * Evaluation helper function for 1 Integral children expression. - * if the expression result is null, the evaluation result should be null. - */ - @inline - protected final def i1(i: Row, e1: Expression, f: ((Integral[Any], Any) => Any)): Any = { - val evalE1 = e1.eval(i) - if(evalE1 == null) { - null - } else { - e1.dataType match { - case i: IntegralType => - f.asInstanceOf[(Integral[i.JvmType], i.JvmType) => i.JvmType]( - i.integral, evalE1.asInstanceOf[i.JvmType]) - case i: FractionalType => - f.asInstanceOf[(Integral[i.JvmType], i.JvmType) => i.JvmType]( - i.asIntegral, evalE1.asInstanceOf[i.JvmType]) - case other => sys.error(s"Type $other does not support numeric operations") - } - } - } - - /** - * Evaluation helper function for 2 Comparable children expressions. Those expressions are - * supposed to be in the same data type, and the return type should be Integer: - * Negative value: 1st argument less than 2nd argument - * Zero: 1st argument equals 2nd argument - * Positive value: 1st argument greater than 2nd argument - * - * Either one of the expressions result is null, the evaluation result should be null. - */ - @inline - protected final def c2( - i: Row, - e1: Expression, - e2: Expression, - f: ((Ordering[Any], Any, Any) => Any)): Any = { - if (e1.dataType != e2.dataType) { - throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}") - } - - val evalE1 = e1.eval(i) - if(evalE1 == null) { - null - } else { - val evalE2 = e2.eval(i) - if (evalE2 == null) { - null - } else { - e1.dataType match { - case i: NativeType => - f.asInstanceOf[(Ordering[i.JvmType], i.JvmType, i.JvmType) => Boolean]( - i.ordering, evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType]) - case other => sys.error(s"Type $other does not support ordered operations") - } - } - } + def prettyString: String = { + transform { + case a: AttributeReference => PrettyAttribute(a.name) + case u: UnresolvedAttribute => PrettyAttribute(u.name) + }.toString } } @@ -272,9 +84,9 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express def symbol: String - override def foldable = left.foldable && right.foldable + override def foldable: Boolean = left.foldable && right.foldable - override def toString = s"($left $symbol $right)" + override def toString: String = s"($left $symbol $right)" } abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { @@ -292,8 +104,8 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio case class GroupExpression(children: Seq[Expression]) extends Expression { self: Product => type EvaluatedType = Seq[Any] - override def eval(input: Row): EvaluatedType = ??? - override def nullable = false - override def foldable = false - override def dataType = ??? + override def eval(input: Row): EvaluatedType = throw new UnsupportedOperationException + override def nullable: Boolean = false + override def foldable: Boolean = false + override def dataType: DataType = throw new UnsupportedOperationException } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index db5d897ee569f..c2866cd955409 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -40,7 +40,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { new GenericRow(outputArray) } - override def toString = s"Row => [${exprArray.mkString(",")}]" + override def toString: String = s"Row => [${exprArray.mkString(",")}]" } /** @@ -107,12 +107,12 @@ class JoinedRow extends Row { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length = row1.length + row2.length + override def length: Int = row1.length + row2.length - override def apply(i: Int) = + override def apply(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) - override def isNullAt(i: Int) = + override def isNullAt(i: Int): Boolean = if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) override def getInt(i: Int): Int = @@ -142,7 +142,7 @@ class JoinedRow extends Row { override def getAs[T](i: Int): T = if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy() = { + override def copy(): Row = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -153,7 +153,7 @@ class JoinedRow extends Row { new GenericRow(copiedValues) } - override def toString() = { + override def toString: String = { // Make sure toString never throws NullPointerException. if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" @@ -207,12 +207,12 @@ class JoinedRow2 extends Row { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length = row1.length + row2.length + override def length: Int = row1.length + row2.length - override def apply(i: Int) = + override def apply(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) - override def isNullAt(i: Int) = + override def isNullAt(i: Int): Boolean = if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) override def getInt(i: Int): Int = @@ -242,7 +242,7 @@ class JoinedRow2 extends Row { override def getAs[T](i: Int): T = if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy() = { + override def copy(): Row = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -253,7 +253,7 @@ class JoinedRow2 extends Row { new GenericRow(copiedValues) } - override def toString() = { + override def toString: String = { // Make sure toString never throws NullPointerException. if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" @@ -301,12 +301,12 @@ class JoinedRow3 extends Row { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length = row1.length + row2.length + override def length: Int = row1.length + row2.length - override def apply(i: Int) = + override def apply(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) - override def isNullAt(i: Int) = + override def isNullAt(i: Int): Boolean = if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) override def getInt(i: Int): Int = @@ -336,7 +336,7 @@ class JoinedRow3 extends Row { override def getAs[T](i: Int): T = if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy() = { + override def copy(): Row = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -347,7 +347,7 @@ class JoinedRow3 extends Row { new GenericRow(copiedValues) } - override def toString() = { + override def toString: String = { // Make sure toString never throws NullPointerException. if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" @@ -395,12 +395,12 @@ class JoinedRow4 extends Row { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length = row1.length + row2.length + override def length: Int = row1.length + row2.length - override def apply(i: Int) = + override def apply(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) - override def isNullAt(i: Int) = + override def isNullAt(i: Int): Boolean = if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) override def getInt(i: Int): Int = @@ -430,7 +430,7 @@ class JoinedRow4 extends Row { override def getAs[T](i: Int): T = if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy() = { + override def copy(): Row = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -441,7 +441,7 @@ class JoinedRow4 extends Row { new GenericRow(copiedValues) } - override def toString() = { + override def toString: String = { // Make sure toString never throws NullPointerException. if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" @@ -489,12 +489,12 @@ class JoinedRow5 extends Row { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length = row1.length + row2.length + override def length: Int = row1.length + row2.length - override def apply(i: Int) = + override def apply(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) - override def isNullAt(i: Int) = + override def isNullAt(i: Int): Boolean = if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) override def getInt(i: Int): Int = @@ -524,7 +524,7 @@ class JoinedRow5 extends Row { override def getAs[T](i: Int): T = if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy() = { + override def copy(): Row = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -535,7 +535,7 @@ class JoinedRow5 extends Row { new GenericRow(copiedValues) } - override def toString() = { + override def toString: String = { // Make sure toString never throws NullPointerException. if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala index b2c6d3029031d..f5fea3f015dc4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala @@ -18,16 +18,19 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Random -import org.apache.spark.sql.types.DoubleType + +import org.apache.spark.sql.types.{DataType, DoubleType} case object Rand extends LeafExpression { - override def dataType = DoubleType - override def nullable = false + override def dataType: DataType = DoubleType + override def nullable: Boolean = false private[this] lazy val rand = new Random - override def eval(input: Row = null) = rand.nextDouble().asInstanceOf[EvaluatedType] + override def eval(input: Row = null): EvaluatedType = { + rand.nextDouble().asInstanceOf[EvaluatedType] + } - override def toString = "RAND()" + override def toString: String = "RAND()" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 8a36c6810790d..1fd5ce342b2ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -29,9 +29,9 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi type EvaluatedType = Any - def nullable = true + override def nullable: Boolean = true - override def toString = s"scalaUDF(${children.mkString(",")})" + override def toString: String = s"scalaUDF(${children.mkString(",")})" // scalastyle:off diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index d00b2ac09745c..83074eb1e6310 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.types.DataType abstract sealed class SortDirection case object Ascending extends SortDirection @@ -31,12 +32,12 @@ case object Descending extends SortDirection case class SortOrder(child: Expression, direction: SortDirection) extends Expression with trees.UnaryNode[Expression] { - override def dataType = child.dataType - override def nullable = child.nullable + override def dataType: DataType = child.dataType + override def nullable: Boolean = child.nullable // SortOrder itself is never evaluated. override def eval(input: Row = null): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" + override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 7434165f654f8..47b6f358ed1b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -62,126 +62,126 @@ abstract class MutableValue extends Serializable { var isNull: Boolean = true def boxed: Any def update(v: Any) - def copy(): this.type + def copy(): MutableValue } final class MutableInt extends MutableValue { var value: Int = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = { isNull = false - v.asInstanceOf[Int] + value = v.asInstanceOf[Int] } - def copy() = { + override def copy(): MutableInt = { val newCopy = new MutableInt newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableInt] } } final class MutableFloat extends MutableValue { var value: Float = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = { isNull = false - v.asInstanceOf[Float] + value = v.asInstanceOf[Float] } - def copy() = { + override def copy(): MutableFloat = { val newCopy = new MutableFloat newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableFloat] } } final class MutableBoolean extends MutableValue { var value: Boolean = false - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = { isNull = false - v.asInstanceOf[Boolean] + value = v.asInstanceOf[Boolean] } - def copy() = { + override def copy(): MutableBoolean = { val newCopy = new MutableBoolean newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableBoolean] } } final class MutableDouble extends MutableValue { var value: Double = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = { isNull = false - v.asInstanceOf[Double] + value = v.asInstanceOf[Double] } - def copy() = { + override def copy(): MutableDouble = { val newCopy = new MutableDouble newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableDouble] } } final class MutableShort extends MutableValue { var value: Short = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = value = { isNull = false v.asInstanceOf[Short] } - def copy() = { + override def copy(): MutableShort = { val newCopy = new MutableShort newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableShort] } } final class MutableLong extends MutableValue { var value: Long = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = value = { isNull = false v.asInstanceOf[Long] } - def copy() = { + override def copy(): MutableLong = { val newCopy = new MutableLong newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableLong] } } final class MutableByte extends MutableValue { var value: Byte = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = value = { isNull = false v.asInstanceOf[Byte] } - def copy() = { + override def copy(): MutableByte = { val newCopy = new MutableByte newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableByte] } } final class MutableAny extends MutableValue { var value: Any = _ - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = { isNull = false - v.asInstanceOf[Any] + value = v.asInstanceOf[Any] } - def copy() = { + override def copy(): MutableAny = { val newCopy = new MutableAny newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableAny] } } @@ -220,22 +220,23 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def isNullAt(i: Int): Boolean = values(i).isNull override def copy(): Row = { - val newValues = new Array[MutableValue](values.length) + val newValues = new Array[Any](values.length) var i = 0 while (i < values.length) { - newValues(i) = values(i).copy() + newValues(i) = values(i).boxed i += 1 } - new SpecificMutableRow(newValues) + + new GenericRow(newValues) } override def update(ordinal: Int, value: Any): Unit = { if (value == null) setNullAt(ordinal) else values(ordinal).update(value) } - override def setString(ordinal: Int, value: String) = update(ordinal, value) + override def setString(ordinal: Int, value: String): Unit = update(ordinal, value) - override def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String] + override def getString(ordinal: Int): String = apply(ordinal).asInstanceOf[String] override def setInt(ordinal: Int, value: Int): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableInt] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala old mode 100755 new mode 100644 index 735b7488fdcbd..ca0351ed0ed0e --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -79,27 +79,29 @@ abstract class AggregateFunction /** Base should return the generic aggregate expression that this function is computing */ val base: AggregateExpression - override def nullable = base.nullable - override def dataType = base.dataType + override def nullable: Boolean = base.nullable + override def dataType: DataType = base.dataType def update(input: Row): Unit // Do we really need this? - override def newInstance() = makeCopy(productIterator.map { case a: AnyRef => a }.toArray) + override def newInstance(): AggregateFunction = { + makeCopy(productIterator.map { case a: AnyRef => a }.toArray) + } } case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = true - override def dataType = child.dataType - override def toString = s"MIN($child)" + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"MIN($child)" override def asPartial: SplitEvaluation = { val partialMin = Alias(Min(child), "PartialMin")() SplitEvaluation(Min(partialMin.toAttribute), partialMin :: Nil) } - override def newInstance() = new MinFunction(child, this) + override def newInstance(): MinFunction = new MinFunction(child, this) } case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -121,16 +123,16 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = true - override def dataType = child.dataType - override def toString = s"MAX($child)" + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"MAX($child)" override def asPartial: SplitEvaluation = { val partialMax = Alias(Max(child), "PartialMax")() SplitEvaluation(Max(partialMax.toAttribute), partialMax :: Nil) } - override def newInstance() = new MaxFunction(child, this) + override def newInstance(): MaxFunction = new MaxFunction(child, this) } case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -152,29 +154,29 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = false - override def dataType = LongType - override def toString = s"COUNT($child)" + override def nullable: Boolean = false + override def dataType: LongType.type = LongType + override def toString: String = s"COUNT($child)" override def asPartial: SplitEvaluation = { val partialCount = Alias(Count(child), "PartialCount")() SplitEvaluation(Coalesce(Seq(Sum(partialCount.toAttribute), Literal(0L))), partialCount :: Nil) } - override def newInstance() = new CountFunction(child, this) + override def newInstance(): CountFunction = new CountFunction(child, this) } case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate { def this() = this(null) - override def children = expressions + override def children: Seq[Expression] = expressions - override def nullable = false - override def dataType = LongType - override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")})" - override def newInstance() = new CountDistinctFunction(expressions, this) + override def nullable: Boolean = false + override def dataType: DataType = LongType + override def toString: String = s"COUNT(DISTINCT ${expressions.mkString(",")})" + override def newInstance(): CountDistinctFunction = new CountDistinctFunction(expressions, this) - override def asPartial = { + override def asPartial: SplitEvaluation = { val partialSet = Alias(CollectHashSet(expressions), "partialSets")() SplitEvaluation( CombineSetsAndCount(partialSet.toAttribute), @@ -185,11 +187,11 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression { def this() = this(null) - override def children = expressions - override def nullable = false - override def dataType = ArrayType(expressions.head.dataType) - override def toString = s"AddToHashSet(${expressions.mkString(",")})" - override def newInstance() = new CollectHashSetFunction(expressions, this) + override def children: Seq[Expression] = expressions + override def nullable: Boolean = false + override def dataType: ArrayType = ArrayType(expressions.head.dataType) + override def toString: String = s"AddToHashSet(${expressions.mkString(",")})" + override def newInstance(): CollectHashSetFunction = new CollectHashSetFunction(expressions, this) } case class CollectHashSetFunction( @@ -219,11 +221,13 @@ case class CollectHashSetFunction( case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression { def this() = this(null) - override def children = inputSet :: Nil - override def nullable = false - override def dataType = LongType - override def toString = s"CombineAndCount($inputSet)" - override def newInstance() = new CombineSetsAndCountFunction(inputSet, this) + override def children: Seq[Expression] = inputSet :: Nil + override def nullable: Boolean = false + override def dataType: DataType = LongType + override def toString: String = s"CombineAndCount($inputSet)" + override def newInstance(): CombineSetsAndCountFunction = { + new CombineSetsAndCountFunction(inputSet, this) + } } case class CombineSetsAndCountFunction( @@ -249,27 +253,31 @@ case class CombineSetsAndCountFunction( case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) extends AggregateExpression with trees.UnaryNode[Expression] { - override def nullable = false - override def dataType = child.dataType - override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" - override def newInstance() = new ApproxCountDistinctPartitionFunction(child, this, relativeSD) + override def nullable: Boolean = false + override def dataType: DataType = child.dataType + override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" + override def newInstance(): ApproxCountDistinctPartitionFunction = { + new ApproxCountDistinctPartitionFunction(child, this, relativeSD) + } } case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) extends AggregateExpression with trees.UnaryNode[Expression] { - override def nullable = false - override def dataType = LongType - override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" - override def newInstance() = new ApproxCountDistinctMergeFunction(child, this, relativeSD) + override def nullable: Boolean = false + override def dataType: LongType.type = LongType + override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" + override def newInstance(): ApproxCountDistinctMergeFunction = { + new ApproxCountDistinctMergeFunction(child, this, relativeSD) + } } case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = false - override def dataType = LongType - override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" + override def nullable: Boolean = false + override def dataType: LongType.type = LongType + override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" override def asPartial: SplitEvaluation = { val partialCount = @@ -280,14 +288,14 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) partialCount :: Nil) } - override def newInstance() = new CountDistinctFunction(child :: Nil, this) + override def newInstance(): CountDistinctFunction = new CountDistinctFunction(child :: Nil, this) } case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = true + override def nullable: Boolean = true - override def dataType = child.dataType match { + override def dataType: DataType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType(precision + 4, scale + 4) // Add 4 digits after decimal point, like Hive case DecimalType.Unlimited => @@ -296,7 +304,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN DoubleType } - override def toString = s"AVG($child)" + override def toString: String = s"AVG($child)" override def asPartial: SplitEvaluation = { child.dataType match { @@ -323,14 +331,14 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN } } - override def newInstance() = new AverageFunction(child, this) + override def newInstance(): AverageFunction = new AverageFunction(child, this) } case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = true + override def nullable: Boolean = true - override def dataType = child.dataType match { + override def dataType: DataType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive case DecimalType.Unlimited => @@ -339,7 +347,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ child.dataType } - override def toString = s"SUM($child)" + override def toString: String = s"SUM($child)" override def asPartial: SplitEvaluation = { child.dataType match { @@ -357,15 +365,15 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ } } - override def newInstance() = new SumFunction(child, this) + override def newInstance(): SumFunction = new SumFunction(child, this) } case class SumDistinct(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { def this() = this(null) - override def nullable = true - override def dataType = child.dataType match { + override def nullable: Boolean = true + override def dataType: DataType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive case DecimalType.Unlimited => @@ -373,10 +381,10 @@ case class SumDistinct(child: Expression) case _ => child.dataType } - override def toString = s"SUM(DISTINCT ${child})" - override def newInstance() = new SumDistinctFunction(child, this) + override def toString: String = s"SUM(DISTINCT $child)" + override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this) - override def asPartial = { + override def asPartial: SplitEvaluation = { val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")() SplitEvaluation( CombineSetsAndSum(partialSet.toAttribute, this), @@ -387,11 +395,13 @@ case class SumDistinct(child: Expression) case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression { def this() = this(null, null) - override def children = inputSet :: Nil - override def nullable = true - override def dataType = base.dataType - override def toString = s"CombineAndSum($inputSet)" - override def newInstance() = new CombineSetsAndSumFunction(inputSet, this) + override def children: Seq[Expression] = inputSet :: Nil + override def nullable: Boolean = true + override def dataType: DataType = base.dataType + override def toString: String = s"CombineAndSum($inputSet)" + override def newInstance(): CombineSetsAndSumFunction = { + new CombineSetsAndSumFunction(inputSet, this) + } } case class CombineSetsAndSumFunction( @@ -425,9 +435,9 @@ case class CombineSetsAndSumFunction( } case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = true - override def dataType = child.dataType - override def toString = s"FIRST($child)" + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"FIRST($child)" override def asPartial: SplitEvaluation = { val partialFirst = Alias(First(child), "PartialFirst")() @@ -435,14 +445,14 @@ case class First(child: Expression) extends PartialAggregate with trees.UnaryNod First(partialFirst.toAttribute), partialFirst :: Nil) } - override def newInstance() = new FirstFunction(child, this) + override def newInstance(): FirstFunction = new FirstFunction(child, this) } case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references = child.references - override def nullable = true - override def dataType = child.dataType - override def toString = s"LAST($child)" + override def references: AttributeSet = child.references + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"LAST($child)" override def asPartial: SplitEvaluation = { val partialLast = Alias(Last(child), "PartialLast")() @@ -450,7 +460,7 @@ case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode Last(partialLast.toAttribute), partialLast :: Nil) } - override def newInstance() = new LastFunction(child, this) + override def newInstance(): LastFunction = new LastFunction(child, this) } case class AverageFunction(expr: Expression, base: AggregateExpression) @@ -651,6 +661,7 @@ case class LastFunction(expr: Expression, base: AggregateExpression) extends Agg result = input } - override def eval(input: Row): Any = if (result != null) expr.eval(result.asInstanceOf[Row]) - else null + override def eval(input: Row): Any = { + if (result != null) expr.eval(result.asInstanceOf[Row]) else null + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 574907f566c0f..1f6526ef66c56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -18,41 +18,53 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.types._ case class UnaryMinus(child: Expression) extends UnaryExpression { type EvaluatedType = Any - def dataType = child.dataType - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"-$child" + override def dataType: DataType = child.dataType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"-$child" + + lazy val numeric = dataType match { + case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] + case other => sys.error(s"Type $other does not support numeric operations") + } override def eval(input: Row): Any = { - n1(child, input, _.negate(_)) + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + numeric.negate(evalE) + } } } case class Sqrt(child: Expression) extends UnaryExpression { type EvaluatedType = Any - def dataType = DoubleType - override def foldable = child.foldable - def nullable = true - override def toString = s"SQRT($child)" + override def dataType: DataType = DoubleType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = true + override def toString: String = s"SQRT($child)" + + lazy val numeric = child.dataType match { + case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] + case other => sys.error(s"Type $other does not support non-negative numeric operations") + } override def eval(input: Row): Any = { val evalE = child.eval(input) if (evalE == null) { null } else { - child.dataType match { - case n: NumericType => - val value = n.numeric.toDouble(evalE.asInstanceOf[n.JvmType]) - if (value < 0) null - else math.sqrt(value) - case other => sys.error(s"Type $other does not support non-negative numeric operations") - } + val value = numeric.toDouble(evalE) + if (value < 0) null + else math.sqrt(value) } } } @@ -62,14 +74,14 @@ abstract class BinaryArithmetic extends BinaryExpression { type EvaluatedType = Any - def nullable = left.nullable || right.nullable + def nullable: Boolean = left.nullable || right.nullable override lazy val resolved = left.resolved && right.resolved && left.dataType == right.dataType && !DecimalType.isFixed(left.dataType) - def dataType = { + def dataType: DataType = { if (!resolved) { throw new UnresolvedException(this, s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") @@ -96,51 +108,122 @@ abstract class BinaryArithmetic extends BinaryExpression { } case class Add(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "+" + override def symbol: String = "+" + + lazy val numeric = dataType match { + case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] + case other => sys.error(s"Type $other does not support numeric operations") + } - override def eval(input: Row): Any = n2(input, left, right, _.plus(_, _)) + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if(evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + numeric.plus(evalE1, evalE2) + } + } + } } case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "-" + override def symbol: String = "-" - override def eval(input: Row): Any = n2(input, left, right, _.minus(_, _)) + lazy val numeric = dataType match { + case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] + case other => sys.error(s"Type $other does not support numeric operations") + } + + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if(evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + numeric.minus(evalE1, evalE2) + } + } + } } case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "*" + override def symbol: String = "*" + + lazy val numeric = dataType match { + case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] + case other => sys.error(s"Type $other does not support numeric operations") + } - override def eval(input: Row): Any = n2(input, left, right, _.times(_, _)) + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if(evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + numeric.times(evalE1, evalE2) + } + } + } } case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "/" + override def symbol: String = "/" - override def nullable = true + override def nullable: Boolean = true + lazy val div: (Any, Any) => Any = dataType match { + case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div + case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot + case other => sys.error(s"Type $other does not support numeric operations") + } + override def eval(input: Row): Any = { val evalE2 = right.eval(input) - dataType match { - case _ if evalE2 == null => null - case _ if evalE2 == 0 => null - case ft: FractionalType => f1(input, left, _.div(_, evalE2.asInstanceOf[ft.JvmType])) - case it: IntegralType => i1(input, left, _.quot(_, evalE2.asInstanceOf[it.JvmType])) + if (evalE2 == null || evalE2 == 0) { + null + } else { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + div(evalE1, evalE2) + } } } - } case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "%" + override def symbol: String = "%" + + override def nullable: Boolean = true - override def nullable = true + lazy val integral = dataType match { + case i: IntegralType => i.integral.asInstanceOf[Integral[Any]] + case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]] + case other => sys.error(s"Type $other does not support numeric operations") + } override def eval(input: Row): Any = { val evalE2 = right.eval(input) - dataType match { - case _ if evalE2 == null => null - case _ if evalE2 == 0 => null - case nt: NumericType => i1(input, left, _.rem(_, evalE2.asInstanceOf[nt.JvmType])) + if (evalE2 == null || evalE2 == 0) { + null + } else { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + integral.rem(evalE1, evalE2) + } } } } @@ -149,45 +232,63 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet * A function that calculates bitwise and(&) of two numbers. */ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "&" - - override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = dataType match { - case ByteType => (evalE1.asInstanceOf[Byte] & evalE2.asInstanceOf[Byte]).toByte - case ShortType => (evalE1.asInstanceOf[Short] & evalE2.asInstanceOf[Short]).toShort - case IntegerType => evalE1.asInstanceOf[Int] & evalE2.asInstanceOf[Int] - case LongType => evalE1.asInstanceOf[Long] & evalE2.asInstanceOf[Long] + override def symbol: String = "&" + + lazy val and: (Any, Any) => Any = dataType match { + case ByteType => + ((evalE1: Byte, evalE2: Byte) => (evalE1 & evalE2).toByte).asInstanceOf[(Any, Any) => Any] + case ShortType => + ((evalE1: Short, evalE2: Short) => (evalE1 & evalE2).toShort).asInstanceOf[(Any, Any) => Any] + case IntegerType => + ((evalE1: Int, evalE2: Int) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] + case LongType => + ((evalE1: Long, evalE2: Long) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] case other => sys.error(s"Unsupported bitwise & operation on $other") } + + override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = and(evalE1, evalE2) } /** * A function that calculates bitwise or(|) of two numbers. */ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "|" - - override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = dataType match { - case ByteType => (evalE1.asInstanceOf[Byte] | evalE2.asInstanceOf[Byte]).toByte - case ShortType => (evalE1.asInstanceOf[Short] | evalE2.asInstanceOf[Short]).toShort - case IntegerType => evalE1.asInstanceOf[Int] | evalE2.asInstanceOf[Int] - case LongType => evalE1.asInstanceOf[Long] | evalE2.asInstanceOf[Long] + override def symbol: String = "|" + + lazy val or: (Any, Any) => Any = dataType match { + case ByteType => + ((evalE1: Byte, evalE2: Byte) => (evalE1 | evalE2).toByte).asInstanceOf[(Any, Any) => Any] + case ShortType => + ((evalE1: Short, evalE2: Short) => (evalE1 | evalE2).toShort).asInstanceOf[(Any, Any) => Any] + case IntegerType => + ((evalE1: Int, evalE2: Int) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] + case LongType => + ((evalE1: Long, evalE2: Long) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] case other => sys.error(s"Unsupported bitwise | operation on $other") } + + override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = or(evalE1, evalE2) } /** * A function that calculates bitwise xor(^) of two numbers. */ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "^" - - override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = dataType match { - case ByteType => (evalE1.asInstanceOf[Byte] ^ evalE2.asInstanceOf[Byte]).toByte - case ShortType => (evalE1.asInstanceOf[Short] ^ evalE2.asInstanceOf[Short]).toShort - case IntegerType => evalE1.asInstanceOf[Int] ^ evalE2.asInstanceOf[Int] - case LongType => evalE1.asInstanceOf[Long] ^ evalE2.asInstanceOf[Long] + override def symbol: String = "^" + + lazy val xor: (Any, Any) => Any = dataType match { + case ByteType => + ((evalE1: Byte, evalE2: Byte) => (evalE1 ^ evalE2).toByte).asInstanceOf[(Any, Any) => Any] + case ShortType => + ((evalE1: Short, evalE2: Short) => (evalE1 ^ evalE2).toShort).asInstanceOf[(Any, Any) => Any] + case IntegerType => + ((evalE1: Int, evalE2: Int) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] + case LongType => + ((evalE1: Long, evalE2: Long) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] case other => sys.error(s"Unsupported bitwise ^ operation on $other") } + + override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = xor(evalE1, evalE2) } /** @@ -196,23 +297,29 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme case class BitwiseNot(child: Expression) extends UnaryExpression { type EvaluatedType = Any - def dataType = child.dataType - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"~$child" + override def dataType: DataType = child.dataType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"~$child" + + lazy val not: (Any) => Any = dataType match { + case ByteType => + ((evalE: Byte) => (~evalE).toByte).asInstanceOf[(Any) => Any] + case ShortType => + ((evalE: Short) => (~evalE).toShort).asInstanceOf[(Any) => Any] + case IntegerType => + ((evalE: Int) => ~evalE).asInstanceOf[(Any) => Any] + case LongType => + ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any] + case other => sys.error(s"Unsupported bitwise ~ operation on $other") + } override def eval(input: Row): Any = { val evalE = child.eval(input) if (evalE == null) { null } else { - dataType match { - case ByteType => (~evalE.asInstanceOf[Byte]).toByte - case ShortType => (~evalE.asInstanceOf[Short]).toShort - case IntegerType => ~evalE.asInstanceOf[Int] - case LongType => ~evalE.asInstanceOf[Long] - case other => sys.error(s"Unsupported bitwise ~ operation on $other") - } + not(evalE) } } } @@ -220,32 +327,46 @@ case class BitwiseNot(child: Expression) extends UnaryExpression { case class MaxOf(left: Expression, right: Expression) extends Expression { type EvaluatedType = Any - override def foldable = left.foldable && right.foldable + override def foldable: Boolean = left.foldable && right.foldable + + override def nullable: Boolean = left.nullable && right.nullable - override def nullable = left.nullable && right.nullable + override def children: Seq[Expression] = left :: right :: Nil + + override lazy val resolved = + left.resolved && right.resolved && + left.dataType == right.dataType - override def children = left :: right :: Nil + override def dataType: DataType = { + if (!resolved) { + throw new UnresolvedException(this, + s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + } + left.dataType + } - override def dataType = left.dataType + lazy val ordering = left.dataType match { + case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]] + case other => sys.error(s"Type $other does not support ordered operations") + } override def eval(input: Row): Any = { - val leftEval = left.eval(input) - val rightEval = right.eval(input) - if (leftEval == null) { - rightEval - } else if (rightEval == null) { - leftEval + val evalE1 = left.eval(input) + val evalE2 = right.eval(input) + if (evalE1 == null) { + evalE2 + } else if (evalE2 == null) { + evalE1 } else { - val numeric = left.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] - if (numeric.compare(leftEval, rightEval) < 0) { - rightEval + if (ordering.compare(evalE1, evalE2) < 0) { + evalE2 } else { - leftEval + evalE1 } } } - override def toString = s"MaxOf($left, $right)" + override def toString: String = s"MaxOf($left, $right)" } /** @@ -254,10 +375,22 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { case class Abs(child: Expression) extends UnaryExpression { type EvaluatedType = Any - def dataType = child.dataType - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"Abs($child)" + override def dataType: DataType = child.dataType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"Abs($child)" - override def eval(input: Row): Any = n1(child, input, _.abs(_)) + lazy val numeric = dataType match { + case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] + case other => sys.error(s"Type $other does not support numeric operations") + } + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + numeric.abs(evalE) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4cae5c4718683..d1abf3c0b64a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -91,7 +91,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val startTime = System.nanoTime() val result = create(in) val endTime = System.nanoTime() - def timeMs = (endTime - startTime).toDouble / 1000000 + def timeMs: Double = (endTime - startTime).toDouble / 1000000 logInfo(s"Code generated expression $in in $timeMs ms") result } @@ -121,7 +121,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * @param nullTerm A term that holds a boolean value representing whether the expression evaluated * to null. * @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not - * valid if `nullTerm` is set to `false`. + * valid if `nullTerm` is set to `true`. * @param objectTerm A possibly boxed version of the result of evaluating this expression. */ protected case class EvaluatedExpression( @@ -246,6 +246,9 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin new String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]]) """.children + case Cast(child @ DateType(), StringType) => + child.castOrNull(c => q"org.apache.spark.sql.types.DateUtils.toString($c)", StringType) + case Cast(child @ NumericType(), IntegerType) => child.castOrNull(c => q"$c.toInt", IntegerType) @@ -256,7 +259,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin child.castOrNull(c => q"$c.toDouble", DoubleType) case Cast(child @ NumericType(), FloatType) => - child.castOrNull(c => q"$c.toFloat", IntegerType) + child.castOrNull(c => q"$c.toFloat", FloatType) // Special handling required for timestamps in hive test cases since the toString function // does not match the expected output. @@ -623,7 +626,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin case FloatType => ru.Literal(Constant(-1.0.toFloat)) case StringType => ru.Literal(Constant("")) case ShortType => ru.Literal(Constant(-1.toShort)) - case LongType => ru.Literal(Constant(1L)) + case LongType => ru.Literal(Constant(-1L)) case ByteType => ru.Literal(Constant(-1.toByte)) case DoubleType => ru.Literal(Constant(-1.toDouble)) case DecimalType() => q"org.apache.spark.sql.types.Decimal(-1)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 1bc34f71441fe..3fd78db297462 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -27,12 +27,12 @@ import org.apache.spark.sql.types._ case class GetItem(child: Expression, ordinal: Expression) extends Expression { type EvaluatedType = Any - val children = child :: ordinal :: Nil + val children: Seq[Expression] = child :: ordinal :: Nil /** `Null` is returned for invalid ordinals. */ - override def nullable = true - override def foldable = child.foldable && ordinal.foldable + override def nullable: Boolean = true + override def foldable: Boolean = child.foldable && ordinal.foldable - def dataType = child.dataType match { + override def dataType: DataType = child.dataType match { case ArrayType(dt, _) => dt case MapType(_, vt, _) => vt } @@ -40,7 +40,7 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { childrenResolved && (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) - override def toString = s"$child[$ordinal]" + override def toString: String = s"$child[$ordinal]" override def eval(input: Row): Any = { val value = child.eval(input) @@ -70,42 +70,48 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { } } -/** - * Returns the value of fields in the Struct `child`. - */ -case class GetField(child: Expression, fieldName: String) extends UnaryExpression { - type EvaluatedType = Any - def dataType = field.dataType - override def nullable = child.nullable || field.nullable - override def foldable = child.foldable +trait GetField extends UnaryExpression { + self: Product => - protected def structType = child.dataType match { - case s: StructType => s - case otherType => sys.error(s"GetField is not valid on fields of type $otherType") - } + type EvaluatedType = Any + override def foldable: Boolean = child.foldable + override def toString: String = s"$child.${field.name}" - lazy val field = - structType.fields - .find(_.name == fieldName) - .getOrElse(sys.error(s"No such field $fieldName in ${child.dataType}")) + def field: StructField +} - lazy val ordinal = structType.fields.indexOf(field) +/** + * Returns the value of fields in the Struct `child`. + */ +case class StructGetField(child: Expression, field: StructField, ordinal: Int) extends GetField { - override lazy val resolved = childrenResolved && fieldResolved - - /** Returns true only if the fieldName is found in the child struct. */ - private def fieldResolved = child.dataType match { - case StructType(fields) => fields.map(_.name).contains(fieldName) - case _ => false - } + override def dataType: DataType = field.dataType + override def nullable: Boolean = child.nullable || field.nullable override def eval(input: Row): Any = { val baseValue = child.eval(input).asInstanceOf[Row] if (baseValue == null) null else baseValue(ordinal) } +} + +/** + * Returns the array of value of fields in the Array of Struct `child`. + */ +case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, containsNull: Boolean) + extends GetField { - override def toString = s"$child.$fieldName" + override def dataType: DataType = ArrayType(field.dataType, containsNull) + override def nullable: Boolean = child.nullable + + override def eval(input: Row): Any = { + val baseValue = child.eval(input).asInstanceOf[Seq[Row]] + if (baseValue == null) null else { + baseValue.map { row => + if (row == null) null else row(ordinal) + } + } + } } /** @@ -114,7 +120,7 @@ case class GetField(child: Expression, fieldName: String) extends UnaryExpressio case class CreateArray(children: Seq[Expression]) extends Expression { override type EvaluatedType = Any - override def foldable = !children.exists(!_.foldable) + override def foldable: Boolean = !children.exists(!_.foldable) lazy val childTypes = children.map(_.dataType).distinct @@ -134,5 +140,5 @@ case class CreateArray(children: Seq[Expression]) extends Expression { children.map(_.eval(input)) } - override def toString = s"Array(${children.mkString(",")})" + override def toString: String = s"Array(${children.mkString(",")})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index 83d8c1d42bca4..adb94df7d1c7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -24,9 +24,9 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { override type EvaluatedType = Any override def dataType: DataType = LongType - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"UnscaledValue($child)" + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"UnscaledValue($child)" override def eval(input: Row): Any = { val childResult = child.eval(input) @@ -43,9 +43,9 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un override type EvaluatedType = Decimal override def dataType: DataType = DecimalType(precision, scale) - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"MakeDecimal($child,$precision,$scale)" + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"MakeDecimal($child,$precision,$scale)" override def eval(input: Row): Decimal = { val childResult = child.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 43b6482c0171c..860b72fad38b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -45,7 +45,7 @@ abstract class Generator extends Expression { override lazy val dataType = ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))) - override def nullable = false + override def nullable: Boolean = false /** * Should be overridden by specific generators. Called only once for each instance to ensure @@ -73,6 +73,25 @@ abstract class Generator extends Expression { } } +/** + * A generator that produces its output using the provided lambda function. + */ +case class UserDefinedGenerator( + schema: Seq[Attribute], + function: Row => TraversableOnce[Row], + children: Seq[Expression]) + extends Generator{ + + override protected def makeOutput(): Seq[Attribute] = schema + + override def eval(input: Row): TraversableOnce[Row] = { + val inputRow = new InterpretedProjection(children) + function(inputRow(input)) + } + + override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})" +} + /** * Given an input array produces a sequence of rows for each value in the array. */ @@ -111,5 +130,5 @@ case class Explode(attributeNames: Seq[String], child: Expression) } } - override def toString() = s"explode($child)" + override def toString: String = s"explode($child)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 5b389aad7a85d..19f3fc9c2291a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -35,9 +35,11 @@ object Literal { case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) case d: Decimal => Literal(d, DecimalType.Unlimited) case t: Timestamp => Literal(t, TimestampType) - case d: Date => Literal(d, DateType) + case d: Date => Literal(DateUtils.fromJavaDate(d), DateType) case a: Array[Byte] => Literal(a, BinaryType) case null => Literal(null, NullType) + case _ => + throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) } } @@ -62,14 +64,13 @@ object IntegerLiteral { case class Literal(value: Any, dataType: DataType) extends LeafExpression { - override def foldable = true - def nullable = value == null + override def foldable: Boolean = true + override def nullable: Boolean = value == null - - override def toString = if (value != null) value.toString else "null" + override def toString: String = if (value != null) value.toString else "null" type EvaluatedType = Any - override def eval(input: Row):Any = value + override def eval(input: Row): Any = value } // TODO: Specialize @@ -77,9 +78,9 @@ case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean extends LeafExpression { type EvaluatedType = Any - def update(expression: Expression, input: Row) = { + def update(expression: Expression, input: Row): Unit = { value = expression.eval(input) } - override def eval(input: Row) = value + override def eval(input: Row): Any = value } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index f388cd5972bac..f70753127d43a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.trees.LeafNode import org.apache.spark.sql.types._ object NamedExpression { private val curId = new java.util.concurrent.atomic.AtomicLong() - def newExprId = ExprId(curId.getAndIncrement()) + def newExprId: ExprId = ExprId(curId.getAndIncrement()) def unapply(expr: NamedExpression): Option[(String, DataType)] = Some(expr.name, expr.dataType) } @@ -40,6 +41,24 @@ abstract class NamedExpression extends Expression { def name: String def exprId: ExprId + + /** + * Returns a dot separated fully qualified name for this attribute. Given that there can be + * multiple qualifiers, it is possible that there are other possible way to refer to this + * attribute. + */ + def qualifiedName: String = (qualifiers.headOption.toSeq :+ name).mkString(".") + + /** + * All possible qualifiers for the expression. + * + * For now, since we do not allow using original table name to qualify a column name once the + * table is aliased, this can only be: + * + * 1. Empty Seq: when an attribute doesn't have a qualifier, + * e.g. top level attributes aliased in the SELECT clause, or column from a LocalRelation. + * 2. Single element: either the table name or the alias name of the table. + */ def qualifiers: Seq[String] def toAttribute: Attribute @@ -61,13 +80,13 @@ abstract class NamedExpression extends Expression { abstract class Attribute extends NamedExpression { self: Product => - override def references = AttributeSet(this) + override def references: AttributeSet = AttributeSet(this) def withNullability(newNullability: Boolean): Attribute def withQualifiers(newQualifiers: Seq[String]): Attribute def withName(newName: String): Attribute - def toAttribute = this + def toAttribute: Attribute = this def newInstance(): Attribute } @@ -75,7 +94,7 @@ abstract class Attribute extends NamedExpression { /** * Used to assign a new name to a computation. * For example the SQL expression "1 + 1 AS a" could be represented as follows: - * Alias(Add(Literal(1), Literal(1), "a")() + * Alias(Add(Literal(1), Literal(1)), "a")() * * Note that exprId and qualifiers are in a separate parameter list because * we only pattern match on child and name. @@ -91,10 +110,10 @@ case class Alias(child: Expression, name: String) override type EvaluatedType = Any - override def eval(input: Row) = child.eval(input) + override def eval(input: Row): Any = child.eval(input) - override def dataType = child.dataType - override def nullable = child.nullable + override def dataType: DataType = child.dataType + override def nullable: Boolean = child.nullable override def metadata: Metadata = { child match { case named: NamedExpression => named.metadata @@ -102,7 +121,7 @@ case class Alias(child: Expression, name: String) } } - override def toAttribute = { + override def toAttribute: Attribute = { if (resolved) { AttributeReference(name, child.dataType, child.nullable, metadata)(exprId, qualifiers) } else { @@ -112,7 +131,13 @@ case class Alias(child: Expression, name: String) override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix" - override protected final def otherCopyArgs = exprId :: qualifiers :: Nil + override protected final def otherCopyArgs: Seq[AnyRef] = exprId :: qualifiers :: Nil + + override def equals(other: Any): Boolean = other match { + case a: Alias => + name == a.name && exprId == a.exprId && child == a.child && qualifiers == a.qualifiers + case _ => false + } } /** @@ -136,7 +161,7 @@ case class AttributeReference( val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil) extends Attribute with trees.LeafNode[Expression] { - override def equals(other: Any) = other match { + override def equals(other: Any): Boolean = other match { case ar: AttributeReference => name == ar.name && exprId == ar.exprId && dataType == ar.dataType case _ => false } @@ -150,7 +175,7 @@ case class AttributeReference( h } - override def newInstance() = + override def newInstance(): AttributeReference = AttributeReference(name, dataType, nullable, metadata)(qualifiers = qualifiers) /** @@ -175,7 +200,7 @@ case class AttributeReference( /** * Returns a copy of this [[AttributeReference]] with new qualifiers. */ - override def withQualifiers(newQualifiers: Seq[String]) = { + override def withQualifiers(newQualifiers: Seq[String]): AttributeReference = { if (newQualifiers.toSet == qualifiers.toSet) { this } else { @@ -190,7 +215,29 @@ case class AttributeReference( override def toString: String = s"$name#${exprId.id}$typeSuffix" } +/** + * A place holder used when printing expressions without debugging information such as the + * expression id or the unresolved indicator. + */ +case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[Expression] { + type EvaluatedType = Any + + override def toString: String = name + + override def withNullability(newNullability: Boolean): Attribute = + throw new UnsupportedOperationException + override def newInstance(): Attribute = throw new UnsupportedOperationException + override def withQualifiers(newQualifiers: Seq[String]): Attribute = + throw new UnsupportedOperationException + override def withName(newName: String): Attribute = throw new UnsupportedOperationException + override def qualifiers: Seq[String] = throw new UnsupportedOperationException + override def exprId: ExprId = throw new UnsupportedOperationException + override def eval(input: Row): EvaluatedType = throw new UnsupportedOperationException + override def nullable: Boolean = throw new UnsupportedOperationException + override def dataType: DataType = NullType +} + object VirtualColumn { - val groupingIdName = "grouping__id" - def newGroupingId = AttributeReference(groupingIdName, IntegerType, false)() + val groupingIdName: String = "grouping__id" + def newGroupingId: AttributeReference = AttributeReference(groupingIdName, IntegerType, false)() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 08b982bc671e7..f9161cf34f0c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -19,22 +19,23 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.types.DataType case class Coalesce(children: Seq[Expression]) extends Expression { type EvaluatedType = Any /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ - def nullable = !children.exists(!_.nullable) + override def nullable: Boolean = !children.exists(!_.nullable) // Coalesce is foldable if all children are foldable. - override def foldable = !children.exists(!_.foldable) + override def foldable: Boolean = !children.exists(!_.foldable) // Only resolved if all the children are of the same type. override lazy val resolved = childrenResolved && (children.map(_.dataType).distinct.size == 1) - override def toString = s"Coalesce(${children.mkString(",")})" + override def toString: String = s"Coalesce(${children.mkString(",")})" - def dataType = if (resolved) { + override def dataType: DataType = if (resolved) { children.head.dataType } else { val childTypes = children.map(c => s"$c: ${c.dataType}").mkString(", ") @@ -54,22 +55,45 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { - override def foldable = child.foldable - def nullable = false + override def foldable: Boolean = child.foldable + override def nullable: Boolean = false override def eval(input: Row): Any = { child.eval(input) == null } - override def toString = s"IS NULL $child" + override def toString: String = s"IS NULL $child" } case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { - override def foldable = child.foldable - def nullable = false - override def toString = s"IS NOT NULL $child" + override def foldable: Boolean = child.foldable + override def nullable: Boolean = false + override def toString: String = s"IS NOT NULL $child" override def eval(input: Row): Any = { child.eval(input) != null } } + +/** + * A predicate that is evaluated to be true if there are at least `n` non-null values. + */ +case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate { + override def nullable: Boolean = false + override def foldable: Boolean = false + override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})" + + private[this] val childrenArray = children.toArray + + override def eval(input: Row): Boolean = { + var numNonNulls = 0 + var i = 0 + while (i < childrenArray.length && numNonNulls < n) { + if (childrenArray(i).eval(input) != null) { + numNonNulls += 1 + } + i += 1 + } + numNonNulls >= n + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index c84cc95520a19..7e47cb3fffe12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.BooleanType +import org.apache.spark.sql.types.{DataType, BinaryType, BooleanType, NativeType} object InterpretedPredicate { def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = @@ -33,7 +34,7 @@ object InterpretedPredicate { trait Predicate extends Expression { self: Product => - def dataType = BooleanType + override def dataType: DataType = BooleanType type EvaluatedType = Any } @@ -71,13 +72,13 @@ trait PredicateHelper { abstract class BinaryPredicate extends BinaryExpression with Predicate { self: Product => - def nullable = left.nullable || right.nullable + override def nullable: Boolean = left.nullable || right.nullable } case class Not(child: Expression) extends UnaryExpression with Predicate { - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"NOT $child" + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"NOT $child" override def eval(input: Row): Any = { child.eval(input) match { @@ -91,10 +92,10 @@ case class Not(child: Expression) extends UnaryExpression with Predicate { * Evaluates to `true` if `list` contains `value`. */ case class In(value: Expression, list: Seq[Expression]) extends Predicate { - def children = value +: list + override def children: Seq[Expression] = value +: list - def nullable = true // TODO: Figure out correct nullability semantics of IN. - override def toString = s"$value IN ${list.mkString("(", ",", ")")}" + override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. + override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}" override def eval(input: Row): Any = { val evaluatedValue = value.eval(input) @@ -109,10 +110,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { case class InSet(value: Expression, hset: Set[Any]) extends Predicate { - def children = value :: Nil + override def children: Seq[Expression] = value :: Nil - def nullable = true // TODO: Figure out correct nullability semantics of IN. - override def toString = s"$value INSET ${hset.mkString("(", ",", ")")}" + override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. + override def toString: String = s"$value INSET ${hset.mkString("(", ",", ")")}" override def eval(input: Row): Any = { hset.contains(value.eval(input)) @@ -120,7 +121,7 @@ case class InSet(value: Expression, hset: Set[Any]) } case class And(left: Expression, right: Expression) extends BinaryPredicate { - def symbol = "&&" + override def symbol: String = "&&" override def eval(input: Row): Any = { val l = left.eval(input) @@ -142,7 +143,7 @@ case class And(left: Expression, right: Expression) extends BinaryPredicate { } case class Or(left: Expression, right: Expression) extends BinaryPredicate { - def symbol = "||" + override def symbol: String = "||" override def eval(input: Row): Any = { val l = left.eval(input) @@ -168,21 +169,27 @@ abstract class BinaryComparison extends BinaryPredicate { } case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { - def symbol = "=" + override def symbol: String = "=" + override def eval(input: Row): Any = { val l = left.eval(input) if (l == null) { null } else { val r = right.eval(input) - if (r == null) null else l == r + if (r == null) null + else if (left.dataType != BinaryType) l == r + else BinaryType.ordering.compare( + l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) == 0 } } } case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { - def symbol = "<=>" - override def nullable = false + override def symbol: String = "<=>" + + override def nullable: Boolean = false + override def eval(input: Row): Any = { val l = left.eval(input) val r = right.eval(input) @@ -197,33 +204,129 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } case class LessThan(left: Expression, right: Expression) extends BinaryComparison { - def symbol = "<" - override def eval(input: Row): Any = c2(input, left, right, _.lt(_, _)) + override def symbol: String = "<" + + lazy val ordering: Ordering[Any] = { + if (left.dataType != right.dataType) { + throw new TreeNodeException(this, + s"Types do not match ${left.dataType} != ${right.dataType}") + } + left.dataType match { + case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]] + case other => sys.error(s"Type $other does not support ordered operations") + } + } + + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + ordering.lt(evalE1, evalE2) + } + } + } } case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { - def symbol = "<=" - override def eval(input: Row): Any = c2(input, left, right, _.lteq(_, _)) + override def symbol: String = "<=" + + lazy val ordering: Ordering[Any] = { + if (left.dataType != right.dataType) { + throw new TreeNodeException(this, + s"Types do not match ${left.dataType} != ${right.dataType}") + } + left.dataType match { + case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]] + case other => sys.error(s"Type $other does not support ordered operations") + } + } + + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + ordering.lteq(evalE1, evalE2) + } + } + } } case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { - def symbol = ">" - override def eval(input: Row): Any = c2(input, left, right, _.gt(_, _)) + override def symbol: String = ">" + + lazy val ordering: Ordering[Any] = { + if (left.dataType != right.dataType) { + throw new TreeNodeException(this, + s"Types do not match ${left.dataType} != ${right.dataType}") + } + left.dataType match { + case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]] + case other => sys.error(s"Type $other does not support ordered operations") + } + } + + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if(evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + ordering.gt(evalE1, evalE2) + } + } + } } case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { - def symbol = ">=" - override def eval(input: Row): Any = c2(input, left, right, _.gteq(_, _)) + override def symbol: String = ">=" + + lazy val ordering: Ordering[Any] = { + if (left.dataType != right.dataType) { + throw new TreeNodeException(this, + s"Types do not match ${left.dataType} != ${right.dataType}") + } + left.dataType match { + case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]] + case other => sys.error(s"Type $other does not support ordered operations") + } + } + + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + ordering.gteq(evalE1, evalE2) + } + } + } } case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) - extends Expression { + extends Expression { - def children = predicate :: trueValue :: falseValue :: Nil - override def nullable = trueValue.nullable || falseValue.nullable + override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil + override def nullable: Boolean = trueValue.nullable || falseValue.nullable override lazy val resolved = childrenResolved && trueValue.dataType == falseValue.dataType - def dataType = { + override def dataType: DataType = { if (!resolved) { throw new UnresolvedException( this, @@ -242,7 +345,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } } - override def toString = s"if ($predicate) $trueValue else $falseValue" + override def toString: String = s"if ($predicate) $trueValue else $falseValue" } // scalastyle:off @@ -262,9 +365,10 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi // scalastyle:on case class CaseWhen(branches: Seq[Expression]) extends Expression { type EvaluatedType = Any - def children = branches - def dataType = { + override def children: Seq[Expression] = branches + + override def dataType: DataType = { if (!resolved) { throw new UnresolvedException(this, "cannot resolve due to differing types in some branches") } @@ -279,12 +383,12 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression { @transient private[this] lazy val elseValue = if (branches.length % 2 == 0) None else Option(branches.last) - override def nullable = { + override def nullable: Boolean = { // If no value is nullable and no elseValue is provided, the whole statement defaults to null. values.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) } - override lazy val resolved = { + override lazy val resolved: Boolean = { if (!childrenResolved) { false } else { @@ -315,7 +419,7 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression { res } - override def toString = { + override def toString: String = { "CASE" + branches.sliding(2, 2).map { case Seq(cond, value) => s" WHEN $cond THEN $value" case Seq(elseValue) => s" ELSE $elseValue" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 73ec7a6d114f5..a8983df208318 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.types.NativeType +import org.apache.spark.sql.types.{StructType, NativeType} /** @@ -44,8 +44,8 @@ trait MutableRow extends Row { */ object EmptyRow extends Row { override def apply(i: Int): Any = throw new UnsupportedOperationException - override def toSeq = Seq.empty - override def length = 0 + override def toSeq: Seq[Any] = Seq.empty + override def length: Int = 0 override def isNullAt(i: Int): Boolean = throw new UnsupportedOperationException override def getInt(i: Int): Int = throw new UnsupportedOperationException override def getLong(i: Int): Long = throw new UnsupportedOperationException @@ -56,7 +56,7 @@ object EmptyRow extends Row { override def getByte(i: Int): Byte = throw new UnsupportedOperationException override def getString(i: Int): String = throw new UnsupportedOperationException override def getAs[T](i: Int): T = throw new UnsupportedOperationException - def copy() = this + override def copy(): Row = this } /** @@ -66,17 +66,17 @@ object EmptyRow extends Row { */ class GenericRow(protected[sql] val values: Array[Any]) extends Row { /** No-arg constructor for serialization. */ - def this() = this(null) + protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) - override def toSeq = values.toSeq + override def toSeq: Seq[Any] = values.toSeq - override def length = values.length + override def length: Int = values.length - override def apply(i: Int) = values(i) + override def apply(i: Int): Any = values(i) - override def isNullAt(i: Int) = values(i) == null + override def isNullAt(i: Int): Boolean = values(i) == null override def getInt(i: Int): Int = { if (values(i) == null) sys.error("Failed to check null bit for primitive int value.") @@ -146,12 +146,40 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { result } - def copy() = this + override def equals(o: Any): Boolean = o match { + case other: Row => + if (values.length != other.length) { + return false + } + + var i = 0 + while (i < values.length) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (apply(i) != other.apply(i)) { + return false + } + i += 1 + } + true + + case _ => false + } + + override def copy(): Row = this +} + +class GenericRowWithSchema(values: Array[Any], override val schema: StructType) + extends GenericRow(values) { + + /** No-arg constructor for serialization. */ + protected def this() = this(null, null) } class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { /** No-arg constructor for serialization. */ - def this() = this(null) + protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) @@ -169,7 +197,7 @@ class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { override def update(ordinal: Int, value: Any): Unit = { values(ordinal) = value } - override def copy() = new GenericRow(values.clone()) + override def copy(): Row = new GenericRow(values.clone()) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 3a5bdca1f07c3..35faa00782e80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -26,17 +26,17 @@ import org.apache.spark.util.collection.OpenHashSet case class NewSet(elementType: DataType) extends LeafExpression { type EvaluatedType = Any - def nullable = false + override def nullable: Boolean = false // We are currently only using these Expressions internally for aggregation. However, if we ever // expose these to users we'll want to create a proper type instead of hijacking ArrayType. - def dataType = ArrayType(elementType) + override def dataType: DataType = ArrayType(elementType) - def eval(input: Row): Any = { + override def eval(input: Row): Any = { new OpenHashSet[Any]() } - override def toString = s"new Set($dataType)" + override def toString: String = s"new Set($dataType)" } /** @@ -46,12 +46,13 @@ case class NewSet(elementType: DataType) extends LeafExpression { case class AddItemToSet(item: Expression, set: Expression) extends Expression { type EvaluatedType = Any - def children = item :: set :: Nil + override def children: Seq[Expression] = item :: set :: Nil - def nullable = set.nullable + override def nullable: Boolean = set.nullable - def dataType = set.dataType - def eval(input: Row): Any = { + override def dataType: DataType = set.dataType + + override def eval(input: Row): Any = { val itemEval = item.eval(input) val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]] @@ -67,7 +68,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { } } - override def toString = s"$set += $item" + override def toString: String = s"$set += $item" } /** @@ -77,13 +78,13 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { case class CombineSets(left: Expression, right: Expression) extends BinaryExpression { type EvaluatedType = Any - def nullable = left.nullable || right.nullable + override def nullable: Boolean = left.nullable || right.nullable - def dataType = left.dataType + override def dataType: DataType = left.dataType - def symbol = "++=" + override def symbol: String = "++=" - def eval(input: Row): Any = { + override def eval(input: Row): Any = { val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]] if(leftEval != null) { val rightEval = right.eval(input).asInstanceOf[OpenHashSet[Any]] @@ -109,16 +110,16 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres case class CountSet(child: Expression) extends UnaryExpression { type EvaluatedType = Any - def nullable = child.nullable + override def nullable: Boolean = child.nullable - def dataType = LongType + override def dataType: DataType = LongType - def eval(input: Row): Any = { + override def eval(input: Row): Any = { val childEval = child.eval(input).asInstanceOf[OpenHashSet[Any]] if (childEval != null) { childEval.size.toLong } } - override def toString = s"$child.count()" + override def toString: String = s"$child.count()" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index f85ee0a9bb6d8..3cdca4e9dd2d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -33,8 +33,8 @@ trait StringRegexExpression { def escape(v: String): String def matches(regex: Pattern, str: String): Boolean - def nullable: Boolean = left.nullable || right.nullable - def dataType: DataType = BooleanType + override def nullable: Boolean = left.nullable || right.nullable + override def dataType: DataType = BooleanType // try cache the pattern for Literal private lazy val cache: Pattern = right match { @@ -98,11 +98,11 @@ trait CaseConversionExpression { case class Like(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { - def symbol = "LIKE" + override def symbol: String = "LIKE" // replace the _ with .{1} exactly match 1 time of any character // replace the % with .*, match 0 or more times with any character - override def escape(v: String) = + override def escape(v: String): String = if (!v.isEmpty) { "(?s)" + (' ' +: v.init).zip(v).flatMap { case (prev, '\\') => "" @@ -129,7 +129,7 @@ case class Like(left: Expression, right: Expression) case class RLike(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { - def symbol = "RLIKE" + override def symbol: String = "RLIKE" override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) } @@ -141,7 +141,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE override def convert(v: String): String = v.toUpperCase() - override def toString() = s"Upper($child)" + override def toString: String = s"Upper($child)" } /** @@ -151,7 +151,7 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE override def convert(v: String): String = v.toLowerCase() - override def toString() = s"Lower($child)" + override def toString: String = s"Lower($child)" } /** A base trait for functions that compare two strings, returning a boolean. */ @@ -160,7 +160,7 @@ trait StringComparison { type EvaluatedType = Any - def nullable: Boolean = left.nullable || right.nullable + override def nullable: Boolean = left.nullable || right.nullable override def dataType: DataType = BooleanType def compare(l: String, r: String): Boolean @@ -175,9 +175,9 @@ trait StringComparison { } } - def symbol: String = nodeName + override def symbol: String = nodeName - override def toString() = s"$nodeName($left, $right)" + override def toString: String = s"$nodeName($left, $right)" } /** @@ -185,7 +185,7 @@ trait StringComparison { */ case class Contains(left: Expression, right: Expression) extends BinaryExpression with StringComparison { - override def compare(l: String, r: String) = l.contains(r) + override def compare(l: String, r: String): Boolean = l.contains(r) } /** @@ -193,7 +193,7 @@ case class Contains(left: Expression, right: Expression) */ case class StartsWith(left: Expression, right: Expression) extends BinaryExpression with StringComparison { - def compare(l: String, r: String) = l.startsWith(r) + override def compare(l: String, r: String): Boolean = l.startsWith(r) } /** @@ -201,7 +201,7 @@ case class StartsWith(left: Expression, right: Expression) */ case class EndsWith(left: Expression, right: Expression) extends BinaryExpression with StringComparison { - def compare(l: String, r: String) = l.endsWith(r) + override def compare(l: String, r: String): Boolean = l.endsWith(r) } /** @@ -212,17 +212,17 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends type EvaluatedType = Any - override def foldable = str.foldable && pos.foldable && len.foldable + override def foldable: Boolean = str.foldable && pos.foldable && len.foldable - def nullable: Boolean = str.nullable || pos.nullable || len.nullable - def dataType: DataType = { + override def nullable: Boolean = str.nullable || pos.nullable || len.nullable + override def dataType: DataType = { if (!resolved) { throw new UnresolvedException(this, s"Cannot resolve since $children are not resolved") } if (str.dataType == BinaryType) str.dataType else StringType } - override def children = str :: pos :: len :: Nil + override def children: Seq[Expression] = str :: pos :: len :: Nil @inline def slice[T, C <: Any](str: C, startPos: Int, sliceLen: Int) @@ -267,7 +267,8 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends } } - override def toString = len match { + override def toString: String = len match { + // TODO: This is broken because max is not an integer value. case max if max == Integer.MAX_VALUE => s"SUBSTR($str, $pos)" case _ => s"SUBSTR($str, $pos, $len)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 376a9f36568a7..c23d3b61887c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.FullOuter @@ -32,6 +33,9 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] object DefaultOptimizer extends Optimizer { val batches = + // SubQueries are only needed for analysis and can be removed before execution. + Batch("Remove SubQueries", FixedPoint(100), + EliminateSubQueries) :: Batch("Combine Limits", FixedPoint(100), CombineLimits) :: Batch("ConstantFolding", FixedPoint(100), @@ -50,7 +54,10 @@ object DefaultOptimizer extends Optimizer { CombineFilters, PushPredicateThroughProject, PushPredicateThroughJoin, - ColumnPruning) :: Nil + PushPredicateThroughGenerate, + ColumnPruning) :: + Batch("LocalRelation", FixedPoint(100), + ConvertToLocalRelation) :: Nil } /** @@ -116,6 +123,15 @@ object ColumnPruning extends Rule[LogicalPlan] { case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = Project(a.references.toSeq, child)) + case p @ Project(projectList, a @ Aggregate(groupingExpressions, aggregateExpressions, child)) + if (a.outputSet -- p.references).nonEmpty => + Project( + projectList, + Aggregate( + groupingExpressions, + aggregateExpressions.filter(e => p.references.contains(e)), + child)) + // Eliminate unneeded attributes from either side of a Join. case Project(projectList, Join(left, right, joinType, condition)) => // Collect the list of all references required either above or to evaluate the condition. @@ -125,7 +141,7 @@ object ColumnPruning extends Rule[LogicalPlan] { condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) /** Applies a projection only when the child is producing unnecessary attributes */ - def pruneJoinChild(c: LogicalPlan) = prunedChild(c, allReferences) + def pruneJoinChild(c: LogicalPlan): LogicalPlan = prunedChild(c, allReferences) Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition)) @@ -206,7 +222,8 @@ object NullPropagation extends Rule[LogicalPlan] { case e @ IsNotNull(c) if !c.nullable => Literal(true, BooleanType) case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType) case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType) - case e @ GetField(Literal(null, _), _) => Literal(null, e.dataType) + case e @ StructGetField(Literal(null, _), _, _) => Literal(null, e.dataType) + case e @ ArrayGetField(Literal(null, _), _, _, _) => Literal(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) case e @ Count(expr) if !expr.nullable => Count(Literal(1)) @@ -453,6 +470,30 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] { } } +/** + * Push [[Filter]] operators through [[Generate]] operators. Parts of the predicate that reference + * attributes generated in [[Generate]] will remain above, and the rest should be pushed beneath. + */ +object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelper { + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case filter @ Filter(condition, + generate @ Generate(generator, join, outer, alias, grandChild)) => + // Predicates that reference attributes produced by the `Generate` operator cannot + // be pushed below the operator. + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { + conjunct => conjunct.references subsetOf grandChild.outputSet + } + if (pushDown.nonEmpty) { + val pushDownPredicate = pushDown.reduce(And) + val withPushdown = generate.copy(child = Filter(pushDownPredicate, grandChild)) + stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown) + } else { + filter + } + } +} + /** * Pushes down [[Filter]] operators where the `condition` can be * evaluated using only the attributes of the left or right side of a join. Other @@ -610,3 +651,17 @@ object DecimalAggregates extends Rule[LogicalPlan] { DecimalType(prec + 4, scale + 4)) } } + +/** + * Converts local operations (i.e. ones that don't require data exchange) on LocalRelation to + * another LocalRelation. + * + * This is relatively simple as it currently handles only a single case: Project. + */ +object ConvertToLocalRelation extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Project(projectList, LocalRelation(output, data)) => + val projection = new InterpretedProjection(projectList, output) + LocalRelation(projectList.map(_.toAttribute), data.map(projection)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index b4c445b3badf1..9c8c643f7d17a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -91,16 +91,18 @@ object PhysicalOperation extends PredicateHelper { (None, Nil, other, Map.empty) } - def collectAliases(fields: Seq[Expression]) = fields.collect { + def collectAliases(fields: Seq[Expression]): Map[Attribute, Expression] = fields.collect { case a @ Alias(child, _) => a.toAttribute.asInstanceOf[Attribute] -> child }.toMap - def substitute(aliases: Map[Attribute, Expression])(expr: Expression) = expr.transform { - case a @ Alias(ref: AttributeReference, name) => - aliases.get(ref).map(Alias(_, name)(a.exprId, a.qualifiers)).getOrElse(a) + def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = { + expr.transform { + case a @ Alias(ref: AttributeReference, name) => + aliases.get(ref).map(Alias(_, name)(a.exprId, a.qualifiers)).getOrElse(a) - case a: AttributeReference => - aliases.get(a).map(Alias(_, a.name)(a.exprId, a.qualifiers)).getOrElse(a) + case a: AttributeReference => + aliases.get(a).map(Alias(_, a.name)(a.exprId, a.qualifiers)).getOrElse(a) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 619f42859cbb8..02f7c26a8ab6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression} +import org.apache.spark.sql.catalyst.expressions.{VirtualColumn, Attribute, AttributeSet, Expression} import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType} @@ -47,8 +47,12 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy * Attributes that are referenced by expressions but not provided by this nodes children. * Subclasses should override this method if they produce attributes internally as it is used by * assertions designed to prevent the construction of invalid plans. + * + * Note that virtual columns should be excluded. Currently, we only support the grouping ID + * virtual column. */ - def missingInput: AttributeSet = references -- inputSet + def missingInput: AttributeSet = + (references -- inputSet).filter(_.name != VirtualColumn.groupingIdName) /** * Runs [[transform]] with `rule` on all expressions present in this query operator. @@ -67,7 +71,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = { var changed = false - @inline def transformExpressionDown(e: Expression) = { + @inline def transformExpressionDown(e: Expression): Expression = { val newE = e.transformDown(rule) if (newE.fastEquals(e)) { e @@ -81,6 +85,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy case e: Expression => transformExpressionDown(e) case Some(e: Expression) => Some(transformExpressionDown(e)) case m: Map[_,_] => m + case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map { case e: Expression => transformExpressionDown(e) case other => other @@ -99,7 +104,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy def transformExpressionsUp(rule: PartialFunction[Expression, Expression]): this.type = { var changed = false - @inline def transformExpressionUp(e: Expression) = { + @inline def transformExpressionUp(e: Expression): Expression = { val newE = e.transformUp(rule) if (newE.fastEquals(e)) { e @@ -113,6 +118,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy case e: Expression => transformExpressionUp(e) case Some(e: Expression) => Some(transformExpressionUp(e)) case m: Map[_,_] => m + case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map { case e: Expression => transformExpressionUp(e) case other => other @@ -152,7 +158,12 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy /** Prints out the schema in the tree format */ def printSchema(): Unit = println(schemaString) + /** + * A prefix string used when printing the plan. + * + * We use "!" to indicate an invalid plan, and "'" to indicate an unresolved plan. + */ protected def statePrefix = if (missingInput.nonEmpty && children.nonEmpty) "!" else "" - override def simpleString = statePrefix + super.simpleString + override def simpleString: String = statePrefix + super.simpleString } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala similarity index 69% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index d90af45b375e4..bb79dc340553b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -17,31 +17,34 @@ package org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.types.{StructType, StructField} +import org.apache.spark.sql.types.{DataTypeConversions, StructType, StructField} object LocalRelation { def apply(output: Attribute*): LocalRelation = new LocalRelation(output) - def apply(output1: StructField, output: StructField*): LocalRelation = new LocalRelation( - StructType(output1 +: output).toAttributes - ) + def apply(output1: StructField, output: StructField*): LocalRelation = { + new LocalRelation(StructType(output1 +: output).toAttributes) + } + + def fromProduct(output: Seq[Attribute], data: Seq[Product]): LocalRelation = { + val schema = StructType.fromAttributes(output) + LocalRelation(output, data.map(row => DataTypeConversions.productToRow(row, schema))) + } } -case class LocalRelation(output: Seq[Attribute], data: Seq[Product] = Nil) +case class LocalRelation(output: Seq[Attribute], data: Seq[Row] = Nil) extends LeafNode with analysis.MultiInstanceRelation { - // TODO: Validate schema compliance. - def loadData(newData: Seq[Product]) = new LocalRelation(output, data ++ newData) - /** * Returns an identical copy of this relation with new exprIds for all attributes. Different * attributes are required when a relation is going to be included multiple times in the same * query. */ - override final def newInstance: this.type = { - LocalRelation(output.map(_.newInstance), data).asInstanceOf[this.type] + override final def newInstance(): this.type = { + LocalRelation(output.map(_.newInstance()), data).asInstanceOf[this.type] } override protected def stringArgs = Iterator(output) @@ -51,4 +54,7 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[Product] = Nil) otherOutput.map(_.dataType) == output.map(_.dataType) && otherData == data case _ => false } + + override lazy val statistics = + Statistics(sizeInBytes = output.map(_.dataType.defaultSize).sum * data.length) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 65ae066e4b4b5..2e9f3aa4ec4ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -18,38 +18,30 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.analysis.Resolver -import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedGetField, Resolver} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.types.StructType import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.types.{ArrayType, StructType, StructField} -/** - * Estimates of various statistics. The default estimation logic simply lazily multiplies the - * corresponding statistic produced by the children. To override this behavior, override - * `statistics` and assign it an overriden version of `Statistics`. - * - * '''NOTE''': concrete and/or overriden versions of statistics fields should pay attention to the - * performance of the implementations. The reason is that estimations might get triggered in - * performance-critical processes, such as query plan planning. - * - * @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it - * defaults to the product of children's `sizeInBytes`. - */ -private[sql] case class Statistics(sizeInBytes: BigInt) abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { self: Product => + /** + * Computes [[Statistics]] for this plan. The default implementation assumes the output + * cardinality is the product of of all child plan's cardinality, i.e. applies in the case + * of cartesian joins. + * + * [[LeafNode]]s must override this. + */ def statistics: Statistics = { if (children.size == 0) { throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.") } - - Statistics( - sizeInBytes = children.map(_.statistics).map(_.sizeInBytes).product) + Statistics(sizeInBytes = children.map(_.statistics.sizeInBytes).product) } /** @@ -82,12 +74,16 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * can do better should override this function. */ def sameResult(plan: LogicalPlan): Boolean = { - plan.getClass == this.getClass && - plan.children.size == children.size && { - logDebug(s"[${cleanArgs.mkString(", ")}] == [${plan.cleanArgs.mkString(", ")}]") - cleanArgs == plan.cleanArgs + val cleanLeft = EliminateSubQueries(this) + val cleanRight = EliminateSubQueries(plan) + + cleanLeft.getClass == cleanRight.getClass && + cleanLeft.children.size == cleanRight.children.size && { + logDebug( + s"[${cleanRight.cleanArgs.mkString(", ")}] == [${cleanLeft.cleanArgs.mkString(", ")}]") + cleanRight.cleanArgs == cleanLeft.cleanArgs } && - (plan.children, children).zipped.forall(_ sameResult _) + (cleanLeft.children, cleanRight.children).zipped.forall(_ sameResult _) } /** Args that have cleaned such that differences in expression id should not affect equality */ @@ -114,57 +110,113 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * nodes of this LogicalPlan. The attribute is expressed as * as string in the following form: `[scope].AttributeName.[nested].[fields]...`. */ - def resolveChildren(name: String, resolver: Resolver): Option[NamedExpression] = - resolve(name, children.flatMap(_.output), resolver) + def resolveChildren( + name: String, + resolver: Resolver, + throwErrors: Boolean = false): Option[NamedExpression] = + resolve(name, children.flatMap(_.output), resolver, throwErrors) /** * Optionally resolves the given string to a [[NamedExpression]] based on the output of this * LogicalPlan. The attribute is expressed as string in the following form: * `[scope].AttributeName.[nested].[fields]...`. */ - def resolve(name: String, resolver: Resolver): Option[NamedExpression] = - resolve(name, output, resolver) + def resolve( + name: String, + resolver: Resolver, + throwErrors: Boolean = false): Option[NamedExpression] = + resolve(name, output, resolver, throwErrors) + + /** + * Resolve the given `name` string against the given attribute, returning either 0 or 1 match. + * + * This assumes `name` has multiple parts, where the 1st part is a qualifier + * (i.e. table name, alias, or subquery alias). + * See the comment above `candidates` variable in resolve() for semantics the returned data. + */ + private def resolveAsTableColumn( + nameParts: Array[String], + resolver: Resolver, + attribute: Attribute): Option[(Attribute, List[String])] = { + assert(nameParts.length > 1) + if (attribute.qualifiers.exists(resolver(_, nameParts.head))) { + // At least one qualifier matches. See if remaining parts match. + val remainingParts = nameParts.tail + resolveAsColumn(remainingParts, resolver, attribute) + } else { + None + } + } + + /** + * Resolve the given `name` string against the given attribute, returning either 0 or 1 match. + * + * Different from resolveAsTableColumn, this assumes `name` does NOT start with a qualifier. + * See the comment above `candidates` variable in resolve() for semantics the returned data. + */ + private def resolveAsColumn( + nameParts: Array[String], + resolver: Resolver, + attribute: Attribute): Option[(Attribute, List[String])] = { + if (resolver(attribute.name, nameParts.head)) { + Option((attribute.withName(nameParts.head), nameParts.tail.toList)) + } else { + None + } + } /** Performs attribute resolution given a name and a sequence of possible attributes. */ protected def resolve( name: String, input: Seq[Attribute], - resolver: Resolver): Option[NamedExpression] = { + resolver: Resolver, + throwErrors: Boolean): Option[NamedExpression] = { val parts = name.split("\\.") - // Collect all attributes that are output by this nodes children where either the first part - // matches the name or where the first part matches the scope and the second part matches the - // name. Return these matches along with any remaining parts, which represent dotted access to - // struct fields. - val options = input.flatMap { option => - // If the first part of the desired name matches a qualifier for this possible match, drop it. - val remainingParts = - if (option.qualifiers.find(resolver(_, parts.head)).nonEmpty && parts.size > 1) { - parts.drop(1) - } else { - parts + // A sequence of possible candidate matches. + // Each candidate is a tuple. The first element is a resolved attribute, followed by a list + // of parts that are to be resolved. + // For example, consider an example where "a" is the table name, "b" is the column name, + // and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b", + // and the second element will be List("c"). + var candidates: Seq[(Attribute, List[String])] = { + // If the name has 2 or more parts, try to resolve it as `table.column` first. + if (parts.length > 1) { + input.flatMap { option => + resolveAsTableColumn(parts, resolver, option) } - - if (resolver(option.name, remainingParts.head)) { - // Preserve the case of the user's attribute reference. - (option.withName(remainingParts.head), remainingParts.tail.toList) :: Nil } else { - Nil + Seq.empty + } + } + + // If none of attributes match `table.column` pattern, we try to resolve it as a column. + if (candidates.isEmpty) { + candidates = input.flatMap { candidate => + resolveAsColumn(parts, resolver, candidate) } } - options.distinct match { + candidates.distinct match { // One match, no nested fields, use it. case Seq((a, Nil)) => Some(a) // One match, but we also need to extract the requested nested field. case Seq((a, nestedFields)) => - val aliased = - Alias( - resolveNesting(nestedFields, a, resolver), - nestedFields.last)() // Preserve the case of the user's field access. - Some(aliased) + try { + + // The foldLeft adds UnresolvedGetField for every remaining parts of the name, + // and aliased it with the last part of the name. + // For example, consider name "a.b.c", where "a" is resolved to an existing attribute. + // Then this will add UnresolvedGetField("b") and UnresolvedGetField("c"), and alias + // the final expression as "c". + val fieldExprs = nestedFields.foldLeft(a: Expression)(resolveGetField(_, _, resolver)) + val aliasName = nestedFields.last + Some(Alias(fieldExprs, aliasName)()) + } catch { + case a: AnalysisException if !throwErrors => None + } // No matches. case Seq() => @@ -173,33 +225,44 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // More than one match. case ambiguousReferences => - throw new TreeNodeException( - this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}") + val referenceNames = ambiguousReferences.map(_._1).mkString(", ") + throw new AnalysisException( + s"Reference '$name' is ambiguous, could be: $referenceNames.") } } /** - * Given a list of successive nested field accesses, and a based expression, attempt to resolve - * the actual field lookups on this expression. + * Returns the resolved `GetField`, and report error if no desired field or over one + * desired fields are found. + * + * TODO: this code is duplicated from Analyzer and should be refactored to avoid this. */ - private def resolveNesting( - nestedFields: List[String], - expression: Expression, + protected def resolveGetField( + expr: Expression, + fieldName: String, resolver: Resolver): Expression = { - - (nestedFields, expression.dataType) match { - case (Nil, _) => expression - case (requestedField :: rest, StructType(fields)) => - val actualField = fields.filter(f => resolver(f.name, requestedField)) - if (actualField.length == 0) { - sys.error( - s"No such struct field $requestedField in ${fields.map(_.name).mkString(", ")}") - } else if (actualField.length == 1) { - resolveNesting(rest, GetField(expression, actualField(0).name), resolver) - } else { - sys.error(s"Ambiguous reference to fields ${actualField.mkString(", ")}") - } - case (_, dt) => sys.error(s"Can't access nested field in type $dt") + def findField(fields: Array[StructField]): Int = { + val checkField = (f: StructField) => resolver(f.name, fieldName) + val ordinal = fields.indexWhere(checkField) + if (ordinal == -1) { + throw new AnalysisException( + s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}") + } else if (fields.indexWhere(checkField, ordinal + 1) != -1) { + throw new AnalysisException( + s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}") + } else { + ordinal + } + } + expr.dataType match { + case StructType(fields) => + val ordinal = findField(fields) + StructGetField(expr, fields(ordinal), ordinal) + case ArrayType(StructType(fields), containsNull) => + val ordinal = findField(fields) + ArrayGetField(expr, fields(ordinal), ordinal, containsNull) + case otherType => + throw new AnalysisException(s"GetField is not valid on fields of type $otherType") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala index cfe2c7a39a17c..ccf5291219add 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Attribute, Expression} /** * Transforms the input by forking and running the specified script. @@ -32,7 +32,9 @@ case class ScriptTransformation( script: String, output: Seq[Attribute], child: LogicalPlan, - ioschema: ScriptInputOutputSchema) extends UnaryNode + ioschema: ScriptInputOutputSchema) extends UnaryNode { + override def references: AttributeSet = AttributeSet(input.flatMap(_.references)) +} /** * A placeholder for implementation specific input and output properties when passing data diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala new file mode 100644 index 0000000000000..9ac4c3a2a56c8 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +/** + * Estimates of various statistics. The default estimation logic simply lazily multiplies the + * corresponding statistic produced by the children. To override this behavior, override + * `statistics` and assign it an overridden version of `Statistics`. + * + * '''NOTE''': concrete and/or overridden versions of statistics fields should pay attention to the + * performance of the implementations. The reason is that estimations might get triggered in + * performance-critical processes, such as query plan planning. + * + * Note that we are using a BigInt here since it is easy to overflow a 64-bit integer in + * cardinality estimation (e.g. cartesian joins). + * + * @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it + * defaults to the product of children's `sizeInBytes`. + */ +private[sql] case class Statistics(sizeInBytes: BigInt) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 9628e93274a11..51da04b1f67ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -22,7 +22,17 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { - def output = projectList.map(_.toAttribute) + override def output: Seq[Attribute] = projectList.map(_.toAttribute) + + override lazy val resolved: Boolean = { + val containsAggregatesOrGenerators = projectList.exists ( _.collect { + case agg: AggregateExpression => agg + case generator: Generator => generator + }.nonEmpty + ) + + !expressions.exists(!_.resolved) && childrenResolved && !containsAggregatesOrGenerators + } } /** @@ -56,21 +66,21 @@ case class Generate( } } - override def output = + override def output: Seq[Attribute] = if (join) child.output ++ generatorOutput else generatorOutput } case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output } case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { // TODO: These aren't really the same attributes as nullability etc might change. - override def output = left.output + override def output: Seq[Attribute] = left.output - override lazy val resolved = + override lazy val resolved: Boolean = childrenResolved && - !left.output.zip(right.output).exists { case (l,r) => l.dataType != r.dataType } + left.output.zip(right.output).forall { case (l,r) => l.dataType == r.dataType } } case class Join( @@ -79,7 +89,7 @@ case class Join( joinType: JoinType, condition: Option[Expression]) extends BinaryNode { - override def output = { + override def output: Seq[Attribute] = { joinType match { case LeftSemi => left.output @@ -93,10 +103,17 @@ case class Join( left.output ++ right.output } } + + private def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty + + // Joins are only resolved if they don't introduce ambiguious expression ids. + override lazy val resolved: Boolean = { + childrenResolved && !expressions.exists(!_.resolved) && selfJoinResolved + } } case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - def output = left.output + override def output: Seq[Attribute] = left.output } case class InsertIntoTable( @@ -106,11 +123,12 @@ case class InsertIntoTable( overwrite: Boolean) extends LogicalPlan { - override def children = child :: Nil - override def output = child.output + override def children: Seq[LogicalPlan] = child :: Nil + override def output: Seq[Attribute] = child.output - override lazy val resolved = childrenResolved && child.output.zip(table.output).forall { - case (childAttr, tableAttr) => childAttr.dataType == tableAttr.dataType + override lazy val resolved: Boolean = childrenResolved && child.output.zip(table.output).forall { + case (childAttr, tableAttr) => + DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType) } } @@ -120,14 +138,14 @@ case class CreateTableAsSelect[T]( child: LogicalPlan, allowExisting: Boolean, desc: Option[T] = None) extends UnaryNode { - override def output = Seq.empty[Attribute] - override lazy val resolved = databaseName != None && childrenResolved + override def output: Seq[Attribute] = Seq.empty[Attribute] + override lazy val resolved: Boolean = databaseName != None && childrenResolved } case class WriteToFile( path: String, child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output } /** @@ -140,7 +158,7 @@ case class Sort( order: Seq[SortOrder], global: Boolean, child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output } case class Aggregate( @@ -149,7 +167,7 @@ case class Aggregate( child: LogicalPlan) extends UnaryNode { - override def output = aggregateExpressions.map(_.toAttribute) + override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) } /** @@ -171,7 +189,7 @@ trait GroupingAnalytics extends UnaryNode { def groupByExprs: Seq[Expression] def aggregations: Seq[NamedExpression] - override def output = aggregations.map(_.toAttribute) + override def output: Seq[Attribute] = aggregations.map(_.toAttribute) } /** @@ -236,7 +254,7 @@ case class Rollup( gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output override lazy val statistics: Statistics = { val limit = limitExpr.eval(null).asInstanceOf[Int] @@ -246,23 +264,35 @@ case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { } case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { - override def output = child.output.map(_.withQualifiers(alias :: Nil)) + override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil)) } case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output } case class Distinct(child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output } -case object NoRelation extends LeafNode { - override def output = Nil +/** + * A relation with one row. This is used in "SELECT ..." without a from clause. + */ +case object OneRowRelation extends LeafNode { + override def output: Seq[Attribute] = Nil + + /** + * Computes [[Statistics]] for this plan. The default implementation assumes the output + * cardinality is the product of of all child plan's cardinality, i.e. applies in the case + * of cartesian joins. + * + * [[LeafNode]]s must override this. + */ + override def statistics: Statistics = Statistics(sizeInBytes = 1) } case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - override def output = left.output + override def output: Seq[Attribute] = left.output } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala index 72b0c5c8e7a26..e737418d9c3bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, SortOrder} /** * Performs a physical redistribution of the data. Used when the consumer of the query @@ -26,14 +26,11 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder} abstract class RedistributeData extends UnaryNode { self: Product => - def output = child.output + override def output: Seq[Attribute] = child.output } case class SortPartitions(sortExpressions: Seq[SortOrder], child: LogicalPlan) - extends RedistributeData { -} + extends RedistributeData case class Repartition(partitionExpressions: Seq[Expression], child: LogicalPlan) - extends RedistributeData { -} - + extends RedistributeData diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 3c3d7a3119064..288c11f69fe22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Expression, Row, SortOrder} -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{DataType, IntegerType} /** * Specifies how tuples that share common expressions will be distributed when a query is executed @@ -72,7 +72,7 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { "a single partition.") // TODO: This is not really valid... - def clustering = ordering.map(_.child).toSet + def clustering: Set[Expression] = ordering.map(_.child).toSet } sealed trait Partitioning { @@ -113,7 +113,7 @@ case object SinglePartition extends Partitioning { override def satisfies(required: Distribution): Boolean = true - override def compatibleWith(other: Partitioning) = other match { + override def compatibleWith(other: Partitioning): Boolean = other match { case SinglePartition => true case _ => false } @@ -124,7 +124,7 @@ case object BroadcastPartitioning extends Partitioning { override def satisfies(required: Distribution): Boolean = true - override def compatibleWith(other: Partitioning) = other match { + override def compatibleWith(other: Partitioning): Boolean = other match { case SinglePartition => true case _ => false } @@ -139,9 +139,9 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) extends Expression with Partitioning { - override def children = expressions - override def nullable = false - override def dataType = IntegerType + override def children: Seq[Expression] = expressions + override def nullable: Boolean = false + override def dataType: DataType = IntegerType private[this] lazy val clusteringSet = expressions.toSet @@ -152,7 +152,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case _ => false } - override def compatibleWith(other: Partitioning) = other match { + override def compatibleWith(other: Partitioning): Boolean = other match { case BroadcastPartitioning => true case h: HashPartitioning if h == this => true case _ => false @@ -178,9 +178,9 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) extends Expression with Partitioning { - override def children = ordering - override def nullable = false - override def dataType = IntegerType + override def children: Seq[SortOrder] = ordering + override def nullable: Boolean = false + override def dataType: DataType = IntegerType private[this] lazy val clusteringSet = ordering.map(_.child).toSet @@ -194,7 +194,7 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case _ => false } - override def compatibleWith(other: Partitioning) = other match { + override def compatibleWith(other: Partitioning): Boolean = other match { case BroadcastPartitioning => true case r: RangePartitioning if r == this => true case _ => false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 2013ae4f7bd13..53257bdf61cc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -18,13 +18,47 @@ package org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.types.DataType /** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */ private class MutableInt(var i: Int) +case class Origin( + line: Option[Int] = None, + startPosition: Option[Int] = None) + +/** + * Provides a location for TreeNodes to ask about the context of their origin. For example, which + * line of code is currently being parsed. + */ +object CurrentOrigin { + private val value = new ThreadLocal[Origin]() { + override def initialValue: Origin = Origin() + } + + def get: Origin = value.get() + def set(o: Origin): Unit = value.set(o) + + def reset(): Unit = value.set(Origin()) + + def setPosition(line: Int, start: Int): Unit = { + value.set( + value.get.copy(line = Some(line), startPosition = Some(start))) + } + + def withOrigin[A](o: Origin)(f: => A): A = { + set(o) + val ret = try f finally { reset() } + reset() + ret + } +} + abstract class TreeNode[BaseType <: TreeNode[BaseType]] { self: BaseType with Product => + val origin: Origin = CurrentOrigin.get + /** Returns a Seq of the children of this node */ def children: Seq[BaseType] @@ -46,6 +80,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { children.foreach(_.foreach(f)) } + /** + * Runs the given function recursively on [[children]] then on this node. + * @param f the function to be applied to each node in the tree. + */ + def foreachUp(f: BaseType => Unit): Unit = { + children.foreach(_.foreach(f)) + f(this) + } + /** * Returns a Seq containing the result of applying the given function to each * node in this tree in a preorder traversal. @@ -141,7 +184,10 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { * @param rule the function used to transform this nodes children */ def transformDown(rule: PartialFunction[BaseType, BaseType]): BaseType = { - val afterRule = rule.applyOrElse(this, identity[BaseType]) + val afterRule = CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(this, identity[BaseType]) + } + // Check if unchanged and then possibly return old copy to avoid gc churn. if (this fastEquals afterRule) { transformChildrenDown(rule) @@ -175,6 +221,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { Some(arg) } case m: Map[_,_] => m + case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => val newChild = arg.asInstanceOf[BaseType].transformDown(rule) @@ -201,9 +248,13 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = { val afterRuleOnChildren = transformChildrenUp(rule); if (this fastEquals afterRuleOnChildren) { - rule.applyOrElse(this, identity[BaseType]) + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(this, identity[BaseType]) + } } else { - rule.applyOrElse(afterRuleOnChildren, identity[BaseType]) + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(afterRuleOnChildren, identity[BaseType]) + } } } @@ -227,6 +278,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { Some(arg) } case m: Map[_,_] => m + case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => val newChild = arg.asInstanceOf[BaseType].transformUp(rule) @@ -258,29 +310,42 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { * @param newArgs the new product arguments. */ def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") { + val defaultCtor = + getClass.getConstructors + .find(_.getParameterTypes.size != 0) + .headOption + .getOrElse(sys.error(s"No valid constructor for $nodeName")) + try { - // Skip no-arg constructors that are just there for kryo. - val defaultCtor = getClass.getConstructors.find(_.getParameterTypes.size != 0).head - if (otherCopyArgs.isEmpty) { - defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type] - } else { - defaultCtor.newInstance((newArgs ++ otherCopyArgs).toArray: _*).asInstanceOf[this.type] + CurrentOrigin.withOrigin(origin) { + // Skip no-arg constructors that are just there for kryo. + if (otherCopyArgs.isEmpty) { + defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type] + } else { + defaultCtor.newInstance((newArgs ++ otherCopyArgs).toArray: _*).asInstanceOf[this.type] + } } } catch { case e: java.lang.IllegalArgumentException => throw new TreeNodeException( - this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName? " - + s"Exception message: ${e.getMessage}.") + this, + s""" + |Failed to copy node. + |Is otherCopyArgs specified correctly for $nodeName. + |Exception message: ${e.getMessage} + |ctor: $defaultCtor? + |args: ${newArgs.mkString(", ")} + """.stripMargin) } } /** Returns the name of this type of TreeNode. Defaults to the class name. */ - def nodeName = getClass.getSimpleName + def nodeName: String = getClass.getSimpleName /** * The arguments that should be included in the arg string. Defaults to the `productIterator`. */ - protected def stringArgs = productIterator + protected def stringArgs: Iterator[Any] = productIterator /** Returns a string representing the arguments to this node, minus any children */ def argString: String = productIterator.flatMap { @@ -292,18 +357,18 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { }.mkString(", ") /** String representation of this node without any children */ - def simpleString = s"$nodeName $argString" + def simpleString: String = s"$nodeName $argString".trim override def toString: String = treeString /** Returns a string representation of the nodes in this tree */ - def treeString = generateTreeString(0, new StringBuilder).toString + def treeString: String = generateTreeString(0, new StringBuilder).toString /** * Returns a string representation of the nodes in this tree, where each operator is numbered. * The numbers can be used with [[trees.TreeNode.apply apply]] to easily access specific subtrees. */ - def numberedTreeString = + def numberedTreeString: String = treeString.split("\n").zipWithIndex.map { case (line, i) => f"$i%02d $line" }.mkString("\n") /** @@ -355,14 +420,14 @@ trait BinaryNode[BaseType <: TreeNode[BaseType]] { def left: BaseType def right: BaseType - def children = Seq(left, right) + def children: Seq[BaseType] = Seq(left, right) } /** * A [[TreeNode]] with no children. */ trait LeafNode[BaseType <: TreeNode[BaseType]] { - def children = Nil + def children: Seq[BaseType] = Nil } /** @@ -370,6 +435,5 @@ trait LeafNode[BaseType <: TreeNode[BaseType]] { */ trait UnaryNode[BaseType <: TreeNode[BaseType]] { def child: BaseType - def children = child :: Nil + def children: Seq[BaseType] = child :: Nil } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala index 79a8e06d4b4d4..ea6aa1850db4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala @@ -41,11 +41,11 @@ package object trees extends Logging { * A [[TreeNode]] companion for reference equality for Hash based Collection. */ class TreeNodeRef(val obj: TreeNode[_]) { - override def equals(o: Any) = o match { + override def equals(o: Any): Boolean = o match { case that: TreeNodeRef => that.obj.eq(obj) case _ => false } - override def hashCode = if (obj == null) 0 else obj.hashCode + override def hashCode: Int = if (obj == null) 0 else obj.hashCode } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index d8da45ae70c4b..dead02d873d75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -34,7 +34,7 @@ package object util { tempFile } - def fileToString(file: File, encoding: String = "UTF-8") = { + def fileToString(file: File, encoding: String = "UTF-8"): String = { val inStream = new FileInputStream(file) val outStream = new ByteArrayOutputStream try { @@ -56,7 +56,7 @@ package object util { def resourceToString( resource:String, encoding: String = "UTF-8", - classLoader: ClassLoader = SparkUtils.getSparkClassLoader) = { + classLoader: ClassLoader = SparkUtils.getSparkClassLoader): String = { val inStream = classLoader.getResourceAsStream(resource) val outStream = new ByteArrayOutputStream try { @@ -104,7 +104,7 @@ package object util { new String(out.toByteArray) } - def stringOrNull(a: AnyRef) = if (a == null) null else a.toString + def stringOrNull(a: AnyRef): String = if (a == null) null else a.toString def benchmark[A](f: => A): A = { val startTime = System.nanoTime() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala index 21f478c80c94b..a9d63e784963d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala @@ -19,10 +19,26 @@ package org.apache.spark.sql.types import java.text.SimpleDateFormat +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -protected[sql] object DataTypeConversions { +private[sql] object DataTypeConversions { + + def productToRow(product: Product, schema: StructType): Row = { + val mutableRow = new GenericMutableRow(product.productArity) + val schemaFields = schema.fields.toArray + + var i = 0 + while (i < mutableRow.length) { + mutableRow(i) = + ScalaReflection.convertToCatalyst(product.productElement(i), schemaFields(i).dataType) + i += 1 + } + + mutableRow + } def stringToTime(s: String): java.util.Date = { if (!s.contains('T')) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala new file mode 100644 index 0000000000000..34270d0ca7cd7 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.language.implicitConversions +import scala.util.matching.Regex +import scala.util.parsing.combinator.syntactical.StandardTokenParsers + +import org.apache.spark.sql.catalyst.SqlLexical + +/** + * This is a data type parser that can be used to parse string representations of data types + * provided in SQL queries. This parser is mixed in with DDLParser and SqlParser. + */ +private[sql] trait DataTypeParser extends StandardTokenParsers { + + // This is used to create a parser from a regex. We are using regexes for data type strings + // since these strings can be also used as column names or field names. + import lexical.Identifier + implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch( + s"identifier matching regex ${regex}", + { case Identifier(str) if regex.unapplySeq(str).isDefined => str } + ) + + protected lazy val primitiveType: Parser[DataType] = + "(?i)string".r ^^^ StringType | + "(?i)float".r ^^^ FloatType | + "(?i)int".r ^^^ IntegerType | + "(?i)tinyint".r ^^^ ByteType | + "(?i)smallint".r ^^^ ShortType | + "(?i)double".r ^^^ DoubleType | + "(?i)bigint".r ^^^ LongType | + "(?i)binary".r ^^^ BinaryType | + "(?i)boolean".r ^^^ BooleanType | + fixedDecimalType | + "(?i)decimal".r ^^^ DecimalType.Unlimited | + "(?i)date".r ^^^ DateType | + "(?i)timestamp".r ^^^ TimestampType | + varchar + + protected lazy val fixedDecimalType: Parser[DataType] = + ("(?i)decimal".r ~> "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ { + case precision ~ scale => + DecimalType(precision.toInt, scale.toInt) + } + + protected lazy val varchar: Parser[DataType] = + "(?i)varchar".r ~> "(" ~> (numericLit <~ ")") ^^^ StringType + + protected lazy val arrayType: Parser[DataType] = + "(?i)array".r ~> "<" ~> dataType <~ ">" ^^ { + case tpe => ArrayType(tpe) + } + + protected lazy val mapType: Parser[DataType] = + "(?i)map".r ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ { + case t1 ~ _ ~ t2 => MapType(t1, t2) + } + + protected lazy val structField: Parser[StructField] = + ident ~ ":" ~ dataType ^^ { + case name ~ _ ~ tpe => StructField(name, tpe, nullable = true) + } + + protected lazy val structType: Parser[DataType] = + ("(?i)struct".r ~> "<" ~> repsep(structField, ",") <~ ">" ^^ { + case fields => new StructType(fields.toArray) + }) | + ("(?i)struct".r ~ "<>" ^^^ StructType(Nil)) + + protected lazy val dataType: Parser[DataType] = + arrayType | + mapType | + structType | + primitiveType + + def toDataType(dataTypeString: String): DataType = synchronized { + phrase(dataType)(new lexical.Scanner(dataTypeString)) match { + case Success(result, _) => result + case failure: NoSuccess => throw new DataTypeException(failMessage(dataTypeString)) + } + } + + private def failMessage(dataTypeString: String): String = { + s"Unsupported dataType: $dataTypeString. If you have a struct and a field name of it has " + + "any special characters, please use backticks (`) to quote that field name, e.g. `x+y`. " + + "Please note that backtick itself is not supported in a field name." + } +} + +private[sql] object DataTypeParser { + lazy val dataTypeParser = new DataTypeParser { + override val lexical = new SqlLexical + } + + def apply(dataTypeString: String): DataType = dataTypeParser.toDataType(dataTypeString) +} + +/** The exception thrown from the [[DataTypeParser]]. */ +private[sql] class DataTypeException(message: String) extends Exception(message) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala new file mode 100644 index 0000000000000..8a1a3b81b3d2c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import java.sql.Date +import java.util.{Calendar, TimeZone} + +import org.apache.spark.sql.catalyst.expressions.Cast + +/** + * helper function to convert between Int value of days since 1970-01-01 and java.sql.Date + */ +object DateUtils { + private val MILLIS_PER_DAY = 86400000 + + // Java TimeZone has no mention of thread safety. Use thread local instance to be safe. + private val LOCAL_TIMEZONE = new ThreadLocal[TimeZone] { + override protected def initialValue: TimeZone = { + Calendar.getInstance.getTimeZone + } + } + + private def javaDateToDays(d: Date): Int = { + millisToDays(d.getTime) + } + + def millisToDays(millisLocal: Long): Int = { + ((millisLocal + LOCAL_TIMEZONE.get().getOffset(millisLocal)) / MILLIS_PER_DAY).toInt + } + + private def toMillisSinceEpoch(days: Int): Long = { + val millisUtc = days.toLong * MILLIS_PER_DAY + millisUtc - LOCAL_TIMEZONE.get().getOffset(millisUtc) + } + + def fromJavaDate(date: java.sql.Date): Int = { + javaDateToDays(date) + } + + def toJavaDate(daysSinceEpoch: Int): java.sql.Date = { + new java.sql.Date(toMillisSinceEpoch(daysSinceEpoch)) + } + + def toString(days: Int): String = Cast.threadLocalDateFormat.get.format(toJavaDate(days)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 21cc6cea4bf54..994c5202c15dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -246,7 +246,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { } } - override def equals(other: Any) = other match { + override def equals(other: Any): Boolean = other match { case d: Decimal => compare(d) == 0 case _ => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala old mode 100755 new mode 100644 index e50e9761431f5..6ee24ee0c1913 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -41,6 +41,9 @@ import org.apache.spark.annotation.DeveloperApi sealed class Metadata private[types] (private[types] val map: Map[String, Any]) extends Serializable { + /** No-arg constructor for kryo. */ + protected def this() = this(null) + /** Tests whether this Metadata contains a binding for a key. */ def contains(key: String): Boolean = map.contains(key) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index defdcb2b706f5..952cf5c75688d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.types -import java.sql.{Date, Timestamp} +import java.sql.Timestamp +import scala.collection.mutable.ArrayBuffer import scala.math.Numeric.{FloatAsIfIntegral, DoubleAsIfIntegral} import scala.reflect.ClassTag import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag} @@ -29,6 +30,7 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} @@ -159,7 +161,6 @@ object DataType { case failure: NoSuccess => throw new IllegalArgumentException(s"Unsupported dataType: $asString, $failure") } - } protected[types] def buildFormattedString( @@ -180,7 +181,7 @@ object DataType { /** * Compares two types, ignoring nullability of ArrayType, MapType, StructType. */ - private[sql] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { + private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { (left, right) match { case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) => equalsIgnoreNullability(leftElementType, rightElementType) @@ -197,12 +198,48 @@ object DataType { case (left, right) => left == right } } + + /** + * Compares two types, ignoring compatible nullability of ArrayType, MapType, StructType. + * + * Compatible nullability is defined as follows: + * - If `from` and `to` are ArrayTypes, `from` has a compatible nullability with `to` + * if and only if `to.containsNull` is true, or both of `from.containsNull` and + * `to.containsNull` are false. + * - If `from` and `to` are MapTypes, `from` has a compatible nullability with `to` + * if and only if `to.valueContainsNull` is true, or both of `from.valueContainsNull` and + * `to.valueContainsNull` are false. + * - If `from` and `to` are StructTypes, `from` has a compatible nullability with `to` + * if and only if for all every pair of fields, `to.nullable` is true, or both + * of `fromField.nullable` and `toField.nullable` are false. + */ + private[sql] def equalsIgnoreCompatibleNullability(from: DataType, to: DataType): Boolean = { + (from, to) match { + case (ArrayType(fromElement, fn), ArrayType(toElement, tn)) => + (tn || !fn) && equalsIgnoreCompatibleNullability(fromElement, toElement) + + case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => + (tn || !fn) && + equalsIgnoreCompatibleNullability(fromKey, toKey) && + equalsIgnoreCompatibleNullability(fromValue, toValue) + + case (StructType(fromFields), StructType(toFields)) => + fromFields.size == toFields.size && + fromFields.zip(toFields).forall { + case (fromField, toField) => + fromField.name == toField.name && + (toField.nullable || !fromField.nullable) && + equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType) + } + + case (fromDataType, toDataType) => fromDataType == toDataType + } + } } /** * :: DeveloperApi :: - * * The base type of all Spark SQL data types. * * @group dataType @@ -227,21 +264,39 @@ abstract class DataType { def json: String = compact(render(jsonValue)) def prettyJson: String = pretty(render(jsonValue)) -} + def simpleString: String = typeName + + /** Check if `this` and `other` are the same data type when ignoring nullability + * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). + */ + private[spark] def sameType(other: DataType): Boolean = + DataType.equalsIgnoreNullability(this, other) + + /** Returns the same data type but set all nullability fields are true + * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). + */ + private[spark] def asNullable: DataType +} /** * :: DeveloperApi :: - * * The data type representing `NULL` values. Please use the singleton [[DataTypes.NullType]]. * * @group dataType */ @DeveloperApi -case object NullType extends DataType { +class NullType private() extends DataType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "NullType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. override def defaultSize: Int = 1 + + private[spark] override def asNullable: NullType = this } +case object NullType extends NullType + protected[sql] object NativeType { val all = Seq( @@ -252,7 +307,7 @@ protected[sql] object NativeType { protected[sql] trait PrimitiveType extends DataType { - override def isPrimitive = true + override def isPrimitive: Boolean = true } @@ -285,13 +340,15 @@ protected[sql] abstract class NativeType extends DataType { /** * :: DeveloperApi :: - * * The data type representing `String` values. Please use the singleton [[DataTypes.StringType]]. * * @group dataType */ @DeveloperApi -case object StringType extends NativeType with PrimitiveType { +class StringType private() extends NativeType with PrimitiveType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "StringType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. private[sql] type JvmType = String @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = implicitly[Ordering[JvmType]] @@ -300,19 +357,25 @@ case object StringType extends NativeType with PrimitiveType { * The default size of a value of the StringType is 4096 bytes. */ override def defaultSize: Int = 4096 + + private[spark] override def asNullable: StringType = this } +case object StringType extends StringType + /** * :: DeveloperApi :: - * * The data type representing `Array[Byte]` values. * Please use the singleton [[DataTypes.BinaryType]]. * * @group dataType */ @DeveloperApi -case object BinaryType extends NativeType with PrimitiveType { +class BinaryType private() extends NativeType with PrimitiveType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "BinaryType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. private[sql] type JvmType = Array[Byte] @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = new Ordering[JvmType] { @@ -329,18 +392,24 @@ case object BinaryType extends NativeType with PrimitiveType { * The default size of a value of the BinaryType is 4096 bytes. */ override def defaultSize: Int = 4096 + + private[spark] override def asNullable: BinaryType = this } +case object BinaryType extends BinaryType + /** * :: DeveloperApi :: - * * The data type representing `Boolean` values. Please use the singleton [[DataTypes.BooleanType]]. * *@group dataType */ @DeveloperApi -case object BooleanType extends NativeType with PrimitiveType { +class BooleanType private() extends NativeType with PrimitiveType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "BooleanType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. private[sql] type JvmType = Boolean @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = implicitly[Ordering[JvmType]] @@ -349,60 +418,80 @@ case object BooleanType extends NativeType with PrimitiveType { * The default size of a value of the BooleanType is 1 byte. */ override def defaultSize: Int = 1 + + private[spark] override def asNullable: BooleanType = this } +case object BooleanType extends BooleanType + /** * :: DeveloperApi :: - * * The data type representing `java.sql.Timestamp` values. * Please use the singleton [[DataTypes.TimestampType]]. * * @group dataType */ @DeveloperApi -case object TimestampType extends NativeType { +class TimestampType private() extends NativeType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "TimestampType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. private[sql] type JvmType = Timestamp @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = new Ordering[JvmType] { - def compare(x: Timestamp, y: Timestamp) = x.compareTo(y) + def compare(x: Timestamp, y: Timestamp): Int = x.compareTo(y) } /** - * The default size of a value of the TimestampType is 8 bytes. + * The default size of a value of the TimestampType is 12 bytes. */ - override def defaultSize: Int = 8 + override def defaultSize: Int = 12 + + private[spark] override def asNullable: TimestampType = this } +case object TimestampType extends TimestampType + /** * :: DeveloperApi :: - * * The data type representing `java.sql.Date` values. * Please use the singleton [[DataTypes.DateType]]. * * @group dataType */ @DeveloperApi -case object DateType extends NativeType { - private[sql] type JvmType = Date +class DateType private() extends NativeType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "DateType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type JvmType = Int @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - private[sql] val ordering = new Ordering[JvmType] { - def compare(x: Date, y: Date) = x.compareTo(y) - } + private[sql] val ordering = implicitly[Ordering[JvmType]] /** - * The default size of a value of the DateType is 8 bytes. + * The default size of a value of the DateType is 4 bytes. */ - override def defaultSize: Int = 8 + override def defaultSize: Int = 4 + + private[spark] override def asNullable: DateType = this } +case object DateType extends DateType + -protected[sql] abstract class NumericType extends NativeType with PrimitiveType { +/** + * :: DeveloperApi :: + * Numeric data types. + * + * @group dataType + */ +abstract class NumericType extends NativeType with PrimitiveType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets @@ -433,13 +522,15 @@ protected[sql] sealed abstract class IntegralType extends NumericType { /** * :: DeveloperApi :: - * * The data type representing `Long` values. Please use the singleton [[DataTypes.LongType]]. * * @group dataType */ @DeveloperApi -case object LongType extends IntegralType { +class LongType private() extends IntegralType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "LongType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. private[sql] type JvmType = Long @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val numeric = implicitly[Numeric[Long]] @@ -450,18 +541,26 @@ case object LongType extends IntegralType { * The default size of a value of the LongType is 8 bytes. */ override def defaultSize: Int = 8 + + override def simpleString: String = "bigint" + + private[spark] override def asNullable: LongType = this } +case object LongType extends LongType + /** * :: DeveloperApi :: - * * The data type representing `Int` values. Please use the singleton [[DataTypes.IntegerType]]. * * @group dataType */ @DeveloperApi -case object IntegerType extends IntegralType { +class IntegerType private() extends IntegralType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "IntegerType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. private[sql] type JvmType = Int @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val numeric = implicitly[Numeric[Int]] @@ -472,18 +571,26 @@ case object IntegerType extends IntegralType { * The default size of a value of the IntegerType is 4 bytes. */ override def defaultSize: Int = 4 + + override def simpleString: String = "int" + + private[spark] override def asNullable: IntegerType = this } +case object IntegerType extends IntegerType + /** * :: DeveloperApi :: - * * The data type representing `Short` values. Please use the singleton [[DataTypes.ShortType]]. * * @group dataType */ @DeveloperApi -case object ShortType extends IntegralType { +class ShortType private() extends IntegralType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "ShortType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. private[sql] type JvmType = Short @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val numeric = implicitly[Numeric[Short]] @@ -494,18 +601,26 @@ case object ShortType extends IntegralType { * The default size of a value of the ShortType is 2 bytes. */ override def defaultSize: Int = 2 + + override def simpleString: String = "smallint" + + private[spark] override def asNullable: ShortType = this } +case object ShortType extends ShortType + /** * :: DeveloperApi :: - * * The data type representing `Byte` values. Please use the singleton [[DataTypes.ByteType]]. * * @group dataType */ @DeveloperApi -case object ByteType extends IntegralType { +class ByteType private() extends IntegralType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "ByteType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. private[sql] type JvmType = Byte @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val numeric = implicitly[Numeric[Byte]] @@ -516,8 +631,14 @@ case object ByteType extends IntegralType { * The default size of a value of the ByteType is 1 byte. */ override def defaultSize: Int = 1 + + override def simpleString: String = "tinyint" + + private[spark] override def asNullable: ByteType = this } +case object ByteType extends ByteType + /** Matcher for any expressions that evaluate to [[FractionalType]]s */ protected[sql] object FractionalType { @@ -540,7 +661,6 @@ case class PrecisionInfo(precision: Int, scale: Int) /** * :: DeveloperApi :: - * * The data type representing `java.math.BigDecimal` values. * A Decimal that might have fixed precision and scale, or unlimited values for these. * @@ -550,6 +670,10 @@ case class PrecisionInfo(precision: Int, scale: Int) */ @DeveloperApi case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType { + + /** No-arg constructor for kryo. */ + protected def this() = this(null) + private[sql] type JvmType = Decimal @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val numeric = Decimal.DecimalIsFractional @@ -575,6 +699,13 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT * The default size of a value of the DecimalType is 4096 bytes. */ override def defaultSize: Int = 4096 + + override def simpleString: String = precisionInfo match { + case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)" + case None => "decimal(10,0)" + } + + private[spark] override def asNullable: DecimalType = this } @@ -612,13 +743,15 @@ object DecimalType { /** * :: DeveloperApi :: - * * The data type representing `Double` values. Please use the singleton [[DataTypes.DoubleType]]. * * @group dataType */ @DeveloperApi -case object DoubleType extends FractionalType { +class DoubleType private() extends FractionalType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "DoubleType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. private[sql] type JvmType = Double @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val numeric = implicitly[Numeric[Double]] @@ -630,18 +763,24 @@ case object DoubleType extends FractionalType { * The default size of a value of the DoubleType is 8 bytes. */ override def defaultSize: Int = 8 + + private[spark] override def asNullable: DoubleType = this } +case object DoubleType extends DoubleType + /** * :: DeveloperApi :: - * * The data type representing `Float` values. Please use the singleton [[DataTypes.FloatType]]. * * @group dataType */ @DeveloperApi -case object FloatType extends FractionalType { +class FloatType private() extends FractionalType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "FloatType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. private[sql] type JvmType = Float @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val numeric = implicitly[Numeric[Float]] @@ -653,8 +792,12 @@ case object FloatType extends FractionalType { * The default size of a value of the FloatType is 4 bytes. */ override def defaultSize: Int = 4 + + private[spark] override def asNullable: FloatType = this } +case object FloatType extends FloatType + object ArrayType { /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ @@ -664,7 +807,6 @@ object ArrayType { /** * :: DeveloperApi :: - * * The data type for collections of multiple values. * Internally these are represented as columns that contain a ``scala.collection.Seq``. * @@ -681,6 +823,10 @@ object ArrayType { */ @DeveloperApi case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { + + /** No-arg constructor for kryo. */ + protected def this() = this(null, false) + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append( s"$prefix-- element: ${elementType.typeName} (containsNull = $containsNull)\n") @@ -697,12 +843,16 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT * (We assume that there are 100 elements). */ override def defaultSize: Int = 100 * elementType.defaultSize + + override def simpleString: String = s"array<${elementType.simpleString}>" + + private[spark] override def asNullable: ArrayType = + ArrayType(elementType.asNullable, containsNull = true) } /** * A field inside a StructType. - * * @param name The name of this field. * @param dataType The data type of this field. * @param nullable Indicates if values of this field can be `null` values. @@ -715,6 +865,9 @@ case class StructField( nullable: Boolean = true, metadata: Metadata = Metadata.empty) { + /** No-arg constructor for kryo. */ + protected def this() = this(null, null) + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n") DataType.buildFormattedString(dataType, s"$prefix |", builder) @@ -741,12 +894,62 @@ object StructType { def apply(fields: java.util.List[StructField]): StructType = { StructType(fields.toArray.asInstanceOf[Array[StructField]]) } + + private[sql] def merge(left: DataType, right: DataType): DataType = + (left, right) match { + case (ArrayType(leftElementType, leftContainsNull), + ArrayType(rightElementType, rightContainsNull)) => + ArrayType( + merge(leftElementType, rightElementType), + leftContainsNull || rightContainsNull) + + case (MapType(leftKeyType, leftValueType, leftContainsNull), + MapType(rightKeyType, rightValueType, rightContainsNull)) => + MapType( + merge(leftKeyType, rightKeyType), + merge(leftValueType, rightValueType), + leftContainsNull || rightContainsNull) + + case (StructType(leftFields), StructType(rightFields)) => + val newFields = ArrayBuffer.empty[StructField] + + leftFields.foreach { + case leftField @ StructField(leftName, leftType, leftNullable, _) => + rightFields + .find(_.name == leftName) + .map { case rightField @ StructField(_, rightType, rightNullable, _) => + leftField.copy( + dataType = merge(leftType, rightType), + nullable = leftNullable || rightNullable) + } + .orElse(Some(leftField)) + .foreach(newFields += _) + } + + rightFields + .filterNot(f => leftFields.map(_.name).contains(f.name)) + .foreach(newFields += _) + + StructType(newFields) + + case (DecimalType.Fixed(leftPrecision, leftScale), + DecimalType.Fixed(rightPrecision, rightScale)) => + DecimalType(leftPrecision.max(rightPrecision), leftScale.max(rightScale)) + + case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_]) + if leftUdt.userClass == rightUdt.userClass => leftUdt + + case (leftType, rightType) if leftType == rightType => + leftType + + case _ => + throw new SparkException(s"Failed to merge incompatible data types $left and $right") + } } /** * :: DeveloperApi :: - * * A [[StructType]] object can be constructed by * {{{ * StructType(fields: Seq[StructField]) @@ -811,6 +1014,9 @@ object StructType { @DeveloperApi case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] { + /** No-arg constructor for kryo. */ + protected def this() = this(null) + /** Returns all field names in an array. */ def fieldNames: Array[String] = fields.map(_.name) @@ -872,6 +1078,34 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru * The default size of a value of the StructType is the total default sizes of all field types. */ override def defaultSize: Int = fields.map(_.dataType.defaultSize).sum + + override def simpleString: String = { + val fieldTypes = fields.map(field => s"${field.name}:${field.dataType.simpleString}") + s"struct<${fieldTypes.mkString(",")}>" + } + + /** + * Merges with another schema (`StructType`). For a struct field A from `this` and a struct field + * B from `that`, + * + * 1. If A and B have the same name and data type, they are merged to a field C with the same name + * and data type. C is nullable if and only if either A or B is nullable. + * 2. If A doesn't exist in `that`, it's included in the result schema. + * 3. If B doesn't exist in `this`, it's also included in the result schema. + * 4. Otherwise, `this` and `that` are considered as conflicting schemas and an exception would be + * thrown. + */ + private[sql] def merge(that: StructType): StructType = + StructType.merge(this, that).asInstanceOf[StructType] + + private[spark] override def asNullable: StructType = { + val newFields = fields.map { + case StructField(name, dataType, nullable, metadata) => + StructField(name, dataType.asNullable, nullable = true, metadata) + } + + StructType(newFields) + } } @@ -887,7 +1121,6 @@ object MapType { /** * :: DeveloperApi :: - * * The data type for Maps. Keys in a map are not allowed to have `null` values. * * Please use [[DataTypes.createMapType()]] to create a specific instance. @@ -902,6 +1135,10 @@ case class MapType( keyType: DataType, valueType: DataType, valueContainsNull: Boolean) extends DataType { + + /** No-arg constructor for kryo. */ + def this() = this(null, null, false) + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append(s"$prefix-- key: ${keyType.typeName}\n") builder.append(s"$prefix-- value: ${valueType.typeName} " + @@ -922,6 +1159,11 @@ case class MapType( * (We assume that there are 100 elements). */ override def defaultSize: Int = 100 * (keyType.defaultSize + valueType.defaultSize) + + override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>" + + private[spark] override def asNullable: MapType = + MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true) } @@ -975,4 +1217,10 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { * The default size of a value of the UserDefinedType is 4096 bytes. */ override def defaultSize: Int = 4096 + + /** + * For UDT, asNullable will not change the nullability of its internal sqlType and just returns + * itself. + */ + private[spark] override def asNullable: UserDefinedType[UserType] = this } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index d0f547d187ecb..eee00e3f7ea76 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -61,6 +61,7 @@ case class OptionalData( case class ComplexData( arrayField: Seq[Int], arrayField1: Array[Int], + arrayField2: List[Int], arrayFieldContainsNull: Seq[java.lang.Integer], mapField: Map[Int, Long], mapFieldValueContainsNull: Map[Int, java.lang.Long], @@ -137,6 +138,10 @@ class ScalaReflectionSuite extends FunSuite { "arrayField1", ArrayType(IntegerType, containsNull = false), nullable = true), + StructField( + "arrayField2", + ArrayType(IntegerType, containsNull = false), + nullable = true), StructField( "arrayFieldContainsNull", ArrayType(IntegerType, containsNull = true), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 60060bf02913b..ee7b14c7a157c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.analysis import org.scalatest.{BeforeAndAfter, FunSuite} -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} -import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -30,10 +30,22 @@ import org.apache.spark.sql.catalyst.dsl.plans._ class AnalysisSuite extends FunSuite with BeforeAndAfter { val caseSensitiveCatalog = new SimpleCatalog(true) val caseInsensitiveCatalog = new SimpleCatalog(false) - val caseSensitiveAnalyze = - new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitive = true) - val caseInsensitiveAnalyze = - new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false) + + val caseSensitiveAnalyzer = + new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitive = true) { + override val extendedResolutionRules = EliminateSubQueries :: Nil + } + val caseInsensitiveAnalyzer = + new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false) { + override val extendedResolutionRules = EliminateSubQueries :: Nil + } + + + def caseSensitiveAnalyze(plan: LogicalPlan) = + caseSensitiveAnalyzer.checkAnalysis(caseSensitiveAnalyzer(plan)) + + def caseInsensitiveAnalyze(plan: LogicalPlan) = + caseInsensitiveAnalyzer.checkAnalysis(caseInsensitiveAnalyzer(plan)) val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) val testRelation2 = LocalRelation( @@ -43,6 +55,21 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { AttributeReference("d", DecimalType.Unlimited)(), AttributeReference("e", ShortType)()) + val nestedRelation = LocalRelation( + AttributeReference("top", StructType( + StructField("duplicateField", StringType) :: + StructField("duplicateField", StringType) :: + StructField("differentCase", StringType) :: + StructField("differentcase", StringType) :: Nil + ))()) + + val nestedRelation2 = LocalRelation( + AttributeReference("top", StructType( + StructField("aField", StringType) :: + StructField("bField", StringType) :: + StructField("cField", StringType) :: Nil + ))()) + before { caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) @@ -55,35 +82,46 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { a.select(UnresolvedStar(None)).select('a).unionAll(b.select(UnresolvedStar(None))) } - assert(caseInsensitiveAnalyze(plan).resolved) + assert(caseInsensitiveAnalyzer(plan).resolved) + } + + test("check project's resolved") { + assert(Project(testRelation.output, testRelation).resolved) + + assert(!Project(Seq(UnresolvedAttribute("a")), testRelation).resolved) + + val explode = Explode(Nil, AttributeReference("a", IntegerType, nullable = true)()) + assert(!Project(Seq(Alias(explode, "explode")()), testRelation).resolved) + + assert(!Project(Seq(Alias(Count(Literal(1)), "count")()), testRelation).resolved) } test("analyze project") { assert( - caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("a")), testRelation)) === + caseSensitiveAnalyzer(Project(Seq(UnresolvedAttribute("a")), testRelation)) === Project(testRelation.output, testRelation)) assert( - caseSensitiveAnalyze( + caseSensitiveAnalyzer( Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === Project(testRelation.output, testRelation)) - val e = intercept[TreeNodeException[_]] { + val e = intercept[AnalysisException] { caseSensitiveAnalyze( Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) } - assert(e.getMessage().toLowerCase.contains("unresolved")) + assert(e.getMessage().toLowerCase.contains("cannot resolve")) assert( - caseInsensitiveAnalyze( + caseInsensitiveAnalyzer( Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === Project(testRelation.output, testRelation)) assert( - caseInsensitiveAnalyze( + caseInsensitiveAnalyzer( Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === Project(testRelation.output, testRelation)) @@ -96,36 +134,83 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { assert(e.getMessage == "Table Not Found: tAbLe") assert( - caseSensitiveAnalyze(UnresolvedRelation(Seq("TaBlE"), None)) === - testRelation) + caseSensitiveAnalyzer(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) assert( - caseInsensitiveAnalyze(UnresolvedRelation(Seq("tAbLe"), None)) === - testRelation) + caseInsensitiveAnalyzer(UnresolvedRelation(Seq("tAbLe"), None)) === testRelation) assert( - caseInsensitiveAnalyze(UnresolvedRelation(Seq("TaBlE"), None)) === - testRelation) + caseInsensitiveAnalyzer(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) } - test("throw errors for unresolved attributes during analysis") { - val e = intercept[TreeNodeException[_]] { - caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("abcd")), testRelation)) + def errorTest( + name: String, + plan: LogicalPlan, + errorMessages: Seq[String], + caseSensitive: Boolean = true) = { + test(name) { + val error = intercept[AnalysisException] { + if(caseSensitive) { + caseSensitiveAnalyze(plan) + } else { + caseInsensitiveAnalyze(plan) + } + } + + errorMessages.foreach(m => assert(error.getMessage contains m)) } - assert(e.getMessage().toLowerCase.contains("unresolved attribute")) } - test("throw errors for unresolved plans during analysis") { - case class UnresolvedTestPlan() extends LeafNode { - override lazy val resolved = false - override def output = Nil - } - val e = intercept[TreeNodeException[_]] { - caseSensitiveAnalyze(UnresolvedTestPlan()) - } - assert(e.getMessage().toLowerCase.contains("unresolved plan")) + errorTest( + "unresolved attributes", + testRelation.select('abcd), + "cannot resolve" :: "abcd" :: Nil) + + errorTest( + "bad casts", + testRelation.select(Literal(1).cast(BinaryType).as('badCast)), + "invalid cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil) + + errorTest( + "non-boolean filters", + testRelation.where(Literal(1)), + "filter" :: "'1'" :: "not a boolean" :: Literal(1).dataType.simpleString :: Nil) + + errorTest( + "missing group by", + testRelation2.groupBy('a)('b), + "'b'" :: "group by" :: Nil + ) + + errorTest( + "ambiguous field", + nestedRelation.select($"top.duplicateField"), + "Ambiguous reference to fields" :: "duplicateField" :: Nil, + caseSensitive = false) + + errorTest( + "ambiguous field due to case insensitivity", + nestedRelation.select($"top.differentCase"), + "Ambiguous reference to fields" :: "differentCase" :: "differentcase" :: Nil, + caseSensitive = false) + + errorTest( + "missing field", + nestedRelation2.select($"top.c"), + "No such struct field" :: "aField" :: "bField" :: "cField" :: Nil, + caseSensitive = false) + + case class UnresolvedTestPlan() extends LeafNode { + override lazy val resolved = false + override def output = Nil } + errorTest( + "catch all unresolved plan", + UnresolvedTestPlan(), + "unresolved" :: Nil) + + test("divide should be casted into fractional types") { val testRelation2 = LocalRelation( AttributeReference("a", StringType)(), @@ -134,22 +219,37 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { AttributeReference("d", DecimalType.Unlimited)(), AttributeReference("e", ShortType)()) - val expr0 = 'a / 2 - val expr1 = 'a / 'b - val expr2 = 'a / 'c - val expr3 = 'a / 'd - val expr4 = 'e / 'e - val plan = caseInsensitiveAnalyze(Project( - Alias(expr0, s"Analyzer($expr0)")() :: - Alias(expr1, s"Analyzer($expr1)")() :: - Alias(expr2, s"Analyzer($expr2)")() :: - Alias(expr3, s"Analyzer($expr3)")() :: - Alias(expr4, s"Analyzer($expr4)")() :: Nil, testRelation2)) + val plan = caseInsensitiveAnalyzer( + testRelation2.select( + 'a / Literal(2) as 'div1, + 'a / 'b as 'div2, + 'a / 'c as 'div3, + 'a / 'd as 'div4, + 'e / 'e as 'div5)) val pl = plan.asInstanceOf[Project].projectList + assert(pl(0).dataType == DoubleType) assert(pl(1).dataType == DoubleType) assert(pl(2).dataType == DoubleType) assert(pl(3).dataType == DecimalType.Unlimited) assert(pl(4).dataType == DoubleType) } + + test("SPARK-6452 regression test") { + // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) + val plan = + Aggregate( + Nil, + Alias(Sum(AttributeReference("a", StringType)(exprId = ExprId(1))), "b")() :: Nil, + LocalRelation( + AttributeReference("a", StringType)(exprId = ExprId(2)))) + + assert(plan.resolved) + + val message = intercept[AnalysisException] { + caseSensitiveAnalyze(plan) + }.getMessage + + assert(message.contains("resolved attribute(s) a#1 missing from a#2")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 85798d0871fda..ecbb54218d457 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.types._ -class HiveTypeCoercionSuite extends FunSuite { +class HiveTypeCoercionSuite extends PlanTest { test("tightest common bound for types") { def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) { @@ -106,7 +106,8 @@ class HiveTypeCoercionSuite extends FunSuite { val booleanCasts = new HiveTypeCoercion { }.BooleanCasts def ruleTest(initial: Expression, transformed: Expression) { val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) - assert(booleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)) == + comparePlans( + booleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)), Project(Seq(Alias(transformed, "a")()), testRelation)) } // Remove superflous boolean -> boolean casts. @@ -119,7 +120,8 @@ class HiveTypeCoercionSuite extends FunSuite { val fac = new HiveTypeCoercion { }.FunctionArgumentConversion def ruleTest(initial: Expression, transformed: Expression) { val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) - assert(fac(Project(Seq(Alias(initial, "a")()), testRelation)) == + comparePlans( + fac(Project(Seq(Alias(initial, "a")()), testRelation)), Project(Seq(Alias(transformed, "a")()), testRelation)) } ruleTest( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala new file mode 100644 index 0000000000000..f2f3a84d19380 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.scalatest.FunSuite + +import org.apache.spark.sql.types.IntegerType + +class AttributeSetSuite extends FunSuite { + + val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1)) + val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1)) + val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3)) + val aSet = AttributeSet(aLower :: Nil) + + val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2)) + val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2)) + val bSet = AttributeSet(bUpper :: Nil) + + val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil) + + test("sanity check") { + assert(aUpper != aLower) + assert(bUpper != bLower) + } + + test("checks by id not name") { + assert(aSet.contains(aUpper) === true) + assert(aSet.contains(aLower) === true) + assert(aSet.contains(fakeA) === false) + + assert(aSet.contains(bUpper) === false) + assert(aSet.contains(bLower) === false) + } + + test("++ preserves AttributeSet") { + assert((aSet ++ bSet).contains(aUpper) === true) + assert((aSet ++ bSet).contains(aLower) === true) + } + + test("extracts all references references") { + val addSet = AttributeSet(Add(aUpper, Alias(bUpper, "test")()):: Nil) + assert(addSet.contains(aUpper)) + assert(addSet.contains(aLower)) + assert(addSet.contains(bUpper)) + assert(addSet.contains(bLower)) + } + + test("dedups attributes") { + assert(AttributeSet(aUpper :: aLower :: Nil).size === 1) + } + + test("subset") { + assert(aSet.subsetOf(aAndBSet) === true) + assert(aAndBSet.subsetOf(aSet) === false) + } + + test("equality") { + assert(aSet != aAndBSet) + assert(aAndBSet != aSet) + assert(aSet != bSet) + assert(bSet != aSet) + + assert(aSet == aSet) + assert(aSet == AttributeSet(aUpper :: Nil)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 37e64adeea853..dcfd8b28cb02a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -26,6 +26,7 @@ import org.scalatest.FunSuite import org.scalatest.Matchers._ import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField import org.apache.spark.sql.types._ @@ -303,6 +304,7 @@ class ExpressionEvaluationSuite extends FunSuite { val sd = "1970-01-01" val d = Date.valueOf(sd) + val zts = sd + " 00:00:00" val sts = sd + " 00:00:02" val nts = sts + ".1" val ts = Timestamp.valueOf(nts) @@ -319,14 +321,14 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble) checkEvaluation(Cast(Literal(sd) cast DateType, StringType), sd) - checkEvaluation(Cast(Literal(d) cast StringType, DateType), d) + checkEvaluation(Cast(Literal(d) cast StringType, DateType), 0) checkEvaluation(Cast(Literal(nts) cast TimestampType, StringType), nts) checkEvaluation(Cast(Literal(ts) cast StringType, TimestampType), ts) // all convert to string type to check checkEvaluation( Cast(Cast(Literal(nts) cast TimestampType, DateType), StringType), sd) checkEvaluation( - Cast(Cast(Literal(ts) cast DateType, TimestampType), StringType), sts) + Cast(Cast(Literal(ts) cast DateType, TimestampType), StringType), zts) checkEvaluation(Cast("abdef" cast BinaryType, StringType), "abdef") @@ -377,8 +379,8 @@ class ExpressionEvaluationSuite extends FunSuite { } test("date") { - val d1 = Date.valueOf("1970-01-01") - val d2 = Date.valueOf("1970-01-02") + val d1 = DateUtils.fromJavaDate(Date.valueOf("1970-01-01")) + val d2 = DateUtils.fromJavaDate(Date.valueOf("1970-01-02")) checkEvaluation(Literal(d1) < Literal(d2), true) } @@ -459,22 +461,21 @@ class ExpressionEvaluationSuite extends FunSuite { test("date casting") { val d = Date.valueOf("1970-01-01") - checkEvaluation(Cast(d, ShortType), null) - checkEvaluation(Cast(d, IntegerType), null) - checkEvaluation(Cast(d, LongType), null) - checkEvaluation(Cast(d, FloatType), null) - checkEvaluation(Cast(d, DoubleType), null) - checkEvaluation(Cast(d, DecimalType.Unlimited), null) - checkEvaluation(Cast(d, DecimalType(10, 2)), null) - checkEvaluation(Cast(d, StringType), "1970-01-01") - checkEvaluation(Cast(Cast(d, TimestampType), StringType), "1970-01-01 00:00:00") + checkEvaluation(Cast(Literal(d), ShortType), null) + checkEvaluation(Cast(Literal(d), IntegerType), null) + checkEvaluation(Cast(Literal(d), LongType), null) + checkEvaluation(Cast(Literal(d), FloatType), null) + checkEvaluation(Cast(Literal(d), DoubleType), null) + checkEvaluation(Cast(Literal(d), DecimalType.Unlimited), null) + checkEvaluation(Cast(Literal(d), DecimalType(10, 2)), null) + checkEvaluation(Cast(Literal(d), StringType), "1970-01-01") + checkEvaluation(Cast(Cast(Literal(d), TimestampType), StringType), "1970-01-01 00:00:00") } test("timestamp casting") { val millis = 15 * 1000 + 2 val seconds = millis * 1000 + 2 val ts = new Timestamp(millis) - val ts1 = new Timestamp(15 * 1000) // a timestamp without the milliseconds part val tss = new Timestamp(seconds) checkEvaluation(Cast(ts, ShortType), 15) checkEvaluation(Cast(ts, IntegerType), 15) @@ -846,23 +847,33 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(GetItem(BoundReference(4, typeArray, true), Literal(null, IntegerType)), null, row) - checkEvaluation(GetField(BoundReference(2, typeS, nullable = true), "a"), "aa", row) - checkEvaluation(GetField(Literal(null, typeS), "a"), null, row) + def quickBuildGetField(expr: Expression, fieldName: String) = { + expr.dataType match { + case StructType(fields) => + val field = fields.find(_.name == fieldName).get + StructGetField(expr, field, fields.indexOf(field)) + } + } + + def quickResolve(u: UnresolvedGetField) = quickBuildGetField(u.child, u.fieldName) + + checkEvaluation(quickBuildGetField(BoundReference(2, typeS, nullable = true), "a"), "aa", row) + checkEvaluation(quickBuildGetField(Literal(null, typeS), "a"), null, row) val typeS_notNullable = StructType( StructField("a", StringType, nullable = false) :: StructField("b", StringType, nullable = false) :: Nil ) - assert(GetField(BoundReference(2,typeS, nullable = true), "a").nullable === true) - assert(GetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable === false) + assert(quickBuildGetField(BoundReference(2,typeS, nullable = true), "a").nullable === true) + assert(quickBuildGetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable === false) - assert(GetField(Literal(null, typeS), "a").nullable === true) - assert(GetField(Literal(null, typeS_notNullable), "a").nullable === true) + assert(quickBuildGetField(Literal(null, typeS), "a").nullable === true) + assert(quickBuildGetField(Literal(null, typeS_notNullable), "a").nullable === true) checkEvaluation('c.map(typeMap).at(3).getItem("aa"), "bb", row) checkEvaluation('c.array(typeArray.elementType).at(4).getItem(1), "bb", row) - checkEvaluation('c.struct(typeS).at(2).getField("a"), "aa", row) + checkEvaluation(quickResolve('c.struct(typeS).at(2).getField("a")), "aa", row) } test("arithmetic") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 264a0eff37d34..72f06e26e05f1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest @@ -30,7 +30,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("AnalysisNodes", Once, - EliminateAnalysisOperators) :: + EliminateSubQueries) :: Batch("Constant Folding", FixedPoint(50), NullPropagation, ConstantFolding, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 9fdf3efa02bb6..ef10c0aece716 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators +import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, EliminateSubQueries} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.plans.PlanTest @@ -33,7 +33,7 @@ class ConstantFoldingSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("AnalysisNodes", Once, - EliminateAnalysisOperators) :: + EliminateSubQueries) :: Batch("ConstantFolding", Once, ConstantFolding, BooleanSimplification) :: Nil @@ -184,7 +184,7 @@ class ConstantFoldingSuite extends PlanTest { GetItem(Literal(null, ArrayType(IntegerType)), 1) as 'c3, GetItem(Literal(Seq(1), ArrayType(IntegerType)), Literal(null, IntegerType)) as 'c4, - GetField( + UnresolvedGetField( Literal(null, StructType(Seq(StructField("a", IntegerType, true)))), "a") as 'c5, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala new file mode 100644 index 0000000000000..cf42d43823399 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + + +class ConvertToLocalRelationSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("LocalRelation", FixedPoint(100), + ConvertToLocalRelation) :: Nil + } + + test("Project on LocalRelation should be turned into a single LocalRelation") { + val testRelation = LocalRelation( + LocalRelation('a.int, 'b.int).output, + Row(1, 2) :: + Row(4, 5) :: Nil) + + val correctAnswer = LocalRelation( + LocalRelation('a1.int, 'b1.int).output, + Row(1, 3) :: + Row(4, 6) :: Nil) + + val projectOnLocal = testRelation.select( + UnresolvedAttribute("a").as("a1"), + (UnresolvedAttribute("b") + 1).as("b1")) + + val optimized = Optimize(projectOnLocal.analyze) + + comparePlans(optimized, correctAnswer) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala index ae99a3f9ba287..2f3704be59a9d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala @@ -29,7 +29,7 @@ class ExpressionOptimizationSuite extends ExpressionEvaluationSuite { expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { - val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, NoRelation) + val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) val optimizedPlan = DefaultOptimizer(plan) super.checkEvaluation(optimizedPlan.expressions.head, expected, inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index ebb123c1f909e..55c6766520a1e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -18,23 +18,27 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.expressions.{Count, Explode} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{PlanTest, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types.IntegerType class FilterPushdownSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", Once, - EliminateAnalysisOperators) :: + EliminateSubQueries) :: Batch("Filter Pushdown", Once, CombineFilters, PushPredicateThroughProject, - PushPredicateThroughJoin) :: Nil + PushPredicateThroughJoin, + PushPredicateThroughGenerate, + ColumnPruning) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) @@ -55,6 +59,38 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("column pruning for group") { + val originalQuery = + testRelation + .groupBy('a)('a, Count('b)) + .select('a) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .groupBy('a)('a) + .select('a).analyze + + comparePlans(optimized, correctAnswer) + } + + test("column pruning for group with alias") { + val originalQuery = + testRelation + .groupBy('a)('a as 'c, Count('b)) + .select('c) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .groupBy('a)('a as 'c) + .select('c).analyze + + comparePlans(optimized, correctAnswer) + } + // After this line is unimplemented. test("simple push down") { val originalQuery = @@ -348,7 +384,7 @@ class FilterPushdownSuite extends PlanTest { } val optimized = Optimize(originalQuery.analyze) - comparePlans(analysis.EliminateAnalysisOperators(originalQuery.analyze), optimized) + comparePlans(analysis.EliminateSubQueries(originalQuery.analyze), optimized) } test("joins: conjunctive predicates") { @@ -367,7 +403,7 @@ class FilterPushdownSuite extends PlanTest { left.join(right, condition = Some("x.b".attr === "y.b".attr)) .analyze - comparePlans(optimized, analysis.EliminateAnalysisOperators(correctAnswer)) + comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer)) } test("joins: conjunctive predicates #2") { @@ -386,7 +422,7 @@ class FilterPushdownSuite extends PlanTest { left.join(right, condition = Some("x.b".attr === "y.b".attr)) .analyze - comparePlans(optimized, analysis.EliminateAnalysisOperators(correctAnswer)) + comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer)) } test("joins: conjunctive predicates #3") { @@ -409,6 +445,64 @@ class FilterPushdownSuite extends PlanTest { condition = Some("z.a".attr === "x.b".attr)) .analyze - comparePlans(optimized, analysis.EliminateAnalysisOperators(correctAnswer)) + comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer)) + } + + val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType)) + + test("generate: predicate referenced no generated column") { + val originalQuery = { + testRelationWithArrayType + .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr")) + .where(('b >= 5) && ('a > 6)) + } + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = { + testRelationWithArrayType + .where(('b >= 5) && ('a > 6)) + .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr")).analyze + } + + comparePlans(optimized, correctAnswer) + } + + test("generate: part of conjuncts referenced generated column") { + val generator = Explode(Seq("c"), 'c_arr) + val originalQuery = { + testRelationWithArrayType + .generate(generator, true, false, Some("arr")) + .where(('b >= 5) && ('c > 6)) + } + val optimized = Optimize(originalQuery.analyze) + val referenceResult = { + testRelationWithArrayType + .where('b >= 5) + .generate(generator, true, false, Some("arr")) + .where('c > 6).analyze + } + + // Since newly generated columns get different ids every time being analyzed + // e.g. comparePlans(originalQuery.analyze, originalQuery.analyze) fails. + // So we check operators manually here. + // Filter("c" > 6) + assertResult(classOf[Filter])(optimized.getClass) + assertResult(1)(optimized.asInstanceOf[Filter].condition.references.size) + assertResult("c"){ + optimized.asInstanceOf[Filter].condition.references.toSeq(0).name + } + + // the rest part + comparePlans(optimized.children(0), referenceResult.children(0)) + } + + test("generate: all conjuncts referenced generated column") { + val originalQuery = { + testRelationWithArrayType + .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr")) + .where(('c > 6) || ('b > 5)).analyze + } + val optimized = Optimize(originalQuery) + + comparePlans(optimized, originalQuery) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index da912ab382179..233e329cb2038 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet -import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.plans.PlanTest @@ -34,7 +34,7 @@ class OptimizeInSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("AnalysisNodes", Once, - EliminateAnalysisOperators) :: + EliminateSubQueries) :: Batch("ConstantFolding", Once, ConstantFolding, BooleanSimplification, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala index dfef87bd9133d..a54751dfa9a12 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{PlanTest, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.rules._ @@ -29,7 +29,7 @@ class UnionPushdownSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", Once, - EliminateAnalysisOperators) :: + EliminateSubQueries) :: Batch("Union Pushdown", Once, UnionPushdown) :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index c4a1f899d8a13..129d091ca03e3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.plans import org.scalatest.FunSuite -import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Filter, LogicalPlan} import org.apache.spark.sql.catalyst.util._ /** @@ -33,11 +33,11 @@ class PlanTest extends FunSuite { * we must normalize them to check if two different queries are identical. */ protected def normalizeExprIds(plan: LogicalPlan) = { - val list = plan.flatMap(_.expressions.flatMap(_.references).map(_.exprId.id)) - val minId = if (list.isEmpty) 0 else list.min plan transformAllExpressions { case a: AttributeReference => - AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(a.exprId.id - minId)) + AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) + case a: Alias => + Alias(a.child, a.name)(exprId = ExprId(0)) } } @@ -52,4 +52,9 @@ class PlanTest extends FunSuite { |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} """.stripMargin) } + + /** Fails the test if the two expressions do not match */ + protected def compareExpressions(e1: Expression, e2: Expression): Unit = { + comparePlans(Filter(e1, OneRowRelation), Filter(e2, OneRowRelation)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index cdb843f959704..e7ce92a2160b6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -104,4 +104,18 @@ class TreeNodeSuite extends FunSuite { assert(actual === Dummy(None)) } + test("preserves origin") { + CurrentOrigin.setPosition(1,1) + val add = Add(Literal(1), Literal(1)) + CurrentOrigin.reset() + + val transformed = add transform { + case Literal(1, _) => Literal(2) + } + + assert(transformed.origin.line.isDefined) + assert(transformed.origin.startPosition.isDefined) + } + + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala old mode 100755 new mode 100644 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala new file mode 100644 index 0000000000000..1ba21b64603ac --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala @@ -0,0 +1,116 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.types + +import org.scalatest.FunSuite + +class DataTypeParserSuite extends FunSuite { + + def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = { + test(s"parse ${dataTypeString.replace("\n", "")}") { + assert(DataTypeParser(dataTypeString) === expectedDataType) + } + } + + def unsupported(dataTypeString: String): Unit = { + test(s"$dataTypeString is not supported") { + intercept[DataTypeException](DataTypeParser(dataTypeString)) + } + } + + checkDataType("int", IntegerType) + checkDataType("BooLean", BooleanType) + checkDataType("tinYint", ByteType) + checkDataType("smallINT", ShortType) + checkDataType("INT", IntegerType) + checkDataType("bigint", LongType) + checkDataType("float", FloatType) + checkDataType("dOUBle", DoubleType) + checkDataType("decimal(10, 5)", DecimalType(10, 5)) + checkDataType("decimal", DecimalType.Unlimited) + checkDataType("DATE", DateType) + checkDataType("timestamp", TimestampType) + checkDataType("string", StringType) + checkDataType("varchAr(20)", StringType) + checkDataType("BINARY", BinaryType) + + checkDataType("array", ArrayType(DoubleType, true)) + checkDataType("Array>", ArrayType(MapType(IntegerType, ByteType, true), true)) + checkDataType( + "array>", + ArrayType(StructType(StructField("tinYint", ByteType, true) :: Nil), true) + ) + checkDataType("MAP", MapType(IntegerType, StringType, true)) + checkDataType("MAp>", MapType(IntegerType, ArrayType(DoubleType), true)) + checkDataType( + "MAP>", + MapType(IntegerType, StructType(StructField("varchar", StringType, true) :: Nil), true) + ) + + checkDataType( + "struct", + StructType( + StructField("intType", IntegerType, true) :: + StructField("ts", TimestampType, true) :: Nil) + ) + // It is fine to use the data type string as the column name. + checkDataType( + "Struct", + StructType( + StructField("int", IntegerType, true) :: + StructField("timestamp", TimestampType, true) :: Nil) + ) + checkDataType( + """ + |struct< + | struct:struct, + | MAP:Map, + | arrAy:Array> + """.stripMargin, + StructType( + StructField("struct", + StructType( + StructField("deciMal", DecimalType.Unlimited, true) :: + StructField("anotherDecimal", DecimalType(5, 2), true) :: Nil), true) :: + StructField("MAP", MapType(TimestampType, StringType), true) :: + StructField("arrAy", ArrayType(DoubleType, true), true) :: Nil) + ) + // A column name can be a reserved word in our DDL parser and SqlParser. + checkDataType( + "Struct", + StructType( + StructField("TABLE", StringType, true) :: + StructField("CASE", BooleanType, true) :: Nil) + ) + // Use backticks to quote column names having special characters. + checkDataType( + "struct<`x+y`:int, `!@#$%^&*()`:string, `1_2.345<>:\"`:varchar(20)>", + StructType( + StructField("x+y", IntegerType, true) :: + StructField("!@#$%^&*()", StringType, true) :: + StructField("1_2.345<>:\"", StringType, true) :: Nil) + ) + // Empty struct. + checkDataType("strUCt<>", StructType(Nil)) + + unsupported("it is not a data type") + unsupported("struct") + unsupported("struct") + unsupported("struct<`x``y` int>") +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index c147be9f6b1ae..a1341ea13d810 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -106,8 +106,8 @@ class DataTypeSuite extends FunSuite { checkDefaultSize(DoubleType, 8) checkDefaultSize(DecimalType(10, 5), 4096) checkDefaultSize(DecimalType.Unlimited, 4096) - checkDefaultSize(DateType, 8) - checkDefaultSize(TimestampType, 8) + checkDefaultSize(DateType, 4) + checkDefaultSize(TimestampType,12) checkDefaultSize(StringType, 4096) checkDefaultSize(BinaryType, 4096) checkDefaultSize(ArrayType(DoubleType, true), 800) @@ -115,4 +115,87 @@ class DataTypeSuite extends FunSuite { checkDefaultSize(MapType(IntegerType, StringType, true), 410000) checkDefaultSize(MapType(IntegerType, ArrayType(DoubleType), false), 80400) checkDefaultSize(structType, 812) + + def checkEqualsIgnoreCompatibleNullability( + from: DataType, + to: DataType, + expected: Boolean): Unit = { + val testName = + s"equalsIgnoreCompatibleNullability: (from: ${from}, to: ${to})" + test(testName) { + assert(DataType.equalsIgnoreCompatibleNullability(from, to) === expected) + } + } + + checkEqualsIgnoreCompatibleNullability( + from = ArrayType(DoubleType, containsNull = true), + to = ArrayType(DoubleType, containsNull = true), + expected = true) + checkEqualsIgnoreCompatibleNullability( + from = ArrayType(DoubleType, containsNull = false), + to = ArrayType(DoubleType, containsNull = false), + expected = true) + checkEqualsIgnoreCompatibleNullability( + from = ArrayType(DoubleType, containsNull = false), + to = ArrayType(DoubleType, containsNull = true), + expected = true) + checkEqualsIgnoreCompatibleNullability( + from = ArrayType(DoubleType, containsNull = true), + to = ArrayType(DoubleType, containsNull = false), + expected = false) + checkEqualsIgnoreCompatibleNullability( + from = ArrayType(DoubleType, containsNull = false), + to = ArrayType(StringType, containsNull = false), + expected = false) + + checkEqualsIgnoreCompatibleNullability( + from = MapType(StringType, DoubleType, valueContainsNull = true), + to = MapType(StringType, DoubleType, valueContainsNull = true), + expected = true) + checkEqualsIgnoreCompatibleNullability( + from = MapType(StringType, DoubleType, valueContainsNull = false), + to = MapType(StringType, DoubleType, valueContainsNull = false), + expected = true) + checkEqualsIgnoreCompatibleNullability( + from = MapType(StringType, DoubleType, valueContainsNull = false), + to = MapType(StringType, DoubleType, valueContainsNull = true), + expected = true) + checkEqualsIgnoreCompatibleNullability( + from = MapType(StringType, DoubleType, valueContainsNull = true), + to = MapType(StringType, DoubleType, valueContainsNull = false), + expected = false) + checkEqualsIgnoreCompatibleNullability( + from = MapType(StringType, ArrayType(IntegerType, true), valueContainsNull = true), + to = MapType(StringType, ArrayType(IntegerType, false), valueContainsNull = true), + expected = false) + checkEqualsIgnoreCompatibleNullability( + from = MapType(StringType, ArrayType(IntegerType, false), valueContainsNull = true), + to = MapType(StringType, ArrayType(IntegerType, true), valueContainsNull = true), + expected = true) + + + checkEqualsIgnoreCompatibleNullability( + from = StructType(StructField("a", StringType, nullable = true) :: Nil), + to = StructType(StructField("a", StringType, nullable = true) :: Nil), + expected = true) + checkEqualsIgnoreCompatibleNullability( + from = StructType(StructField("a", StringType, nullable = false) :: Nil), + to = StructType(StructField("a", StringType, nullable = false) :: Nil), + expected = true) + checkEqualsIgnoreCompatibleNullability( + from = StructType(StructField("a", StringType, nullable = false) :: Nil), + to = StructType(StructField("a", StringType, nullable = true) :: Nil), + expected = true) + checkEqualsIgnoreCompatibleNullability( + from = StructType(StructField("a", StringType, nullable = true) :: Nil), + to = StructType(StructField("a", StringType, nullable = false) :: Nil), + expected = false) + checkEqualsIgnoreCompatibleNullability( + from = StructType( + StructField("a", StringType, nullable = false) :: + StructField("b", StringType, nullable = true) :: Nil), + to = StructType( + StructField("a", StringType, nullable = false) :: + StructField("b", StringType, nullable = false) :: Nil), + expected = false) } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 1a0c77d282307..5b2a1890cddc8 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -21,8 +21,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.3.2-SNAPSHOT ../../pom.xml @@ -66,6 +66,11 @@ jackson-databind 2.3.0 + + org.jodd + jodd-core + ${jodd.version} + junit junit @@ -94,15 +99,17 @@ 9.3-1102-jdbc41 test - - com.spotify - docker-client - 2.7.5 - test - target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + + ../../python + + pyspark/sql/*.py + + + diff --git a/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java b/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java new file mode 100644 index 0000000000000..a40be526d0d11 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql; + +/** + * SaveMode is used to specify the expected behavior of saving a DataFrame to a data source. + */ +public enum SaveMode { + /** + * Append mode means that when saving a DataFrame to a data source, if data/table already exists, + * contents of the DataFrame are expected to be appended to existing data. + */ + Append, + /** + * Overwrite mode means that when saving a DataFrame to a data source, + * if data/table already exists, existing data is expected to be overwritten by the contents of + * the DataFrame. + */ + Overwrite, + /** + * ErrorIfExists mode means that when saving a DataFrame to a data source, if data already exists, + * an exception is expected to be thrown. + */ + ErrorIfExists, + /** + * Ignore mode means that when saving a DataFrame to a data source, if data already exists, + * the save operation is expected to not save the contents of the DataFrame and to not + * change the existing data. + */ + Ignore +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/jdbc/JDBCUtils.java b/sql/core/src/main/java/org/apache/spark/sql/jdbc/JDBCUtils.java deleted file mode 100644 index aa441b2096f18..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/jdbc/JDBCUtils.java +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.jdbc; - -import org.apache.spark.Partition; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.DataFrame; - -public class JDBCUtils { - /** - * Construct a DataFrame representing the JDBC table at the database - * specified by url with table name table. - */ - public static DataFrame jdbcRDD(SQLContext sql, String url, String table) { - Partition[] parts = new Partition[1]; - parts[0] = new JDBCPartition(null, 0); - return sql.baseRelationToDataFrame( - new JDBCRelation(url, table, parts, sql)); - } - - /** - * Construct a DataFrame representing the JDBC table at the database - * specified by url with table name table partitioned by parts. - * Here, parts is an array of expressions suitable for insertion into a WHERE - * clause; each one defines one partition. - */ - public static DataFrame jdbcRDD(SQLContext sql, String url, String table, String[] parts) { - Partition[] partitions = new Partition[parts.length]; - for (int i = 0; i < parts.length; i++) - partitions[i] = new JDBCPartition(parts[i], i); - return sql.baseRelationToDataFrame( - new JDBCRelation(url, table, partitions, sql)); - } - - private static JavaJDBCTrampoline trampoline = new JavaJDBCTrampoline(); - - public static void createJDBCTable(DataFrame rdd, String url, String table, boolean allowExisting) { - trampoline.createJDBCTable(rdd, url, table, allowExisting); - } - - public static void insertIntoJDBC(DataFrame rdd, String url, String table, boolean overwrite) { - trampoline.insertIntoJDBC(rdd, url, table, overwrite); - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala index f1949aa5dd74b..ca4a127120b37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala @@ -71,11 +71,17 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging { } } + /** Clears all cached tables. */ private[sql] def clearCache(): Unit = writeLock { cachedData.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist()) cachedData.clear() } + /** Checks if the cache is empty. */ + private[sql] def isEmpty: Boolean = readLock { + cachedData.isEmpty + } + /** * Caches the data produced by the logical representation of the given schema rdd. Unlike * `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because recomputing diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 4aa37219e13a6..1f5c30b487951 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -17,90 +17,55 @@ package org.apache.spark.sql -import scala.annotation.tailrec import scala.language.implicitConversions -import org.apache.spark.sql.Dsl.lit +import org.apache.spark.annotation.Experimental +import org.apache.spark.Logging +import org.apache.spark.sql.functions.lit import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Subquery, Project, LogicalPlan} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar, UnresolvedGetField} import org.apache.spark.sql.types._ private[sql] object Column { - def apply(colName: String): Column = new IncomputableColumn(colName) + def apply(colName: String): Column = new Column(colName) - def apply(expr: Expression): Column = new IncomputableColumn(expr) - - def apply(sqlContext: SQLContext, plan: LogicalPlan, expr: Expression): Column = { - new ComputableColumn(sqlContext, plan, expr) - } + def apply(expr: Expression): Column = new Column(expr) def unapply(col: Column): Option[Expression] = Some(col.expr) } /** + * :: Experimental :: * A column in a [[DataFrame]]. * - * `Column` instances can be created by: - * {{{ - * // 1. Select a column out of a DataFrame - * df("colName") - * - * // 2. Create a literal expression - * Literal(1) - * - * // 3. Create new columns from - * }}} - * + * @groupname java_expr_ops Java-specific expression operators. + * @groupname expr_ops Expression operators. + * @groupname df_ops DataFrame functions. + * @groupname Ungrouped Support functions for DataFrames. */ -// TODO: Improve documentation. -trait Column extends DataFrame { +@Experimental +class Column(protected[sql] val expr: Expression) extends Logging { - protected[sql] def expr: Expression + def this(name: String) = this(name match { + case "*" => UnresolvedStar(None) + case _ if name.endsWith(".*") => UnresolvedStar(Some(name.substring(0, name.length - 2))) + case _ => UnresolvedAttribute(name) + }) - /** - * Returns true iff the [[Column]] is computable. - */ - def isComputable: Boolean - - private def computableCol(baseCol: ComputableColumn, expr: Expression) = { - val plan = Project(Seq(expr match { - case named: NamedExpression => named - case unnamed: Expression => Alias(unnamed, "col")() - }), baseCol.plan) - Column(baseCol.sqlContext, plan, expr) - } + /** Creates a column based on the given expression. */ + implicit private def exprToColumn(newExpr: Expression): Column = new Column(newExpr) - private def constructColumn(otherValue: Any)(newExpr: Column => Expression): Column = { - // Removes all the top level projection and subquery so we can get to the underlying plan. - @tailrec def stripProject(p: LogicalPlan): LogicalPlan = p match { - case Project(_, child) => stripProject(child) - case Subquery(_, child) => stripProject(child) - case _ => p - } + override def toString: String = expr.prettyString - (this, lit(otherValue)) match { - case (left: ComputableColumn, right: ComputableColumn) => - if (stripProject(left.plan).sameResult(stripProject(right.plan))) { - computableCol(right, newExpr(right)) - } else { - Column(newExpr(right)) - } - case (left: ComputableColumn, right) => computableCol(left, newExpr(right)) - case (_, right: ComputableColumn) => computableCol(right, newExpr(right)) - case (_, right) => Column(newExpr(right)) - } + override def equals(that: Any): Boolean = that match { + case that: Column => that.expr.equals(this.expr) + case _ => false } - /** Creates a column based on the given expression. */ - private def exprToColumn(newExpr: Expression, computable: Boolean = true): Column = { - this match { - case c: ComputableColumn if computable => computableCol(c, newExpr) - case _ => Column(newExpr) - } - } + override def hashCode: Int = this.expr.hashCode /** * Unary minus, i.e. negate the expression. @@ -109,20 +74,13 @@ trait Column extends DataFrame { * df.select( -df("amount") ) * * // Java: - * import static org.apache.spark.sql.Dsl.*; + * import static org.apache.spark.sql.functions.*; * df.select( negate(col("amount") ); * }}} + * + * @group expr_ops */ - def unary_- : Column = exprToColumn(UnaryMinus(expr)) - - /** - * Bitwise NOT. - * {{{ - * // Scala: select the flags column and negate every bit. - * df.select( ~df("flags") ) - * }}} - */ - def unary_~ : Column = exprToColumn(BitwiseNot(expr)) + def unary_- : Column = UnaryMinus(expr) /** * Inversion of boolean expression, i.e. NOT. @@ -131,12 +89,13 @@ trait Column extends DataFrame { * df.filter( !df("isActive") ) * * // Java: - * import static org.apache.spark.sql.Dsl.*; + * import static org.apache.spark.sql.functions.*; * df.filter( not(df.col("isActive")) ); * }} + * + * @group expr_ops */ - def unary_! : Column = exprToColumn(Not(expr)) - + def unary_! : Column = Not(expr) /** * Equality test. @@ -145,12 +104,20 @@ trait Column extends DataFrame { * df.filter( df("colA") === df("colB") ) * * // Java - * import static org.apache.spark.sql.Dsl.*; + * import static org.apache.spark.sql.functions.*; * df.filter( col("colA").equalTo(col("colB")) ); * }}} - */ - def === (other: Any): Column = constructColumn(other) { o => - EqualTo(expr, o.expr) + * + * @group expr_ops + */ + def === (other: Any): Column = { + val right = lit(other).expr + if (this.expr == right) { + logWarning( + s"Constructing trivially true equals predicate, '${this.expr} = $right'. " + + "Perhaps you need to use aliases.") + } + EqualTo(expr, right) } /** @@ -160,9 +127,11 @@ trait Column extends DataFrame { * df.filter( df("colA") === df("colB") ) * * // Java - * import static org.apache.spark.sql.Dsl.*; + * import static org.apache.spark.sql.functions.*; * df.filter( col("colA").equalTo(col("colB")) ); * }}} + * + * @group expr_ops */ def equalTo(other: Any): Column = this === other @@ -174,13 +143,29 @@ trait Column extends DataFrame { * df.select( !(df("colA") === df("colB")) ) * * // Java: - * import static org.apache.spark.sql.Dsl.*; - * df.filter( not(col("colA").equalTo(col("colB"))) ); + * import static org.apache.spark.sql.functions.*; + * df.filter( col("colA").notEqual(col("colB")) ); * }}} + * + * @group expr_ops */ - def !== (other: Any): Column = constructColumn(other) { o => - Not(EqualTo(expr, o.expr)) - } + def !== (other: Any): Column = Not(EqualTo(expr, lit(other).expr)) + + /** + * Inequality test. + * {{{ + * // Scala: + * df.select( df("colA") !== df("colB") ) + * df.select( !(df("colA") === df("colB")) ) + * + * // Java: + * import static org.apache.spark.sql.functions.*; + * df.filter( col("colA").notEqual(col("colB")) ); + * }}} + * + * @group java_expr_ops + */ + def notEqual(other: Any): Column = Not(EqualTo(expr, lit(other).expr)) /** * Greater than. @@ -189,13 +174,13 @@ trait Column extends DataFrame { * people.select( people("age") > 21 ) * * // Java: - * import static org.apache.spark.sql.Dsl.*; + * import static org.apache.spark.sql.functions.*; * people.select( people("age").gt(21) ); * }}} + * + * @group expr_ops */ - def > (other: Any): Column = constructColumn(other) { o => - GreaterThan(expr, o.expr) - } + def > (other: Any): Column = GreaterThan(expr, lit(other).expr) /** * Greater than. @@ -204,9 +189,11 @@ trait Column extends DataFrame { * people.select( people("age") > lit(21) ) * * // Java: - * import static org.apache.spark.sql.Dsl.*; + * import static org.apache.spark.sql.functions.*; * people.select( people("age").gt(21) ); * }}} + * + * @group java_expr_ops */ def gt(other: Any): Column = this > other @@ -219,10 +206,10 @@ trait Column extends DataFrame { * // Java: * people.select( people("age").lt(21) ); * }}} + * + * @group expr_ops */ - def < (other: Any): Column = constructColumn(other) { o => - LessThan(expr, o.expr) - } + def < (other: Any): Column = LessThan(expr, lit(other).expr) /** * Less than. @@ -233,6 +220,8 @@ trait Column extends DataFrame { * // Java: * people.select( people("age").lt(21) ); * }}} + * + * @group java_expr_ops */ def lt(other: Any): Column = this < other @@ -245,10 +234,10 @@ trait Column extends DataFrame { * // Java: * people.select( people("age").leq(21) ); * }}} + * + * @group expr_ops */ - def <= (other: Any): Column = constructColumn(other) { o => - LessThanOrEqual(expr, o.expr) - } + def <= (other: Any): Column = LessThanOrEqual(expr, lit(other).expr) /** * Less than or equal to. @@ -259,6 +248,8 @@ trait Column extends DataFrame { * // Java: * people.select( people("age").leq(21) ); * }}} + * + * @group java_expr_ops */ def leq(other: Any): Column = this <= other @@ -271,10 +262,10 @@ trait Column extends DataFrame { * // Java: * people.select( people("age").geq(21) ) * }}} + * + * @group expr_ops */ - def >= (other: Any): Column = constructColumn(other) { o => - GreaterThanOrEqual(expr, o.expr) - } + def >= (other: Any): Column = GreaterThanOrEqual(expr, lit(other).expr) /** * Greater than or equal to an expression. @@ -285,30 +276,38 @@ trait Column extends DataFrame { * // Java: * people.select( people("age").geq(21) ) * }}} + * + * @group java_expr_ops */ def geq(other: Any): Column = this >= other /** * Equality test that is safe for null values. + * + * @group expr_ops */ - def <=> (other: Any): Column = constructColumn(other) { o => - EqualNullSafe(expr, o.expr) - } + def <=> (other: Any): Column = EqualNullSafe(expr, lit(other).expr) /** * Equality test that is safe for null values. + * + * @group java_expr_ops */ def eqNullSafe(other: Any): Column = this <=> other /** * True if the current expression is null. + * + * @group expr_ops */ - def isNull: Column = exprToColumn(IsNull(expr)) + def isNull: Column = IsNull(expr) /** * True if the current expression is NOT null. + * + * @group expr_ops */ - def isNotNull: Column = exprToColumn(IsNotNull(expr)) + def isNotNull: Column = IsNotNull(expr) /** * Boolean OR. @@ -319,10 +318,10 @@ trait Column extends DataFrame { * // Java: * people.filter( people("inSchool").or(people("isEmployed")) ); * }}} + * + * @group expr_ops */ - def || (other: Any): Column = constructColumn(other) { o => - Or(expr, o.expr) - } + def || (other: Any): Column = Or(expr, lit(other).expr) /** * Boolean OR. @@ -333,6 +332,8 @@ trait Column extends DataFrame { * // Java: * people.filter( people("inSchool").or(people("isEmployed")) ); * }}} + * + * @group java_expr_ops */ def or(other: Column): Column = this || other @@ -345,10 +346,10 @@ trait Column extends DataFrame { * // Java: * people.select( people("inSchool").and(people("isEmployed")) ); * }}} + * + * @group expr_ops */ - def && (other: Any): Column = constructColumn(other) { o => - And(expr, o.expr) - } + def && (other: Any): Column = And(expr, lit(other).expr) /** * Boolean AND. @@ -359,30 +360,11 @@ trait Column extends DataFrame { * // Java: * people.select( people("inSchool").and(people("isEmployed")) ); * }}} + * + * @group java_expr_ops */ def and(other: Column): Column = this && other - /** - * Bitwise AND. - */ - def & (other: Any): Column = constructColumn(other) { o => - BitwiseAnd(expr, o.expr) - } - - /** - * Bitwise OR with an expression. - */ - def | (other: Any): Column = constructColumn(other) { o => - BitwiseOr(expr, o.expr) - } - - /** - * Bitwise XOR with an expression. - */ - def ^ (other: Any): Column = constructColumn(other) { o => - BitwiseXor(expr, o.expr) - } - /** * Sum of this expression and another expression. * {{{ @@ -392,10 +374,10 @@ trait Column extends DataFrame { * // Java: * people.select( people("height").plus(people("weight")) ); * }}} + * + * @group expr_ops */ - def + (other: Any): Column = constructColumn(other) { o => - Add(expr, o.expr) - } + def + (other: Any): Column = Add(expr, lit(other).expr) /** * Sum of this expression and another expression. @@ -406,6 +388,8 @@ trait Column extends DataFrame { * // Java: * people.select( people("height").plus(people("weight")) ); * }}} + * + * @group java_expr_ops */ def plus(other: Any): Column = this + other @@ -418,10 +402,10 @@ trait Column extends DataFrame { * // Java: * people.select( people("height").minus(people("weight")) ); * }}} + * + * @group expr_ops */ - def - (other: Any): Column = constructColumn(other) { o => - Subtract(expr, o.expr) - } + def - (other: Any): Column = Subtract(expr, lit(other).expr) /** * Subtraction. Subtract the other expression from this expression. @@ -432,6 +416,8 @@ trait Column extends DataFrame { * // Java: * people.select( people("height").minus(people("weight")) ); * }}} + * + * @group java_expr_ops */ def minus(other: Any): Column = this - other @@ -444,10 +430,10 @@ trait Column extends DataFrame { * // Java: * people.select( people("height").multiply(people("weight")) ); * }}} + * + * @group expr_ops */ - def * (other: Any): Column = constructColumn(other) { o => - Multiply(expr, o.expr) - } + def * (other: Any): Column = Multiply(expr, lit(other).expr) /** * Multiplication of this expression and another expression. @@ -458,6 +444,8 @@ trait Column extends DataFrame { * // Java: * people.select( people("height").multiply(people("weight")) ); * }}} + * + * @group java_expr_ops */ def multiply(other: Any): Column = this * other @@ -470,10 +458,10 @@ trait Column extends DataFrame { * // Java: * people.select( people("height").divide(people("weight")) ); * }}} + * + * @group expr_ops */ - def / (other: Any): Column = constructColumn(other) { o => - Divide(expr, o.expr) - } + def / (other: Any): Column = Divide(expr, lit(other).expr) /** * Division this expression by another expression. @@ -484,74 +472,113 @@ trait Column extends DataFrame { * // Java: * people.select( people("height").divide(people("weight")) ); * }}} + * + * @group java_expr_ops */ def divide(other: Any): Column = this / other /** * Modulo (a.k.a. remainder) expression. + * + * @group expr_ops */ - def % (other: Any): Column = constructColumn(other) { o => - Remainder(expr, o.expr) - } + def % (other: Any): Column = Remainder(expr, lit(other).expr) /** * Modulo (a.k.a. remainder) expression. + * + * @group java_expr_ops */ def mod(other: Any): Column = this % other /** * A boolean expression that is evaluated to true if the value of this expression is contained * by the evaluated values of the arguments. + * + * @group expr_ops */ @scala.annotation.varargs - def in(list: Column*): Column = { - new IncomputableColumn(In(expr, list.map(_.expr))) - } + def in(list: Column*): Column = In(expr, list.map(_.expr)) - def like(literal: String): Column = exprToColumn(Like(expr, lit(literal).expr)) + /** + * SQL like expression. + * + * @group expr_ops + */ + def like(literal: String): Column = Like(expr, lit(literal).expr) - def rlike(literal: String): Column = exprToColumn(RLike(expr, lit(literal).expr)) + /** + * SQL RLIKE expression (LIKE with Regex). + * + * @group expr_ops + */ + def rlike(literal: String): Column = RLike(expr, lit(literal).expr) /** * An expression that gets an item at position `ordinal` out of an array. + * + * @group expr_ops */ - def getItem(ordinal: Int): Column = exprToColumn(GetItem(expr, Literal(ordinal))) + def getItem(ordinal: Int): Column = GetItem(expr, Literal(ordinal)) /** * An expression that gets a field by name in a [[StructField]]. + * + * @group expr_ops */ - def getField(fieldName: String): Column = exprToColumn(GetField(expr, fieldName)) + def getField(fieldName: String): Column = UnresolvedGetField(expr, fieldName) /** * An expression that returns a substring. * @param startPos expression for the starting position. * @param len expression for the length of the substring. + * + * @group expr_ops */ - def substr(startPos: Column, len: Column): Column = { - new IncomputableColumn(Substring(expr, startPos.expr, len.expr)) - } + def substr(startPos: Column, len: Column): Column = Substring(expr, startPos.expr, len.expr) /** * An expression that returns a substring. * @param startPos starting position. * @param len length of the substring. + * + * @group expr_ops */ - def substr(startPos: Int, len: Int): Column = this.substr(lit(startPos), lit(len)) + def substr(startPos: Int, len: Int): Column = Substring(expr, lit(startPos).expr, lit(len).expr) - def contains(other: Any): Column = constructColumn(other) { o => - Contains(expr, o.expr) - } + /** + * Contains the other element. + * + * @group expr_ops + */ + def contains(other: Any): Column = Contains(expr, lit(other).expr) - def startsWith(other: Column): Column = constructColumn(other) { o => - StartsWith(expr, o.expr) - } + /** + * String starts with. + * + * @group expr_ops + */ + def startsWith(other: Column): Column = StartsWith(expr, lit(other).expr) + /** + * String starts with another string literal. + * + * @group expr_ops + */ def startsWith(literal: String): Column = this.startsWith(lit(literal)) - def endsWith(other: Column): Column = constructColumn(other) { o => - EndsWith(expr, o.expr) - } + /** + * String ends with. + * + * @group expr_ops + */ + def endsWith(other: Column): Column = EndsWith(expr, lit(other).expr) + /** + * String ends with another string literal. + * + * @group expr_ops + */ def endsWith(literal: String): Column = this.endsWith(lit(literal)) /** @@ -560,8 +587,21 @@ trait Column extends DataFrame { * // Renames colA to colB in select output. * df.select($"colA".as("colB")) * }}} + * + * @group expr_ops + */ + def as(alias: String): Column = Alias(expr, alias)() + + /** + * Gives the column an alias. + * {{{ + * // Renames colA to colB in select output. + * df.select($"colA".as('colB)) + * }}} + * + * @group expr_ops */ - override def as(alias: String): Column = exprToColumn(Alias(expr, alias)()) + def as(alias: Symbol): Column = Alias(expr, alias.name)() /** * Casts the column to a different data type. @@ -573,8 +613,14 @@ trait Column extends DataFrame { * // equivalent to * df.select(df("colA").cast("int")) * }}} + * + * @group expr_ops */ - def cast(to: DataType): Column = exprToColumn(Cast(expr, to)) + def cast(to: DataType): Column = expr match { + // Lift alias out of cast so we can support col.as("name").cast(IntegerType) + case Alias(childExpr, name) => Alias(Cast(childExpr, to), name)() + case _ => Cast(expr, to) + } /** * Casts the column to a different data type, using the canonical string representation @@ -584,31 +630,60 @@ trait Column extends DataFrame { * // Casts colA to integer. * df.select(df("colA").cast("int")) * }}} + * + * @group expr_ops */ - def cast(to: String): Column = exprToColumn( - Cast(expr, to.toLowerCase match { - case "string" => StringType - case "boolean" => BooleanType - case "byte" => ByteType - case "short" => ShortType - case "int" => IntegerType - case "long" => LongType - case "float" => FloatType - case "double" => DoubleType - case "decimal" => DecimalType.Unlimited - case "date" => DateType - case "timestamp" => TimestampType - case _ => throw new RuntimeException(s"""Unsupported cast type: "$to"""") - }) - ) - - def desc: Column = exprToColumn(SortOrder(expr, Descending), computable = false) - - def asc: Column = exprToColumn(SortOrder(expr, Ascending), computable = false) + def cast(to: String): Column = cast(DataTypeParser(to)) + + /** + * Returns an ordering used in sorting. + * {{{ + * // Scala: sort a DataFrame by age column in descending order. + * df.sort(df("age").desc) + * + * // Java + * df.sort(df.col("age").desc()); + * }}} + * + * @group expr_ops + */ + def desc: Column = SortOrder(expr, Descending) + + /** + * Returns an ordering used in sorting. + * {{{ + * // Scala: sort a DataFrame by age column in ascending order. + * df.sort(df("age").asc) + * + * // Java + * df.sort(df.col("age").asc()); + * }}} + * + * @group expr_ops + */ + def asc: Column = SortOrder(expr, Ascending) + + /** + * Prints the expression to the console for debugging purpose. + * + * @group df_ops + */ + def explain(extended: Boolean): Unit = { + if (extended) { + println(expr) + } else { + println(expr.prettyString) + } + } } -class ColumnName(name: String) extends IncomputableColumn(name) { +/** + * :: Experimental :: + * A convenient class used for constructing schema. + */ +@Experimental +class ColumnName(name: String) extends Column(name) { /** Creates a new AttributeReference of type boolean */ def boolean: StructField = StructField(name, BooleanType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index f3bc07ae5238c..73c0fb493b255 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -17,102 +17,318 @@ package org.apache.spark.sql +import java.io.CharArrayWriter +import java.sql.DriverManager + +import scala.collection.JavaConversions._ +import scala.language.implicitConversions import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag +import scala.util.control.NonFatal + +import com.fasterxml.jackson.core.JsonFactory import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.sql.catalyst.{ScalaReflection, SqlParser} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, ResolvedStar} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} +import org.apache.spark.sql.jdbc.JDBCWriteDetails +import org.apache.spark.sql.json.JsonRDD +import org.apache.spark.sql.types._ +import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect} import org.apache.spark.util.Utils private[sql] object DataFrame { def apply(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = { - new DataFrameImpl(sqlContext, logicalPlan) + new DataFrame(sqlContext, logicalPlan) } } /** - * A collection of rows that have the same columns. + * :: Experimental :: + * A distributed collection of data organized into named columns. * - * A [[DataFrame]] is equivalent to a relational table in Spark SQL, and can be created using - * various functions in [[SQLContext]]. + * A [[DataFrame]] is equivalent to a relational table in Spark SQL. There are multiple ways + * to create a [[DataFrame]]: * {{{ + * // Create a DataFrame from Parquet files * val people = sqlContext.parquetFile("...") + * + * // Create a DataFrame from data sources + * val df = sqlContext.load("...", "json") * }}} * * Once created, it can be manipulated using the various domain-specific-language (DSL) functions - * defined in: [[DataFrame]] (this class), [[Column]], [[Dsl]] for the DSL. + * defined in: [[DataFrame]] (this class), [[Column]], and [[functions]]. * - * To select a column from the data frame, use the apply method: + * To select a column from the data frame, use `apply` method in Scala and `col` in Java. * {{{ * val ageCol = people("age") // in Scala - * Column ageCol = people.apply("age") // in Java + * Column ageCol = people.col("age") // in Java * }}} * * Note that the [[Column]] type can also be manipulated through its various functions. - * {{ + * {{{ * // The following creates a new column that increases everybody's age by 10. * people("age") + 10 // in Scala - * }} + * people.col("age").plus(10); // in Java + * }}} * - * A more concrete example: + * A more concrete example in Scala: * {{{ * // To create DataFrame using SQLContext * val people = sqlContext.parquetFile("...") * val department = sqlContext.parquetFile("...") * - * people.filter("age" > 30) + * people.filter("age > 30") * .join(department, people("deptId") === department("id")) * .groupBy(department("name"), "gender") * .agg(avg(people("salary")), max(people("age"))) * }}} + * + * and in Java: + * {{{ + * // To create DataFrame using SQLContext + * DataFrame people = sqlContext.parquetFile("..."); + * DataFrame department = sqlContext.parquetFile("..."); + * + * people.filter("age".gt(30)) + * .join(department, people.col("deptId").equalTo(department("id"))) + * .groupBy(department.col("name"), "gender") + * .agg(avg(people.col("salary")), max(people.col("age"))); + * }}} + * + * @groupname basic Basic DataFrame functions + * @groupname dfops Language Integrated Queries + * @groupname rdd RDD Operations + * @groupname output Output Operations + * @groupname action Actions */ // TODO: Improve documentation. -trait DataFrame extends RDDApi[Row] { +@Experimental +class DataFrame private[sql]( + @transient val sqlContext: SQLContext, + @DeveloperApi @transient val queryExecution: SQLContext#QueryExecution) + extends RDDApi[Row] with Serializable { + + /** + * A constructor that automatically analyzes the logical plan. + * + * This reports error eagerly as the [[DataFrame]] is constructed, unless + * [[SQLConf.dataFrameEagerAnalysis]] is turned off. + */ + def this(sqlContext: SQLContext, logicalPlan: LogicalPlan) = { + this(sqlContext, { + val qe = sqlContext.executePlan(logicalPlan) + if (sqlContext.conf.dataFrameEagerAnalysis) { + qe.assertAnalyzed() // This should force analysis and throw errors if there are any + } + qe + }) + } + + @transient protected[sql] val logicalPlan: LogicalPlan = queryExecution.logical match { + // For various commands (like DDL) and queries with side effects, we force query optimization to + // happen right away to let these side effects take place eagerly. + case _: Command | + _: InsertIntoTable | + _: CreateTableAsSelect[_] | + _: CreateTableUsingAsSelect | + _: WriteToFile => + LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) + case _ => + queryExecution.analyzed + } - val sqlContext: SQLContext + /** + * An implicit conversion function internal to this class for us to avoid doing + * "new DataFrame(...)" everywhere. + */ + @inline private implicit def logicalPlanToDataFrame(logicalPlan: LogicalPlan): DataFrame = { + new DataFrame(sqlContext, logicalPlan) + } - @DeveloperApi - def queryExecution: SQLContext#QueryExecution + protected[sql] def resolve(colName: String): NamedExpression = { + queryExecution.analyzed.resolve(colName, sqlContext.analyzer.resolver).getOrElse { + throw new AnalysisException( + s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""") + } + } - protected[sql] def logicalPlan: LogicalPlan + protected[sql] def numericColumns: Seq[Expression] = { + schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => + queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get + } + } + + /** + * Internal API for Python + * @param numRows Number of rows to show + */ + private[sql] def showString(numRows: Int): String = { + val data = take(numRows) + val numCols = schema.fieldNames.length + + // For cells that are beyond 20 characters, replace it with the first 17 and "..." + val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { row => + row.toSeq.map { cell => + val str = if (cell == null) "null" else cell.toString + if (str.length > 20) str.substring(0, 17) + "..." else str + }: Seq[String] + } + + // Compute the width of each column + val colWidths = Array.fill(numCols)(0) + for (row <- rows) { + for ((cell, i) <- row.zipWithIndex) { + colWidths(i) = math.max(colWidths(i), cell.length) + } + } + + // Pad the cells + rows.map { row => + row.zipWithIndex.map { case (cell, i) => + String.format(s"%-${colWidths(i)}s", cell) + }.mkString(" ") + }.mkString("\n") + } + + override def toString: String = { + try { + schema.map(f => s"${f.name}: ${f.dataType.simpleString}").mkString("[", ", ", "]") + } catch { + case NonFatal(e) => + s"Invalid tree; ${e.getMessage}:\n$queryExecution" + } + } /** Left here for backward compatibility. */ - @deprecated("1.3.0", "use toDataFrame") + @deprecated("1.3.0", "use toDF") def toSchemaRDD: DataFrame = this /** - * Returns the object itself. Used to force an implicit conversion from RDD to DataFrame in Scala. + * Returns the object itself. + * @group basic */ - def toDataFrame: DataFrame = this + // This is declared with parentheses to prevent the Scala compiler from treating + // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. + def toDF(): DataFrame = this /** * Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion * from a RDD of tuples into a [[DataFrame]] with meaningful names. For example: * {{{ * val rdd: RDD[(Int, String)] = ... - * rdd.toDataFrame // this implicit conversion creates a DataFrame with column name _1 and _2 - * rdd.toDataFrame("id", "name") // this creates a DataFrame with column name "id" and "name" + * rdd.toDF() // this implicit conversion creates a DataFrame with column name _1 and _2 + * rdd.toDF("id", "name") // this creates a DataFrame with column name "id" and "name" * }}} + * @group basic */ @scala.annotation.varargs - def toDataFrame(colName: String, colNames: String*): DataFrame + def toDF(colNames: String*): DataFrame = { + require(schema.size == colNames.size, + "The number of columns doesn't match.\n" + + s"Old column names (${schema.size}): " + schema.fields.map(_.name).mkString(", ") + "\n" + + s"New column names (${colNames.size}): " + colNames.mkString(", ")) + + val newCols = logicalPlan.output.zip(colNames).map { case (oldAttribute, newName) => + Column(oldAttribute).as(newName) + } + select(newCols :_*) + } - /** Returns the schema of this [[DataFrame]]. */ - def schema: StructType + /** + * Returns the schema of this [[DataFrame]]. + * @group basic + */ + def schema: StructType = queryExecution.analyzed.schema - /** Returns all column names and their data types as an array. */ - def dtypes: Array[(String, String)] + /** + * Returns all column names and their data types as an array. + * @group basic + */ + def dtypes: Array[(String, String)] = schema.fields.map { field => + (field.name, field.dataType.toString) + } - /** Returns all column names as an array. */ + /** + * Returns all column names as an array. + * @group basic + */ def columns: Array[String] = schema.fields.map(_.name) - /** Prints the schema to the console in a nice tree format. */ - def printSchema(): Unit + /** + * Prints the schema to the console in a nice tree format. + * @group basic + */ + def printSchema(): Unit = println(schema.treeString) + + /** + * Prints the plans (logical and physical) to the console for debugging purposes. + * @group basic + */ + def explain(extended: Boolean): Unit = { + ExplainCommand( + queryExecution.logical, + extended = extended).queryExecution.executedPlan.executeCollect().map { + r => println(r.getString(0)) + } + } + + /** + * Only prints the physical plan to the console for debugging purposes. + * @group basic + */ + def explain(): Unit = explain(extended = false) + + /** + * Returns true if the `collect` and `take` methods can be run locally + * (without any Spark executors). + * @group basic + */ + def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation] + + /** + * Displays the [[DataFrame]] in a tabular form. For example: + * {{{ + * year month AVG('Adj Close) MAX('Adj Close) + * 1980 12 0.503218 0.595103 + * 1981 01 0.523289 0.570307 + * 1982 02 0.436504 0.475256 + * 1983 03 0.410516 0.442194 + * 1984 04 0.450090 0.483521 + * }}} + * @param numRows Number of rows to show + * + * @group action + */ + def show(numRows: Int): Unit = println(showString(numRows)) + + /** + * Displays the top 20 rows of [[DataFrame]] in a tabular form. + * @group action + */ + def show(): Unit = show(20) + + /** + * Returns a [[DataFrameNaFunctions]] for working with missing data. + * {{{ + * // Dropping rows containing any null values. + * df.na.drop() + * }}} + * + * @group dfops + */ + def na: DataFrameNaFunctions = new DataFrameNaFunctions(this) /** * Cartesian join with another [[DataFrame]]. @@ -120,8 +336,11 @@ trait DataFrame extends RDDApi[Row] { * Note that cartesian joins are very expensive without an extra filter that can be pushed down. * * @param right Right side of the join operation. + * @group dfops */ - def join(right: DataFrame): DataFrame + def join(right: DataFrame): DataFrame = { + Join(logicalPlan, right.logicalPlan, joinType = Inner, None) + } /** * Inner join with another [[DataFrame]], using the given join expression. @@ -131,28 +350,34 @@ trait DataFrame extends RDDApi[Row] { * df1.join(df2, $"df1Key" === $"df2Key") * df1.join(df2).where($"df1Key" === $"df2Key") * }}} + * @group dfops */ - def join(right: DataFrame, joinExprs: Column): DataFrame + def join(right: DataFrame, joinExprs: Column): DataFrame = { + Join(logicalPlan, right.logicalPlan, joinType = Inner, Some(joinExprs.expr)) + } /** - * Join with another [[DataFrame]], usin g the given join expression. The following performs + * Join with another [[DataFrame]], using the given join expression. The following performs * a full outer join between `df1` and `df2`. * * {{{ * // Scala: - * import org.apache.spark.sql.dsl._ - * df1.join(df2, "outer", $"df1Key" === $"df2Key") + * import org.apache.spark.sql.functions._ + * df1.join(df2, $"df1Key" === $"df2Key", "outer") * * // Java: - * import static org.apache.spark.sql.Dsl.*; - * df1.join(df2, "outer", col("df1Key") === col("df2Key")); + * import static org.apache.spark.sql.functions.*; + * df1.join(df2, col("df1Key").equalTo(col("df2Key")), "outer"); * }}} * * @param right Right side of the join. * @param joinExprs Join expression. * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. + * @group dfops */ - def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame + def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = { + Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr)) + } /** * Returns a new [[DataFrame]] sorted by the specified column, all in ascending order. @@ -162,66 +387,94 @@ trait DataFrame extends RDDApi[Row] { * df.sort($"sortcol") * df.sort($"sortcol".asc) * }}} + * @group dfops */ @scala.annotation.varargs - def sort(sortCol: String, sortCols: String*): DataFrame + def sort(sortCol: String, sortCols: String*): DataFrame = { + sort((sortCol +: sortCols).map(apply) :_*) + } /** * Returns a new [[DataFrame]] sorted by the given expressions. For example: * {{{ * df.sort($"col1", $"col2".desc) * }}} + * @group dfops */ @scala.annotation.varargs - def sort(sortExpr: Column, sortExprs: Column*): DataFrame + def sort(sortExprs: Column*): DataFrame = { + val sortOrder: Seq[SortOrder] = sortExprs.map { col => + col.expr match { + case expr: SortOrder => + expr + case expr: Expression => + SortOrder(expr, Ascending) + } + } + Sort(sortOrder, global = true, logicalPlan) + } /** * Returns a new [[DataFrame]] sorted by the given expressions. * This is an alias of the `sort` function. + * @group dfops */ @scala.annotation.varargs - def orderBy(sortCol: String, sortCols: String*): DataFrame + def orderBy(sortCol: String, sortCols: String*): DataFrame = sort(sortCol, sortCols :_*) /** * Returns a new [[DataFrame]] sorted by the given expressions. * This is an alias of the `sort` function. + * @group dfops */ @scala.annotation.varargs - def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame + def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs :_*) /** * Selects column based on the column name and return it as a [[Column]]. + * @group dfops */ def apply(colName: String): Column = col(colName) /** * Selects column based on the column name and return it as a [[Column]]. + * @group dfops */ - def col(colName: String): Column + def col(colName: String): Column = colName match { + case "*" => + Column(ResolvedStar(schema.fieldNames.map(resolve))) + case _ => + val expr = resolve(colName) + Column(expr) + } /** - * Selects a set of expressions, wrapped in a Product. - * {{{ - * // The following two are equivalent: - * df.apply(($"colA", $"colB" + 1)) - * df.select($"colA", $"colB" + 1) - * }}} + * Returns a new [[DataFrame]] with an alias set. + * @group dfops */ - def apply(projection: Product): DataFrame + def as(alias: String): DataFrame = Subquery(alias, logicalPlan) /** - * Returns a new [[DataFrame]] with an alias set. + * (Scala-specific) Returns a new [[DataFrame]] with an alias set. + * @group dfops */ - def as(name: String): DataFrame + def as(alias: Symbol): DataFrame = as(alias.name) /** * Selects a set of expressions. * {{{ * df.select($"colA", $"colB" + 1) * }}} + * @group dfops */ @scala.annotation.varargs - def select(cols: Column*): DataFrame + def select(cols: Column*): DataFrame = { + val namedExpressions = cols.map { + case Column(expr: NamedExpression) => expr + case Column(expr: Expression) => Alias(expr, expr.prettyString)() + } + Project(namedExpressions.toSeq, logicalPlan) + } /** * Selects a set of columns. This is a variant of `select` that can only select @@ -232,9 +485,26 @@ trait DataFrame extends RDDApi[Row] { * df.select("colA", "colB") * df.select($"colA", $"colB") * }}} + * @group dfops + */ + @scala.annotation.varargs + def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)) :_*) + + /** + * Selects a set of SQL expressions. This is a variant of `select` that accepts + * SQL expressions. + * + * {{{ + * df.selectExpr("colA", "colB as newName", "abs(colC)") + * }}} + * @group dfops */ @scala.annotation.varargs - def select(col: String, cols: String*): DataFrame + def selectExpr(exprs: String*): DataFrame = { + select(exprs.map { expr => + Column(new SqlParser().parseExpression(expr)) + }: _*) + } /** * Filters rows using the given condition. @@ -244,34 +514,36 @@ trait DataFrame extends RDDApi[Row] { * peopleDf.where($"age" > 15) * peopleDf($"age" > 15) * }}} + * @group dfops */ - def filter(condition: Column): DataFrame + def filter(condition: Column): DataFrame = Filter(condition.expr, logicalPlan) /** - * Filters rows using the given condition. This is an alias for `filter`. + * Filters rows using the given SQL expression. * {{{ - * // The following are equivalent: - * peopleDf.filter($"age" > 15) - * peopleDf.where($"age" > 15) - * peopleDf($"age" > 15) + * peopleDf.filter("age > 15") * }}} + * @group dfops */ - def where(condition: Column): DataFrame + def filter(conditionExpr: String): DataFrame = { + filter(Column(new SqlParser().parseExpression(conditionExpr))) + } /** - * Filters rows using the given condition. This is a shorthand meant for Scala. + * Filters rows using the given condition. This is an alias for `filter`. * {{{ * // The following are equivalent: * peopleDf.filter($"age" > 15) * peopleDf.where($"age" > 15) * peopleDf($"age" > 15) * }}} + * @group dfops */ - def apply(condition: Column): DataFrame + def where(condition: Column): DataFrame = filter(condition) /** * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. - * See [[GroupedDataFrame]] for all the available aggregate functions. + * See [[GroupedData]] for all the available aggregate functions. * * {{{ * // Compute the average for all numeric columns grouped by department. @@ -283,13 +555,14 @@ trait DataFrame extends RDDApi[Row] { * "age" -> "max" * )) * }}} + * @group dfops */ @scala.annotation.varargs - def groupBy(cols: Column*): GroupedDataFrame + def groupBy(cols: Column*): GroupedData = new GroupedData(this, cols.map(_.expr)) /** * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. - * See [[GroupedDataFrame]] for all the available aggregate functions. + * See [[GroupedData]] for all the available aggregate functions. * * This is a variant of groupBy that can only group by existing columns using column names * (i.e. cannot construct expressions). @@ -304,9 +577,13 @@ trait DataFrame extends RDDApi[Row] { * "age" -> "max" * )) * }}} + * @group dfops */ @scala.annotation.varargs - def groupBy(col1: String, cols: String*): GroupedDataFrame + def groupBy(col1: String, cols: String*): GroupedData = { + val colNames: Seq[String] = col1 +: cols + new GroupedData(this, colNames.map(colName => resolve(colName))) + } /** * (Scala-specific) Compute aggregates by specifying a map from column name to @@ -320,6 +597,7 @@ trait DataFrame extends RDDApi[Row] { * "expense" -> "sum" * ) * }}} + * @group dfops */ def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { groupBy().agg(aggExpr, aggExprs :_*) @@ -332,6 +610,7 @@ trait DataFrame extends RDDApi[Row] { * df.agg(Map("age" -> "max", "salary" -> "avg")) * df.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) * }} + * @group dfops */ def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs) @@ -342,6 +621,7 @@ trait DataFrame extends RDDApi[Row] { * df.agg(Map("age" -> "max", "salary" -> "avg")) * df.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) * }} + * @group dfops */ def agg(exprs: java.util.Map[String, String]): DataFrame = groupBy().agg(exprs) @@ -352,6 +632,7 @@ trait DataFrame extends RDDApi[Row] { * df.agg(max($"age"), avg($"salary")) * df.groupBy().agg(max($"age"), avg($"salary")) * }} + * @group dfops */ @scala.annotation.varargs def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs :_*) @@ -359,26 +640,30 @@ trait DataFrame extends RDDApi[Row] { /** * Returns a new [[DataFrame]] by taking the first `n` rows. The difference between this function * and `head` is that `head` returns an array while `limit` returns a new [[DataFrame]]. + * @group dfops */ - def limit(n: Int): DataFrame + def limit(n: Int): DataFrame = Limit(Literal(n), logicalPlan) /** * Returns a new [[DataFrame]] containing union of rows in this frame and another frame. * This is equivalent to `UNION ALL` in SQL. + * @group dfops */ - def unionAll(other: DataFrame): DataFrame + def unionAll(other: DataFrame): DataFrame = Union(logicalPlan, other.logicalPlan) /** * Returns a new [[DataFrame]] containing rows only in both this frame and another frame. * This is equivalent to `INTERSECT` in SQL. + * @group dfops */ - def intersect(other: DataFrame): DataFrame + def intersect(other: DataFrame): DataFrame = Intersect(logicalPlan, other.logicalPlan) /** * Returns a new [[DataFrame]] containing rows in this frame but not in another frame. * This is equivalent to `EXCEPT` in SQL. + * @group dfops */ - def except(other: DataFrame): DataFrame + def except(other: DataFrame): DataFrame = Except(logicalPlan, other.logicalPlan) /** * Returns a new [[DataFrame]] by sampling a fraction of rows. @@ -386,96 +671,282 @@ trait DataFrame extends RDDApi[Row] { * @param withReplacement Sample with replacement or not. * @param fraction Fraction of rows to generate. * @param seed Seed for sampling. + * @group dfops */ - def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame + def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = { + Sample(fraction, withReplacement, seed, logicalPlan) + } /** * Returns a new [[DataFrame]] by sampling a fraction of rows, using a random seed. * * @param withReplacement Sample with replacement or not. * @param fraction Fraction of rows to generate. + * @group dfops */ def sample(withReplacement: Boolean, fraction: Double): DataFrame = { sample(withReplacement, fraction, Utils.random.nextLong) } + /** + * (Scala-specific) Returns a new [[DataFrame]] where each row has been expanded to zero or more + * rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of + * the input row are implicitly joined with each row that is output by the function. + * + * The following example uses this function to count the number of books which contain + * a given word: + * + * {{{ + * case class Book(title: String, words: String) + * val df: RDD[Book] + * + * case class Word(word: String) + * val allWords = df.explode('words) { + * case Row(words: String) => words.split(" ").map(Word(_)) + * } + * + * val bookCountPerWord = allWords.groupBy("word").agg(countDistinct("title")) + * }}} + * @group dfops + */ + def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = { + val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] + val attributes = schema.toAttributes + val rowFunction = + f.andThen(_.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row])) + val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr)) + + Generate(generator, join = true, outer = false, None, logicalPlan) + } + + /** + * (Scala-specific) Returns a new [[DataFrame]] where a single column has been expanded to zero + * or more rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. All + * columns of the input row are implicitly joined with each value that is output by the function. + * + * {{{ + * df.explode("words", "word")(words: String => words.split(" ")) + * }}} + * @group dfops + */ + def explode[A, B : TypeTag](inputColumn: String, outputColumn: String)(f: A => TraversableOnce[B]) + : DataFrame = { + val dataType = ScalaReflection.schemaFor[B].dataType + val attributes = AttributeReference(outputColumn, dataType)() :: Nil + def rowFunction(row: Row): TraversableOnce[Row] = { + f(row(0).asInstanceOf[A]).map(o => Row(ScalaReflection.convertToCatalyst(o, dataType))) + } + val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil) + + Generate(generator, join = true, outer = false, None, logicalPlan) + } + ///////////////////////////////////////////////////////////////////////////// /** * Returns a new [[DataFrame]] by adding a column. + * @group dfops */ - def addColumn(colName: String, col: Column): DataFrame + def withColumn(colName: String, col: Column): DataFrame = select(Column("*"), col.as(colName)) + + /** + * Returns a new [[DataFrame]] with a column renamed. + * @group dfops + */ + def withColumnRenamed(existingName: String, newName: String): DataFrame = { + val resolver = sqlContext.analyzer.resolver + val colNames = schema.map { field => + val name = field.name + if (resolver(name, existingName)) Column(name).as(newName) else Column(name) + } + select(colNames :_*) + } + + /** + * Computes statistics for numeric columns, including count, mean, stddev, min, and max. + * If no columns are given, this function computes statistics for all numerical columns. + * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting [[DataFrame]]. If you want to + * programmatically compute summary statistics, use the `agg` function instead. + * + * {{{ + * df.describe("age", "height").show() + * + * // output: + * // summary age height + * // count 10.0 10.0 + * // mean 53.3 178.05 + * // stddev 11.6 15.7 + * // min 18.0 163.0 + * // max 92.0 192.0 + * }}} + * + * @group action + */ + @scala.annotation.varargs + def describe(cols: String*): DataFrame = { + + // TODO: Add stddev as an expression, and remove it from here. + def stddevExpr(expr: Expression): Expression = + Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr)))) + + // The list of summary statistics to compute, in the form of expressions. + val statistics = List[(String, Expression => Expression)]( + "count" -> Count, + "mean" -> Average, + "stddev" -> stddevExpr, + "min" -> Min, + "max" -> Max) + + val outputCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList + + val ret: Seq[Row] = if (outputCols.nonEmpty) { + val aggExprs = statistics.flatMap { case (_, colToAgg) => + outputCols.map(c => Column(colToAgg(Column(c).expr)).as(c)) + } + + val row = agg(aggExprs.head, aggExprs.tail: _*).head().toSeq + + // Pivot the data so each summary is one row + row.grouped(outputCols.size).toSeq.zip(statistics).map { + case (aggregation, (statistic, _)) => Row(statistic :: aggregation.toList: _*) + } + } else { + // If there are no output columns, just output a single column that contains the stats. + statistics.map { case (name, _) => Row(name) } + } + + // The first column is string type, and the rest are double type. + val schema = StructType( + StructField("summary", StringType) :: outputCols.map(StructField(_, DoubleType))).toAttributes + LocalRelation(schema, ret) + } /** * Returns the first `n` rows. + * @group action */ - def head(n: Int): Array[Row] + def head(n: Int): Array[Row] = limit(n).collect() /** * Returns the first row. + * @group action */ - def head(): Row + def head(): Row = head(1).head /** * Returns the first row. Alias for head(). + * @group action */ - override def first(): Row + override def first(): Row = head() /** * Returns a new RDD by applying a function to all rows of this DataFrame. + * @group rdd */ - override def map[R: ClassTag](f: Row => R): RDD[R] + override def map[R: ClassTag](f: Row => R): RDD[R] = rdd.map(f) /** * Returns a new RDD by first applying a function to all rows of this [[DataFrame]], * and then flattening the results. + * @group rdd */ - override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] + override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f) /** * Returns a new RDD by applying a function to each partition of this DataFrame. + * @group rdd */ - override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] + override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = { + rdd.mapPartitions(f) + } + /** * Applies a function `f` to all rows. + * @group rdd */ - override def foreach(f: Row => Unit): Unit + override def foreach(f: Row => Unit): Unit = rdd.foreach(f) /** * Applies a function f to each partition of this [[DataFrame]]. + * @group rdd */ - override def foreachPartition(f: Iterator[Row] => Unit): Unit + override def foreachPartition(f: Iterator[Row] => Unit): Unit = rdd.foreachPartition(f) /** * Returns the first `n` rows in the [[DataFrame]]. + * @group action */ - override def take(n: Int): Array[Row] + override def take(n: Int): Array[Row] = head(n) /** * Returns an array that contains all of [[Row]]s in this [[DataFrame]]. + * @group action */ - override def collect(): Array[Row] + override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect() /** * Returns a Java list that contains all of [[Row]]s in this [[DataFrame]]. + * @group action */ - override def collectAsList(): java.util.List[Row] + override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() :_*) /** * Returns the number of rows in the [[DataFrame]]. + * @group action */ - override def count(): Long + override def count(): Long = groupBy().count().collect().head.getLong(0) /** * Returns a new [[DataFrame]] that has exactly `numPartitions` partitions. + * @group rdd */ - override def repartition(numPartitions: Int): DataFrame + override def repartition(numPartitions: Int): DataFrame = { + sqlContext.createDataFrame( + queryExecution.toRdd.map(_.copy()).repartition(numPartitions), + schema, needsConversion = false) + } - override def persist(): this.type + /** + * Returns a new [[DataFrame]] that contains only the unique rows from this [[DataFrame]]. + * @group dfops + */ + override def distinct: DataFrame = Distinct(logicalPlan) - override def persist(newLevel: StorageLevel): this.type + /** + * @group basic + */ + override def persist(): this.type = { + sqlContext.cacheManager.cacheQuery(this) + this + } + + /** + * @group basic + */ + override def cache(): this.type = persist() + + /** + * @group basic + */ + override def persist(newLevel: StorageLevel): this.type = { + sqlContext.cacheManager.cacheQuery(this, None, newLevel) + this + } - override def unpersist(blocking: Boolean): this.type + /** + * @group basic + */ + override def unpersist(blocking: Boolean): this.type = { + sqlContext.cacheManager.tryUncacheQuery(this, blocking) + this + } + + /** + * @group basic + */ + override def unpersist(): this.type = unpersist(blocking = false) ///////////////////////////////////////////////////////////////////////////// // I/O @@ -483,113 +954,350 @@ trait DataFrame extends RDDApi[Row] { /** * Returns the content of the [[DataFrame]] as an [[RDD]] of [[Row]]s. + * @group rdd */ - def rdd: RDD[Row] + def rdd: RDD[Row] = { + // use a local variable to make sure the map closure doesn't capture the whole DataFrame + val schema = this.schema + queryExecution.executedPlan.execute().map(ScalaReflection.convertRowToScala(_, schema)) + } /** * Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s. + * @group rdd */ def toJavaRDD: JavaRDD[Row] = rdd.toJavaRDD() /** * Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s. + * @group rdd */ def javaRDD: JavaRDD[Row] = toJavaRDD /** - * Registers this RDD as a temporary table using the given name. The lifetime of this temporary - * table is tied to the [[SQLContext]] that was used to create this DataFrame. + * Registers this [[DataFrame]] as a temporary table using the given name. The lifetime of this + * temporary table is tied to the [[SQLContext]] that was used to create this DataFrame. * - * @group schema + * @group basic */ - def registerTempTable(tableName: String): Unit + def registerTempTable(tableName: String): Unit = { + sqlContext.registerDataFrameAsTable(this, tableName) + } /** * Saves the contents of this [[DataFrame]] as a parquet file, preserving the schema. * Files that are written out using this method can be read back in as a [[DataFrame]] * using the `parquetFile` function in [[SQLContext]]. + * @group output */ - def saveAsParquetFile(path: String): Unit + def saveAsParquetFile(path: String): Unit = { + if (sqlContext.conf.parquetUseDataSourceApi) { + save("org.apache.spark.sql.parquet", SaveMode.ErrorIfExists, Map("path" -> path)) + } else { + sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd + } + } /** * :: Experimental :: - * Creates a table from the the contents of this DataFrame. This will fail if the table already - * exists. + * Creates a table from the the contents of this DataFrame. + * It will use the default data source configured by spark.sql.sources.default. + * This will fail if the table already exists. + * + * Note that this currently only works with DataFrames that are created from a HiveContext as + * there is no notion of a persisted catalog in a standard SQL context. Instead you can write + * an RDD out to a parquet file, and then register that file as a table. This "table" can then + * be the target of an `insertInto`. + * @group output + */ + @Experimental + def saveAsTable(tableName: String): Unit = { + saveAsTable(tableName, SaveMode.ErrorIfExists) + } + + /** + * :: Experimental :: + * Creates a table from the the contents of this DataFrame, using the default data source + * configured by spark.sql.sources.default and [[SaveMode.ErrorIfExists]] as the save mode. + * + * Note that this currently only works with DataFrames that are created from a HiveContext as + * there is no notion of a persisted catalog in a standard SQL context. Instead you can write + * an RDD out to a parquet file, and then register that file as a table. This "table" can then + * be the target of an `insertInto`. + * @group output + */ + @Experimental + def saveAsTable(tableName: String, mode: SaveMode): Unit = { + if (sqlContext.catalog.tableExists(Seq(tableName)) && mode == SaveMode.Append) { + // If table already exists and the save mode is Append, + // we will just call insertInto to append the contents of this DataFrame. + insertInto(tableName, overwrite = false) + } else { + val dataSourceName = sqlContext.conf.defaultDataSourceName + saveAsTable(tableName, dataSourceName, mode) + } + } + + /** + * :: Experimental :: + * Creates a table at the given path from the the contents of this DataFrame + * based on a given data source and a set of options, + * using [[SaveMode.ErrorIfExists]] as the save mode. + * + * Note that this currently only works with DataFrames that are created from a HiveContext as + * there is no notion of a persisted catalog in a standard SQL context. Instead you can write + * an RDD out to a parquet file, and then register that file as a table. This "table" can then + * be the target of an `insertInto`. + * @group output + */ + @Experimental + def saveAsTable(tableName: String, source: String): Unit = { + saveAsTable(tableName, source, SaveMode.ErrorIfExists) + } + + /** + * :: Experimental :: + * Creates a table at the given path from the the contents of this DataFrame + * based on a given data source, [[SaveMode]] specified by mode, and a set of options. * * Note that this currently only works with DataFrames that are created from a HiveContext as * there is no notion of a persisted catalog in a standard SQL context. Instead you can write * an RDD out to a parquet file, and then register that file as a table. This "table" can then * be the target of an `insertInto`. + * @group output */ @Experimental - def saveAsTable(tableName: String): Unit + def saveAsTable(tableName: String, source: String, mode: SaveMode): Unit = { + saveAsTable(tableName, source, mode, Map.empty[String, String]) + } /** * :: Experimental :: - * Creates a table from the the contents of this DataFrame based on a given data source and - * a set of options. This will fail if the table already exists. + * Creates a table at the given path from the the contents of this DataFrame + * based on a given data source, [[SaveMode]] specified by mode, and a set of options. * * Note that this currently only works with DataFrames that are created from a HiveContext as * there is no notion of a persisted catalog in a standard SQL context. Instead you can write * an RDD out to a parquet file, and then register that file as a table. This "table" can then * be the target of an `insertInto`. + * @group output */ @Experimental def saveAsTable( tableName: String, - dataSourceName: String, - option: (String, String), - options: (String, String)*): Unit + source: String, + mode: SaveMode, + options: java.util.Map[String, String]): Unit = { + saveAsTable(tableName, source, mode, options.toMap) + } /** * :: Experimental :: - * Creates a table from the the contents of this DataFrame based on a given data source and - * a set of options. This will fail if the table already exists. + * (Scala-specific) + * Creates a table from the the contents of this DataFrame based on a given data source, + * [[SaveMode]] specified by mode, and a set of options. * * Note that this currently only works with DataFrames that are created from a HiveContext as * there is no notion of a persisted catalog in a standard SQL context. Instead you can write * an RDD out to a parquet file, and then register that file as a table. This "table" can then * be the target of an `insertInto`. + * @group output */ @Experimental def saveAsTable( tableName: String, - dataSourceName: String, - options: java.util.Map[String, String]): Unit + source: String, + mode: SaveMode, + options: Map[String, String]): Unit = { + val cmd = + CreateTableUsingAsSelect( + tableName, + source, + temporary = false, + mode, + options, + logicalPlan) + + sqlContext.executePlan(cmd).toRdd + } + + /** + * :: Experimental :: + * Saves the contents of this DataFrame to the given path, + * using the default data source configured by spark.sql.sources.default and + * [[SaveMode.ErrorIfExists]] as the save mode. + * @group output + */ + @Experimental + def save(path: String): Unit = { + save(path, SaveMode.ErrorIfExists) + } + + /** + * :: Experimental :: + * Saves the contents of this DataFrame to the given path and [[SaveMode]] specified by mode, + * using the default data source configured by spark.sql.sources.default. + * @group output + */ + @Experimental + def save(path: String, mode: SaveMode): Unit = { + val dataSourceName = sqlContext.conf.defaultDataSourceName + save(path, dataSourceName, mode) + } + + /** + * :: Experimental :: + * Saves the contents of this DataFrame to the given path based on the given data source, + * using [[SaveMode.ErrorIfExists]] as the save mode. + * @group output + */ + @Experimental + def save(path: String, source: String): Unit = { + save(source, SaveMode.ErrorIfExists, Map("path" -> path)) + } + /** + * :: Experimental :: + * Saves the contents of this DataFrame to the given path based on the given data source and + * [[SaveMode]] specified by mode. + * @group output + */ @Experimental - def save(path: String): Unit + def save(path: String, source: String, mode: SaveMode): Unit = { + save(source, mode, Map("path" -> path)) + } + /** + * :: Experimental :: + * Saves the contents of this DataFrame based on the given data source, + * [[SaveMode]] specified by mode, and a set of options. + * @group output + */ @Experimental def save( - dataSourceName: String, - option: (String, String), - options: (String, String)*): Unit + source: String, + mode: SaveMode, + options: java.util.Map[String, String]): Unit = { + save(source, mode, options.toMap) + } + /** + * :: Experimental :: + * (Scala-specific) + * Saves the contents of this DataFrame based on the given data source, + * [[SaveMode]] specified by mode, and a set of options + * @group output + */ @Experimental def save( - dataSourceName: String, - options: java.util.Map[String, String]): Unit + source: String, + mode: SaveMode, + options: Map[String, String]): Unit = { + ResolvedDataSource(sqlContext, source, mode, options, this) + } /** * :: Experimental :: * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. + * @group output */ @Experimental - def insertInto(tableName: String, overwrite: Boolean): Unit + def insertInto(tableName: String, overwrite: Boolean): Unit = { + sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)), + Map.empty, logicalPlan, overwrite)).toRdd + } /** * :: Experimental :: * Adds the rows from this RDD to the specified table. * Throws an exception if the table already exists. + * @group output */ @Experimental def insertInto(tableName: String): Unit = insertInto(tableName, overwrite = false) /** * Returns the content of the [[DataFrame]] as a RDD of JSON strings. + * @group rdd */ - def toJSON: RDD[String] + def toJSON: RDD[String] = { + val rowSchema = this.schema + this.mapPartitions { iter => + val writer = new CharArrayWriter() + // create the Generator without separator inserted between 2 records + val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) + + new Iterator[String] { + override def hasNext: Boolean = iter.hasNext + override def next(): String = { + JsonRDD.rowToJSON(rowSchema, gen)(iter.next()) + gen.flush() + + val json = writer.toString + if (hasNext) { + writer.reset() + } else { + gen.close() + } + + json + } + } + } + } + + //////////////////////////////////////////////////////////////////////////// + // JDBC Write Support + //////////////////////////////////////////////////////////////////////////// + + /** + * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. + * This will run a `CREATE TABLE` and a bunch of `INSERT INTO` statements. + * If you pass `true` for `allowExisting`, it will drop any table with the + * given name; if you pass `false`, it will throw if the table already + * exists. + * @group output + */ + def createJDBCTable(url: String, table: String, allowExisting: Boolean): Unit = { + val conn = DriverManager.getConnection(url) + try { + if (allowExisting) { + val sql = s"DROP TABLE IF EXISTS $table" + conn.prepareStatement(sql).executeUpdate() + } + val schema = JDBCWriteDetails.schemaString(this, url) + val sql = s"CREATE TABLE $table ($schema)" + conn.prepareStatement(sql).executeUpdate() + } finally { + conn.close() + } + JDBCWriteDetails.saveTable(this, url, table) + } + + /** + * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. + * Assumes the table already exists and has a compatible schema. If you + * pass `true` for `overwrite`, it will `TRUNCATE` the table before + * performing the `INSERT`s. + * + * The table must already exist on the database. It must have a schema + * that is compatible with the schema of this RDD; inserting the rows of + * the RDD in order via the simple statement + * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail. + * @group output + */ + def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit = { + if (overwrite) { + val conn = DriverManager.getConnection(url) + try { + val sql = s"TRUNCATE TABLE $table" + conn.prepareStatement(sql).executeUpdate() + } finally { + conn.close() + } + } + JDBCWriteDetails.saveTable(this, url, table) + } //////////////////////////////////////////////////////////////////////////// // for Python API @@ -598,5 +1306,9 @@ trait DataFrame extends RDDApi[Row] { /** * Converts a JavaRDD to a PythonRDD. */ - protected[sql] def javaToPython: JavaRDD[Array[Byte]] + protected[sql] def javaToPython: JavaRDD[Array[Byte]] = { + val fieldTypes = schema.fields.map(_.dataType) + val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() + SerDeUtil.javaToPython(jrdd) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ComputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala similarity index 65% rename from sql/core/src/main/scala/org/apache/spark/sql/ComputableColumn.scala rename to sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala index ac479b26a7c6a..a3187fe3230fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ComputableColumn.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala @@ -17,17 +17,14 @@ package org.apache.spark.sql -import scala.language.implicitConversions +/** + * A container for a [[DataFrame]], used for implicit conversions. + */ +private[sql] case class DataFrameHolder(df: DataFrame) { -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + // This is declared with parentheses to prevent the Scala compiler from treating + // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. + def toDF(): DataFrame = df - -private[sql] class ComputableColumn protected[sql]( - sqlContext: SQLContext, - protected[sql] val plan: LogicalPlan, - protected[sql] val expr: Expression) - extends DataFrameImpl(sqlContext, plan) with Column { - - override def isComputable: Boolean = true + def toDF(colNames: String*): DataFrame = df.toDF(colNames :_*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala deleted file mode 100644 index 0b0623dc1fe75..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala +++ /dev/null @@ -1,370 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql - -import scala.language.implicitConversions -import scala.reflect.ClassTag -import scala.collection.JavaConversions._ - -import com.fasterxml.jackson.core.JsonFactory - -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.api.python.SerDeUtil -import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedRelation} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython} -import org.apache.spark.sql.json.JsonRDD -import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsLogicalPlan} -import org.apache.spark.sql.types.{NumericType, StructType} - - -/** - * Internal implementation of [[DataFrame]]. Users of the API should use [[DataFrame]] directly. - */ -private[sql] class DataFrameImpl protected[sql]( - override val sqlContext: SQLContext, - val queryExecution: SQLContext#QueryExecution) - extends DataFrame { - - /** - * A constructor that automatically analyzes the logical plan. This reports error eagerly - * as the [[DataFrame]] is constructed. - */ - def this(sqlContext: SQLContext, logicalPlan: LogicalPlan) = { - this(sqlContext, { - val qe = sqlContext.executePlan(logicalPlan) - qe.analyzed // This should force analysis and throw errors if there are any - qe - }) - } - - @transient protected[sql] override val logicalPlan: LogicalPlan = queryExecution.logical match { - // For various commands (like DDL) and queries with side effects, we force query optimization to - // happen right away to let these side effects take place eagerly. - case _: Command | _: InsertIntoTable | _: CreateTableAsSelect[_] |_: WriteToFile => - LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) - case _ => - queryExecution.logical - } - - /** - * An implicit conversion function internal to this class for us to avoid doing - * "new DataFrameImpl(...)" everywhere. - */ - @inline private implicit def logicalPlanToDataFrame(logicalPlan: LogicalPlan): DataFrame = { - new DataFrameImpl(sqlContext, logicalPlan) - } - - protected[sql] def resolve(colName: String): NamedExpression = { - queryExecution.analyzed.resolve(colName, sqlContext.analyzer.resolver).getOrElse { - throw new RuntimeException( - s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""") - } - } - - protected[sql] def numericColumns: Seq[Expression] = { - schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => - queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get - } - } - - override def toDataFrame(colName: String, colNames: String*): DataFrame = { - val newNames = colName +: colNames - require(schema.size == newNames.size, - "The number of columns doesn't match.\n" + - "Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" + - "New column names: " + newNames.mkString(", ")) - - val newCols = schema.fieldNames.zip(newNames).map { case (oldName, newName) => - apply(oldName).as(newName) - } - select(newCols :_*) - } - - override def schema: StructType = queryExecution.analyzed.schema - - override def dtypes: Array[(String, String)] = schema.fields.map { field => - (field.name, field.dataType.toString) - } - - override def columns: Array[String] = schema.fields.map(_.name) - - override def printSchema(): Unit = println(schema.treeString) - - override def join(right: DataFrame): DataFrame = { - Join(logicalPlan, right.logicalPlan, joinType = Inner, None) - } - - override def join(right: DataFrame, joinExprs: Column): DataFrame = { - Join(logicalPlan, right.logicalPlan, Inner, Some(joinExprs.expr)) - } - - override def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = { - Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr)) - } - - override def sort(sortCol: String, sortCols: String*): DataFrame = { - orderBy(apply(sortCol), sortCols.map(apply) :_*) - } - - override def sort(sortExpr: Column, sortExprs: Column*): DataFrame = { - val sortOrder: Seq[SortOrder] = (sortExpr +: sortExprs).map { col => - col.expr match { - case expr: SortOrder => - expr - case expr: Expression => - SortOrder(expr, Ascending) - } - } - Sort(sortOrder, global = true, logicalPlan) - } - - override def orderBy(sortCol: String, sortCols: String*): DataFrame = { - sort(sortCol, sortCols :_*) - } - - override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame = { - sort(sortExpr, sortExprs :_*) - } - - override def col(colName: String): Column = colName match { - case "*" => - Column(ResolvedStar(schema.fieldNames.map(resolve))) - case _ => - val expr = resolve(colName) - Column(sqlContext, Project(Seq(expr), logicalPlan), expr) - } - - override def apply(projection: Product): DataFrame = { - require(projection.productArity >= 1) - select(projection.productIterator.map { - case c: Column => c - case o: Any => Column(Literal(o)) - }.toSeq :_*) - } - - override def as(name: String): DataFrame = Subquery(name, logicalPlan) - - override def select(cols: Column*): DataFrame = { - val exprs = cols.zipWithIndex.map { - case (Column(expr: NamedExpression), _) => - expr - case (Column(expr: Expression), _) => - Alias(expr, expr.toString)() - } - Project(exprs.toSeq, logicalPlan) - } - - override def select(col: String, cols: String*): DataFrame = { - select((col +: cols).map(Column(_)) :_*) - } - - override def filter(condition: Column): DataFrame = { - Filter(condition.expr, logicalPlan) - } - - override def where(condition: Column): DataFrame = { - filter(condition) - } - - override def apply(condition: Column): DataFrame = { - filter(condition) - } - - override def groupBy(cols: Column*): GroupedDataFrame = { - new GroupedDataFrame(this, cols.map(_.expr)) - } - - override def groupBy(col1: String, cols: String*): GroupedDataFrame = { - val colNames: Seq[String] = col1 +: cols - new GroupedDataFrame(this, colNames.map(colName => resolve(colName))) - } - - override def limit(n: Int): DataFrame = { - Limit(Literal(n), logicalPlan) - } - - override def unionAll(other: DataFrame): DataFrame = { - Union(logicalPlan, other.logicalPlan) - } - - override def intersect(other: DataFrame): DataFrame = { - Intersect(logicalPlan, other.logicalPlan) - } - - override def except(other: DataFrame): DataFrame = { - Except(logicalPlan, other.logicalPlan) - } - - override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = { - Sample(fraction, withReplacement, seed, logicalPlan) - } - - ///////////////////////////////////////////////////////////////////////////// - - override def addColumn(colName: String, col: Column): DataFrame = { - select(Column("*"), col.as(colName)) - } - - override def head(n: Int): Array[Row] = limit(n).collect() - - override def head(): Row = head(1).head - - override def first(): Row = head() - - override def map[R: ClassTag](f: Row => R): RDD[R] = rdd.map(f) - - override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f) - - override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = { - rdd.mapPartitions(f) - } - - override def foreach(f: Row => Unit): Unit = rdd.foreach(f) - - override def foreachPartition(f: Iterator[Row] => Unit): Unit = rdd.foreachPartition(f) - - override def take(n: Int): Array[Row] = head(n) - - override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect() - - override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() :_*) - - override def count(): Long = groupBy().count().rdd.collect().head.getLong(0) - - override def repartition(numPartitions: Int): DataFrame = { - sqlContext.applySchema(rdd.repartition(numPartitions), schema) - } - - override def persist(): this.type = { - sqlContext.cacheManager.cacheQuery(this) - this - } - - override def persist(newLevel: StorageLevel): this.type = { - sqlContext.cacheManager.cacheQuery(this, None, newLevel) - this - } - - override def unpersist(blocking: Boolean): this.type = { - sqlContext.cacheManager.tryUncacheQuery(this, blocking) - this - } - - ///////////////////////////////////////////////////////////////////////////// - // I/O - ///////////////////////////////////////////////////////////////////////////// - - override def rdd: RDD[Row] = { - val schema = this.schema - queryExecution.executedPlan.execute().map(ScalaReflection.convertRowToScala(_, schema)) - } - - override def registerTempTable(tableName: String): Unit = { - sqlContext.registerRDDAsTable(this, tableName) - } - - override def saveAsParquetFile(path: String): Unit = { - sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd - } - - override def saveAsTable(tableName: String): Unit = { - val dataSourceName = sqlContext.conf.defaultDataSourceName - val cmd = - CreateTableUsingAsLogicalPlan( - tableName, - dataSourceName, - temporary = false, - Map.empty, - allowExisting = false, - logicalPlan) - - sqlContext.executePlan(cmd).toRdd - } - - override def saveAsTable( - tableName: String, - dataSourceName: String, - option: (String, String), - options: (String, String)*): Unit = { - val cmd = - CreateTableUsingAsLogicalPlan( - tableName, - dataSourceName, - temporary = false, - (option +: options).toMap, - allowExisting = false, - logicalPlan) - - sqlContext.executePlan(cmd).toRdd - } - - override def saveAsTable( - tableName: String, - dataSourceName: String, - options: java.util.Map[String, String]): Unit = { - val opts = options.toSeq - saveAsTable(tableName, dataSourceName, opts.head, opts.tail:_*) - } - - override def save(path: String): Unit = { - val dataSourceName = sqlContext.conf.defaultDataSourceName - save(dataSourceName, ("path" -> path)) - } - - override def save( - dataSourceName: String, - option: (String, String), - options: (String, String)*): Unit = { - ResolvedDataSource(sqlContext, dataSourceName, (option +: options).toMap, this) - } - - override def save( - dataSourceName: String, - options: java.util.Map[String, String]): Unit = { - val opts = options.toSeq - save(dataSourceName, opts.head, opts.tail:_*) - } - - override def insertInto(tableName: String, overwrite: Boolean): Unit = { - sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)), - Map.empty, logicalPlan, overwrite)).toRdd - } - - override def toJSON: RDD[String] = { - val rowSchema = this.schema - this.mapPartitions { iter => - val jsonFactory = new JsonFactory() - iter.map(JsonRDD.rowToJSON(rowSchema, jsonFactory)) - } - } - - //////////////////////////////////////////////////////////////////////////// - // for Python API - //////////////////////////////////////////////////////////////////////////// - protected[sql] override def javaToPython: JavaRDD[Array[Byte]] = { - val fieldTypes = schema.fields.map(_.dataType) - val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() - SerDeUtil.javaToPython(jrdd) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala new file mode 100644 index 0000000000000..bf3c3fe876873 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -0,0 +1,231 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import java.{lang => jl} + +import scala.collection.JavaConversions._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + + +/** + * :: Experimental :: + * Functionality for working with missing data in [[DataFrame]]s. + */ +@Experimental +final class DataFrameNaFunctions private[sql](df: DataFrame) { + + /** + * Returns a new [[DataFrame]] that drops rows containing any null values. + */ + def drop(): DataFrame = drop("any", df.columns) + + /** + * Returns a new [[DataFrame]] that drops rows containing null values. + * + * If `how` is "any", then drop rows containing any null values. + * If `how` is "all", then drop rows only if every column is null for that row. + */ + def drop(how: String): DataFrame = drop(how, df.columns) + + /** + * Returns a new [[DataFrame]] that drops rows containing any null values + * in the specified columns. + */ + def drop(cols: Array[String]): DataFrame = drop(cols.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame ]] that drops rows containing any null values + * in the specified columns. + */ + def drop(cols: Seq[String]): DataFrame = drop(cols.size, cols) + + /** + * Returns a new [[DataFrame]] that drops rows containing null values + * in the specified columns. + * + * If `how` is "any", then drop rows containing any null values in the specified columns. + * If `how` is "all", then drop rows only if every specified column is null for that row. + */ + def drop(how: String, cols: Array[String]): DataFrame = drop(how, cols.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame]] that drops rows containing null values + * in the specified columns. + * + * If `how` is "any", then drop rows containing any null values in the specified columns. + * If `how` is "all", then drop rows only if every specified column is null for that row. + */ + def drop(how: String, cols: Seq[String]): DataFrame = { + how.toLowerCase match { + case "any" => drop(cols.size, cols) + case "all" => drop(1, cols) + case _ => throw new IllegalArgumentException(s"how ($how) must be 'any' or 'all'") + } + } + + /** + * Returns a new [[DataFrame]] that drops rows containing less than `minNonNulls` non-null values. + */ + def drop(minNonNulls: Int): DataFrame = drop(minNonNulls, df.columns) + + /** + * Returns a new [[DataFrame]] that drops rows containing less than `minNonNulls` non-null + * values in the specified columns. + */ + def drop(minNonNulls: Int, cols: Array[String]): DataFrame = drop(minNonNulls, cols.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame]] that drops rows containing less than + * `minNonNulls` non-null values in the specified columns. + */ + def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = { + // Filtering condition -- only keep the row if it has at least `minNonNulls` non-null values. + val predicate = AtLeastNNonNulls(minNonNulls, cols.map(name => df.resolve(name))) + df.filter(Column(predicate)) + } + + /** + * Returns a new [[DataFrame]] that replaces null values in numeric columns with `value`. + */ + def fill(value: Double): DataFrame = fill(value, df.columns) + + /** + * Returns a new [[DataFrame ]] that replaces null values in string columns with `value`. + */ + def fill(value: String): DataFrame = fill(value, df.columns) + + /** + * Returns a new [[DataFrame]] that replaces null values in specified numeric columns. + * If a specified column is not a numeric column, it is ignored. + */ + def fill(value: Double, cols: Array[String]): DataFrame = fill(value, cols.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame]] that replaces null values in specified + * numeric columns. If a specified column is not a numeric column, it is ignored. + */ + def fill(value: Double, cols: Seq[String]): DataFrame = { + val columnEquals = df.sqlContext.analyzer.resolver + val projections = df.schema.fields.map { f => + // Only fill if the column is part of the cols list. + if (f.dataType.isInstanceOf[NumericType] && cols.exists(col => columnEquals(f.name, col))) { + fillCol[Double](f, value) + } else { + df.col(f.name) + } + } + df.select(projections : _*) + } + + /** + * Returns a new [[DataFrame]] that replaces null values in specified string columns. + * If a specified column is not a string column, it is ignored. + */ + def fill(value: String, cols: Array[String]): DataFrame = fill(value, cols.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame]] that replaces null values in + * specified string columns. If a specified column is not a string column, it is ignored. + */ + def fill(value: String, cols: Seq[String]): DataFrame = { + val columnEquals = df.sqlContext.analyzer.resolver + val projections = df.schema.fields.map { f => + // Only fill if the column is part of the cols list. + if (f.dataType.isInstanceOf[StringType] && cols.exists(col => columnEquals(f.name, col))) { + fillCol[String](f, value) + } else { + df.col(f.name) + } + } + df.select(projections : _*) + } + + /** + * Returns a new [[DataFrame]] that replaces null values. + * + * The key of the map is the column name, and the value of the map is the replacement value. + * The value must be of the following type: `Integer`, `Long`, `Float`, `Double`, `String`. + * + * For example, the following replaces null values in column "A" with string "unknown", and + * null values in column "B" with numeric value 1.0. + * {{{ + * import com.google.common.collect.ImmutableMap; + * df.na.fill(ImmutableMap.of("A", "unknown", "B", 1.0)); + * }}} + */ + def fill(valueMap: java.util.Map[String, Any]): DataFrame = fill0(valueMap.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame]] that replaces null values. + * + * The key of the map is the column name, and the value of the map is the replacement value. + * The value must be of the following type: `Int`, `Long`, `Float`, `Double`, `String`. + * + * For example, the following replaces null values in column "A" with string "unknown", and + * null values in column "B" with numeric value 1.0. + * {{{ + * df.na.fill(Map( + * "A" -> "unknown", + * "B" -> 1.0 + * )) + * }}} + */ + def fill(valueMap: Map[String, Any]): DataFrame = fill0(valueMap.toSeq) + + private def fill0(values: Seq[(String, Any)]): DataFrame = { + // Error handling + values.foreach { case (colName, replaceValue) => + // Check column name exists + df.resolve(colName) + + // Check data type + replaceValue match { + case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long | _: String => + // This is good + case _ => throw new IllegalArgumentException( + s"Unsupported value type ${replaceValue.getClass.getName} ($replaceValue).") + } + } + + val columnEquals = df.sqlContext.analyzer.resolver + val projections = df.schema.fields.map { f => + values.find { case (k, _) => columnEquals(k, f.name) }.map { case (_, v) => + v match { + case v: jl.Float => fillCol[Double](f, v.toDouble) + case v: jl.Double => fillCol[Double](f, v) + case v: jl.Long => fillCol[Double](f, v.toDouble) + case v: jl.Integer => fillCol[Double](f, v.toDouble) + case v: String => fillCol[String](f, v) + } + }.getOrElse(df.col(f.name)) + } + df.select(projections : _*) + } + + /** + * Returns a [[Column]] expression that replaces null value in `col` with `replacement`. + */ + private def fillCol[T](col: StructField, replacement: T): Column = { + coalesce(df.col(col.name), lit(replacement).cast(col.dataType)).as(col.name) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala deleted file mode 100644 index 71365c776d559..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala +++ /dev/null @@ -1,587 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.util.{List => JList} - -import scala.language.implicitConversions -import scala.reflect.runtime.universe.{TypeTag, typeTag} -import scala.collection.JavaConversions._ - -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types._ - - -/** - * Domain specific functions available for [[DataFrame]]. - */ -object Dsl { - - /** An implicit conversion that turns a Scala `Symbol` into a [[Column]]. */ - implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) - - /** Converts $"col name" into an [[Column]]. */ - implicit class StringToColumn(val sc: StringContext) extends AnyVal { - def $(args: Any*): ColumnName = { - new ColumnName(sc.s(args :_*)) - } - } - - private[this] implicit def toColumn(expr: Expression): Column = Column(expr) - - /** - * Returns a [[Column]] based on the given column name. - */ - def col(colName: String): Column = Column(colName) - - /** - * Returns a [[Column]] based on the given column name. Alias of [[col]]. - */ - def column(colName: String): Column = Column(colName) - - /** - * Creates a [[Column]] of literal value. - * - * The passed in object is returned directly if it is already a [[Column]]. - * If the object is a Scala Symbol, it is converted into a [[Column]] also. - * Otherwise, a new [[Column]] is created to represent the literal value. - */ - def lit(literal: Any): Column = { - literal match { - case c: Column => return c - case s: Symbol => return new ColumnName(literal.asInstanceOf[Symbol].name) - case _ => // continue - } - - val literalExpr = literal match { - case v: Boolean => Literal(v, BooleanType) - case v: Byte => Literal(v, ByteType) - case v: Short => Literal(v, ShortType) - case v: Int => Literal(v, IntegerType) - case v: Long => Literal(v, LongType) - case v: Float => Literal(v, FloatType) - case v: Double => Literal(v, DoubleType) - case v: String => Literal(v, StringType) - case v: BigDecimal => Literal(Decimal(v), DecimalType.Unlimited) - case v: java.math.BigDecimal => Literal(Decimal(v), DecimalType.Unlimited) - case v: Decimal => Literal(v, DecimalType.Unlimited) - case v: java.sql.Timestamp => Literal(v, TimestampType) - case v: java.sql.Date => Literal(v, DateType) - case v: Array[Byte] => Literal(v, BinaryType) - case null => Literal(null, NullType) - case _ => - throw new RuntimeException("Unsupported literal type " + literal.getClass + " " + literal) - } - Column(literalExpr) - } - - ////////////////////////////////////////////////////////////////////////////////////////////// - ////////////////////////////////////////////////////////////////////////////////////////////// - - /** Aggregate function: returns the sum of all values in the expression. */ - def sum(e: Column): Column = Sum(e.expr) - - /** Aggregate function: returns the sum of distinct values in the expression. */ - def sumDistinct(e: Column): Column = SumDistinct(e.expr) - - /** Aggregate function: returns the number of items in a group. */ - def count(e: Column): Column = Count(e.expr) - - /** Aggregate function: returns the number of distinct items in a group. */ - @scala.annotation.varargs - def countDistinct(expr: Column, exprs: Column*): Column = - CountDistinct((expr +: exprs).map(_.expr)) - - /** Aggregate function: returns the approximate number of distinct items in a group. */ - def approxCountDistinct(e: Column): Column = ApproxCountDistinct(e.expr) - - /** Aggregate function: returns the approximate number of distinct items in a group. */ - def approxCountDistinct(e: Column, rsd: Double): Column = ApproxCountDistinct(e.expr, rsd) - - /** Aggregate function: returns the average of the values in a group. */ - def avg(e: Column): Column = Average(e.expr) - - /** Aggregate function: returns the first value in a group. */ - def first(e: Column): Column = First(e.expr) - - /** Aggregate function: returns the last value in a group. */ - def last(e: Column): Column = Last(e.expr) - - /** Aggregate function: returns the minimum value of the expression in a group. */ - def min(e: Column): Column = Min(e.expr) - - /** Aggregate function: returns the maximum value of the expression in a group. */ - def max(e: Column): Column = Max(e.expr) - - ////////////////////////////////////////////////////////////////////////////////////////////// - ////////////////////////////////////////////////////////////////////////////////////////////// - - /** - * Unary minus, i.e. negate the expression. - * {{{ - * // Select the amount column and negates all values. - * // Scala: - * df.select( -df("amount") ) - * - * // Java: - * df.select( negate(df.col("amount")) ); - * }}} - */ - def negate(e: Column): Column = -e - - /** - * Inversion of boolean expression, i.e. NOT. - * {{ - * // Scala: select rows that are not active (isActive === false) - * df.filter( !df("isActive") ) - * - * // Java: - * df.filter( not(df.col("isActive")) ); - * }} - */ - def not(e: Column): Column = !e - - /** Converts a string expression to upper case. */ - def upper(e: Column): Column = Upper(e.expr) - - /** Converts a string exprsesion to lower case. */ - def lower(e: Column): Column = Lower(e.expr) - - /** Computes the square root of the specified float value. */ - def sqrt(e: Column): Column = Sqrt(e.expr) - - /** Computes the absolutle value. */ - def abs(e: Column): Column = Abs(e.expr) - - /** - * This is a private API for Python - * TODO: move this to a private package - */ - def toColumns(cols: JList[Column]): Seq[Column] = { - cols.toList.toSeq - } - - ////////////////////////////////////////////////////////////////////////////////////////////// - ////////////////////////////////////////////////////////////////////////////////////////////// - - // scalastyle:off - - /* Use the following code to generate: - (0 to 22).map { x => - val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) - val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) - val args = (1 to x).map(i => s"arg$i: Column").mkString(", ") - val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ") - println(s""" - /** - * Call a Scala function of ${x} arguments as user-defined function (UDF), and automatically - * infer the data types based on the function's signature. - */ - def callUDF[$typeTags](f: Function$x[$types]${if (args.length > 0) ", " + args else ""}): Column = { - ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq($argsInUdf)) - }""") - } - - (0 to 22).map { x => - val args = (1 to x).map(i => s"arg$i: Column").mkString(", ") - val fTypes = Seq.fill(x + 1)("_").mkString(", ") - val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ") - println(s""" - /** - * Call a Scala function of ${x} arguments as user-defined function (UDF). This requires - * you to specify the return data type. - */ - def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = { - ScalaUdf(f, returnType, Seq($argsInUdf)) - }""") - } - } - */ - /** - * Call a Scala function of 0 arguments as user-defined function (UDF), and automatically - * infer the data types based on the function's signature. - */ - def callUDF[RT: TypeTag](f: Function0[RT]): Column = { - ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq()) - } - - /** - * Call a Scala function of 1 arguments as user-defined function (UDF), and automatically - * infer the data types based on the function's signature. - */ - def callUDF[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT], arg1: Column): Column = { - ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr)) - } - - /** - * Call a Scala function of 2 arguments as user-defined function (UDF), and automatically - * infer the data types based on the function's signature. - */ - def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT], arg1: Column, arg2: Column): Column = { - ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr)) - } - - /** - * Call a Scala function of 3 arguments as user-defined function (UDF), and automatically - * infer the data types based on the function's signature. - */ - def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT], arg1: Column, arg2: Column, arg3: Column): Column = { - ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr)) - } - - /** - * Call a Scala function of 4 arguments as user-defined function (UDF), and automatically - * infer the data types based on the function's signature. - */ - def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = { - ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) - } - - /** - * Call a Scala function of 5 arguments as user-defined function (UDF), and automatically - * infer the data types based on the function's signature. - */ - def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = { - ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) - } - - /** - * Call a Scala function of 6 arguments as user-defined function (UDF), and automatically - * infer the data types based on the function's signature. - */ - def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = { - ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) - } - - /** - * Call a Scala function of 7 arguments as user-defined function (UDF), and automatically - * infer the data types based on the function's signature. - */ - def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = { - ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) - } - - /** - * Call a Scala function of 8 arguments as user-defined function (UDF), and automatically - * infer the data types based on the function's signature. - */ - def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = { - ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) - } - - /** - * Call a Scala function of 9 arguments as user-defined function (UDF), and automatically - * infer the data types based on the function's signature. - */ - def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = { - ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) - } - - /** - * Call a Scala function of 10 arguments as user-defined function (UDF), and automatically - * infer the data types based on the function's signature. - */ - def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = { - ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) - } - - /** - * Call a Scala function of 11 arguments as user-defined function (UDF), and automatically - * infer the data types based on the function's signature. - */ - def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](f: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column): Column = { - ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr)) - } - - /** - * Call a Scala function of 12 arguments as user-defined function (UDF), and automatically - * infer the data types based on the function's signature. - */ - def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](f: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column): Column = { - ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr)) - } - - /** - * Call a Scala function of 13 arguments as user-defined function (UDF), and automatically - * infer the data types based on the function's signature. - */ - def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](f: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column): Column = { - ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr)) - } - - /** - * Call a Scala function of 14 arguments as user-defined function (UDF), and automatically - * infer the data types based on the function's signature. - */ - def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](f: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column): Column = { - ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr)) - } - - /** - * Call a Scala function of 15 arguments as user-defined function (UDF), and automatically - * infer the data types based on the function's signature. - */ - def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](f: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column): Column = { - ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr)) - } - - /** - * Call a Scala function of 16 arguments as user-defined function (UDF), and automatically - * infer the data types based on the function's signature. - */ - def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](f: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column): Column = { - ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr)) - } - - /** - * Call a Scala function of 17 arguments as user-defined function (UDF), and automatically - * infer the data types based on the function's signature. - */ - def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](f: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column): Column = { - ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr)) - } - - /** - * Call a Scala function of 18 arguments as user-defined function (UDF), and automatically - * infer the data types based on the function's signature. - */ - def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](f: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column): Column = { - ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr)) - } - - /** - * Call a Scala function of 19 arguments as user-defined function (UDF), and automatically - * infer the data types based on the function's signature. - */ - def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](f: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column): Column = { - ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr)) - } - - /** - * Call a Scala function of 20 arguments as user-defined function (UDF), and automatically - * infer the data types based on the function's signature. - */ - def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](f: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column): Column = { - ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr)) - } - - /** - * Call a Scala function of 21 arguments as user-defined function (UDF), and automatically - * infer the data types based on the function's signature. - */ - def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](f: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column): Column = { - ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr)) - } - - /** - * Call a Scala function of 22 arguments as user-defined function (UDF), and automatically - * infer the data types based on the function's signature. - */ - def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](f: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column, arg22: Column): Column = { - ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr, arg22.expr)) - } - - ////////////////////////////////////////////////////////////////////////////////////////////////// - - /** - * Call a Scala function of 0 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - */ - def callUDF(f: Function0[_], returnType: DataType): Column = { - ScalaUdf(f, returnType, Seq()) - } - - /** - * Call a Scala function of 1 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - */ - def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr)) - } - - /** - * Call a Scala function of 2 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - */ - def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr)) - } - - /** - * Call a Scala function of 3 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - */ - def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) - } - - /** - * Call a Scala function of 4 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - */ - def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) - } - - /** - * Call a Scala function of 5 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - */ - def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) - } - - /** - * Call a Scala function of 6 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - */ - def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) - } - - /** - * Call a Scala function of 7 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - */ - def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) - } - - /** - * Call a Scala function of 8 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - */ - def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) - } - - /** - * Call a Scala function of 9 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - */ - def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) - } - - /** - * Call a Scala function of 10 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - */ - def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) - } - - /** - * Call a Scala function of 11 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - */ - def callUDF(f: Function11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr)) - } - - /** - * Call a Scala function of 12 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - */ - def callUDF(f: Function12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr)) - } - - /** - * Call a Scala function of 13 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - */ - def callUDF(f: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr)) - } - - /** - * Call a Scala function of 14 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - */ - def callUDF(f: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr)) - } - - /** - * Call a Scala function of 15 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - */ - def callUDF(f: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr)) - } - - /** - * Call a Scala function of 16 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - */ - def callUDF(f: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr)) - } - - /** - * Call a Scala function of 17 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - */ - def callUDF(f: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr)) - } - - /** - * Call a Scala function of 18 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - */ - def callUDF(f: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr)) - } - - /** - * Call a Scala function of 19 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - */ - def callUDF(f: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr)) - } - - /** - * Call a Scala function of 20 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - */ - def callUDF(f: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr)) - } - - /** - * Call a Scala function of 21 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - */ - def callUDF(f: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr)) - } - - /** - * Call a Scala function of 22 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - */ - def callUDF(f: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column, arg22: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr, arg22.expr)) - } - - // scalastyle:on -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala index f0e6a8f332188..d5d7e35a6b35d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala @@ -20,8 +20,13 @@ package org.apache.spark.sql import org.apache.spark.annotation.Experimental /** + * :: Experimental :: * Holder for experimental methods for the bravest. We make NO guarantee about the stability * regarding binary compatibility and source compatibility of methods here. + * + * {{{ + * sqlContext.experimental.extraStrategies += ... + * }}} */ @Experimental class ExperimentalMethods protected[sql](sqlContext: SQLContext) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala similarity index 65% rename from sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala rename to sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 7963cb03126ba..32c6562ef39d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -20,39 +20,66 @@ package org.apache.spark.sql import scala.language.implicitConversions import scala.collection.JavaConversions._ +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.analysis.Star import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr} import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.types.NumericType + /** + * :: Experimental :: * A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupBy]]. */ -class GroupedDataFrame protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expression]) { +@Experimental +class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) { - private[this] implicit def toDataFrame(aggExprs: Seq[NamedExpression]): DataFrame = { + private[this] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { val namedGroupingExprs = groupingExprs.map { case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.toString)() + case expr: Expression => Alias(expr, expr.prettyString)() } DataFrame( df.sqlContext, Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs, df.logicalPlan)) } - private[this] def aggregateNumericColumns(f: Expression => Expression): Seq[NamedExpression] = { - df.numericColumns.map { c => + private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => Expression) + : Seq[NamedExpression] = { + + val columnExprs = if (colNames.isEmpty) { + // No columns specified. Use all numeric columns. + df.numericColumns + } else { + // Make sure all specified columns are numeric + colNames.map { colName => + val namedExpr = df.resolve(colName) + if (!namedExpr.dataType.isInstanceOf[NumericType]) { + throw new AnalysisException( + s""""$colName" is not a numeric column. """ + + "Aggregation function can only be performed on a numeric column.") + } + namedExpr + } + } + columnExprs.map { c => val a = f(c) - Alias(a, a.toString)() + Alias(a, a.prettyString)() } } - + private[this] def strToExpr(expr: String): (Expression => Expression) = { expr.toLowerCase match { case "avg" | "average" | "mean" => Average case "max" => Max case "min" => Min case "sum" => Sum - case "count" | "size" => Count + case "count" | "size" => + // Turn count(*) into count(1) + (inputExpr: Expression) => inputExpr match { + case s: Star => Count(Literal(1)) + case _ => Count(inputExpr) + } } } @@ -89,7 +116,7 @@ class GroupedDataFrame protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expr def agg(exprs: Map[String, String]): DataFrame = { exprs.map { case (colName, expr) => val a = strToExpr(expr)(df(colName).expr) - Alias(a, a.toString)() + Alias(a, a.prettyString)() }.toSeq } @@ -101,10 +128,7 @@ class GroupedDataFrame protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expr * {{{ * // Selects the age of the oldest employee and the aggregate expense for each department * import com.google.common.collect.ImmutableMap; - * df.groupBy("department").agg(ImmutableMap.builder() - * .put("age", "max") - * .put("expense", "sum") - * .build()); + * df.groupBy("department").agg(ImmutableMap.of("age", "max", "expense", "sum")); * }}} */ def agg(exprs: java.util.Map[String, String]): DataFrame = { @@ -115,17 +139,17 @@ class GroupedDataFrame protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expr * Compute aggregates by specifying a series of aggregate columns. Unlike other methods in this * class, the resulting [[DataFrame]] won't automatically include the grouping columns. * - * The available aggregate methods are defined in [[org.apache.spark.sql.Dsl]]. + * The available aggregate methods are defined in [[org.apache.spark.sql.functions]]. * * {{{ * // Selects the age of the oldest employee and the aggregate expense for each department * * // Scala: - * import org.apache.spark.sql.dsl._ + * import org.apache.spark.sql.functions._ * df.groupBy("department").agg($"department", max($"age"), sum($"expense")) * * // Java: - * import static org.apache.spark.sql.Dsl.*; + * import static org.apache.spark.sql.functions.*; * df.groupBy("department").agg(col("department"), max(col("age")), sum(col("expense"))); * }}} */ @@ -133,7 +157,7 @@ class GroupedDataFrame protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expr def agg(expr: Column, exprs: Column*): DataFrame = { val aggExprs = (expr +: exprs).map(_.expr).map { case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.toString)() + case expr: Expression => Alias(expr, expr.prettyString)() } DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan)) } @@ -142,35 +166,55 @@ class GroupedDataFrame protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expr * Count the number of rows for each group. * The resulting [[DataFrame]] will also contain the grouping columns. */ - def count(): DataFrame = Seq(Alias(Count(LiteralExpr(1)), "count")()) + def count(): DataFrame = Seq(Alias(Count(Literal(1)), "count")()) /** * Compute the average value for each numeric columns for each group. This is an alias for `avg`. * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the average values for them. */ - def mean(): DataFrame = aggregateNumericColumns(Average) - + @scala.annotation.varargs + def mean(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames:_*)(Average) + } + /** * Compute the max value for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the max values for them. */ - def max(): DataFrame = aggregateNumericColumns(Max) + @scala.annotation.varargs + def max(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames:_*)(Max) + } /** * Compute the mean value for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the mean values for them. */ - def avg(): DataFrame = aggregateNumericColumns(Average) + @scala.annotation.varargs + def avg(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames:_*)(Average) + } /** * Compute the min value for each numeric column for each group. * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the min values for them. */ - def min(): DataFrame = aggregateNumericColumns(Min) + @scala.annotation.varargs + def min(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames:_*)(Min) + } /** * Compute the sum for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the sum for them. */ - def sum(): DataFrame = aggregateNumericColumns(Sum) + @scala.annotation.varargs + def sum(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames:_*)(Sum) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala deleted file mode 100644 index ba5c7355b4b70..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala +++ /dev/null @@ -1,174 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql - -import scala.reflect.ClassTag - -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.storage.StorageLevel -import org.apache.spark.sql.types.StructType - - -private[sql] class IncomputableColumn(protected[sql] val expr: Expression) extends Column { - - def this(name: String) = this(name match { - case "*" => UnresolvedStar(None) - case _ if name.endsWith(".*") => UnresolvedStar(Some(name.substring(0, name.length - 2))) - case _ => UnresolvedAttribute(name) - }) - - private def err[T](): T = { - throw new UnsupportedOperationException("Cannot run this method on an UncomputableColumn") - } - - override def isComputable: Boolean = false - - override val sqlContext: SQLContext = null - - override def queryExecution = err() - - protected[sql] override def logicalPlan: LogicalPlan = err() - - override def toDataFrame(colName: String, colNames: String*): DataFrame = err() - - override def schema: StructType = err() - - override def dtypes: Array[(String, String)] = err() - - override def columns: Array[String] = err() - - override def printSchema(): Unit = err() - - override def join(right: DataFrame): DataFrame = err() - - override def join(right: DataFrame, joinExprs: Column): DataFrame = err() - - override def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = err() - - override def sort(sortCol: String, sortCols: String*): DataFrame = err() - - override def sort(sortExpr: Column, sortExprs: Column*): DataFrame = err() - - override def orderBy(sortCol: String, sortCols: String*): DataFrame = err() - - override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame = err() - - override def col(colName: String): Column = err() - - override def apply(projection: Product): DataFrame = err() - - override def select(cols: Column*): DataFrame = err() - - override def select(col: String, cols: String*): DataFrame = err() - - override def filter(condition: Column): DataFrame = err() - - override def where(condition: Column): DataFrame = err() - - override def apply(condition: Column): DataFrame = err() - - override def groupBy(cols: Column*): GroupedDataFrame = err() - - override def groupBy(col1: String, cols: String*): GroupedDataFrame = err() - - override def limit(n: Int): DataFrame = err() - - override def unionAll(other: DataFrame): DataFrame = err() - - override def intersect(other: DataFrame): DataFrame = err() - - override def except(other: DataFrame): DataFrame = err() - - override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = err() - - ///////////////////////////////////////////////////////////////////////////// - - override def addColumn(colName: String, col: Column): DataFrame = err() - - override def head(n: Int): Array[Row] = err() - - override def head(): Row = err() - - override def first(): Row = err() - - override def map[R: ClassTag](f: Row => R): RDD[R] = err() - - override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = err() - - override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = err() - - override def foreach(f: Row => Unit): Unit = err() - - override def foreachPartition(f: Iterator[Row] => Unit): Unit = err() - - override def take(n: Int): Array[Row] = err() - - override def collect(): Array[Row] = err() - - override def collectAsList(): java.util.List[Row] = err() - - override def count(): Long = err() - - override def repartition(numPartitions: Int): DataFrame = err() - - override def persist(): this.type = err() - - override def persist(newLevel: StorageLevel): this.type = err() - - override def unpersist(blocking: Boolean): this.type = err() - - override def rdd: RDD[Row] = err() - - override def registerTempTable(tableName: String): Unit = err() - - override def saveAsParquetFile(path: String): Unit = err() - - override def saveAsTable(tableName: String): Unit = err() - - override def saveAsTable( - tableName: String, - dataSourceName: String, - option: (String, String), - options: (String, String)*): Unit = err() - - override def saveAsTable( - tableName: String, - dataSourceName: String, - options: java.util.Map[String, String]): Unit = err() - - override def save(path: String): Unit = err() - - override def save( - dataSourceName: String, - option: (String, String), - options: (String, String)*): Unit = err() - - override def save( - dataSourceName: String, - options: java.util.Map[String, String]): Unit = err() - - override def insertInto(tableName: String, overwrite: Boolean): Unit = err() - - override def toJSON: RDD[String] = err() - - protected[sql] override def javaToPython: JavaRDD[Array[Byte]] = err() -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala b/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala index 38e6382f171d5..ba4373f0124b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala @@ -29,13 +29,13 @@ import org.apache.spark.storage.StorageLevel */ private[sql] trait RDDApi[T] { - def cache(): this.type = persist() + def cache(): this.type def persist(): this.type def persist(newLevel: StorageLevel): this.type - def unpersist(): this.type = unpersist(blocking = false) + def unpersist(): this.type def unpersist(blocking: Boolean): this.type @@ -60,4 +60,6 @@ private[sql] trait RDDApi[T] { def first(): T def repartition(numPartitions: Int): DataFrame + + def distinct: DataFrame } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 561a91d2d60ee..4815620c6fe57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -33,9 +33,11 @@ private[spark] object SQLConf { val DIALECT = "spark.sql.dialect" val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString" + val PARQUET_INT96_AS_TIMESTAMP = "spark.sql.parquet.int96AsTimestamp" val PARQUET_CACHE_METADATA = "spark.sql.parquet.cacheMetadata" val PARQUET_COMPRESSION = "spark.sql.parquet.compression.codec" val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.parquet.filterPushdown" + val PARQUET_USE_DATA_SOURCE_API = "spark.sql.parquet.useDataSourceApi" val COLUMN_NAME_OF_CORRUPT_RECORD = "spark.sql.columnNameOfCorruptRecord" val BROADCAST_TIMEOUT = "spark.sql.broadcastTimeout" @@ -48,7 +50,16 @@ private[spark] object SQLConf { val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool" // This is used to set the default data source - val DEFAULT_DATA_SOURCE_NAME = "spark.sql.default.datasource" + val DEFAULT_DATA_SOURCE_NAME = "spark.sql.sources.default" + // This is used to control the when we will split a schema's JSON string to multiple pieces + // in order to fit the JSON string in metastore's table property (by default, the value has + // a length restriction of 4000 characters). We will split the JSON string of a schema + // to its length exceeds the threshold. + val SCHEMA_STRING_LENGTH_THRESHOLD = "spark.sql.sources.schemaStringLengthThreshold" + + // Whether to perform eager analysis when constructing a dataframe. + // Set to false when debugging requires the ability to look at invalid query plans. + val DATAFRAME_EAGER_ANALYSIS = "spark.sql.eagerAnalysis" object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" @@ -104,6 +115,10 @@ private[sql] class SQLConf extends Serializable { private[spark] def parquetFilterPushDown = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED, "false").toBoolean + /** When true uses Parquet implementation based on data source API */ + private[spark] def parquetUseDataSourceApi = + getConf(PARQUET_USE_DATA_SOURCE_API, "true").toBoolean + /** When true the planner will use the external sort, which may spill to disk. */ private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT, "false").toBoolean @@ -143,6 +158,12 @@ private[sql] class SQLConf extends Serializable { private[spark] def isParquetBinaryAsString: Boolean = getConf(PARQUET_BINARY_AS_STRING, "false").toBoolean + /** + * When set to true, we always treat INT96Values in Parquet files as timestamp. + */ + private[spark] def isParquetINT96AsTimestamp: Boolean = + getConf(PARQUET_INT96_AS_TIMESTAMP, "true").toBoolean + /** * When set to true, partition pruning for in-memory columnar tables is enabled. */ @@ -161,6 +182,14 @@ private[sql] class SQLConf extends Serializable { private[spark] def defaultDataSourceName: String = getConf(DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.parquet") + // Do not use a value larger than 4000 as the default value of this property. + // See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information. + private[spark] def schemaStringLengthThreshold: Int = + getConf(SCHEMA_STRING_LENGTH_THRESHOLD, "4000").toInt + + private[spark] def dataFrameEagerAnalysis: Boolean = + getConf(DATAFRAME_EAGER_ANALYSIS, "true").toBoolean + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a741d0031d155..b7f72e716c7f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -20,37 +20,41 @@ package org.apache.spark.sql import java.beans.Introspector import java.util.Properties -import scala.collection.immutable import scala.collection.JavaConversions._ +import scala.collection.immutable import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.{SparkContext, Partition} -import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental} -import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} +import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, OneRowRelation} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.json._ +import org.apache.spark.sql.catalyst.{ScalaReflection, expressions} +import org.apache.spark.sql.execution.{Filter, _} import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} +import org.apache.spark.sql.json._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +import org.apache.spark.{Partition, SparkContext} /** - * :: AlphaComponent :: - * The entry point for running relational queries using Spark. Allows the creation of [[DataFrame]] - * objects and the execution of SQL queries. + * The entry point for working with structured data (rows and columns) in Spark. Allows the + * creation of [[DataFrame]] objects as well as the execution of SQL queries. * - * @groupname userf Spark SQL Functions + * @groupname basic Basic Operations + * @groupname ddl_ops Persistent Catalog DDL + * @groupname cachemgmt Cached Table Management + * @groupname genericdata Generic Data Sources + * @groupname specificdata Specific Data Sources + * @groupname config Configuration + * @groupname dataframes Custom DataFrame Creation * @groupname Ungrouped Support functions for language integrated queries. */ -@AlphaComponent class SQLContext(@transient val sparkContext: SparkContext) extends org.apache.spark.Logging with Serializable { @@ -62,24 +66,40 @@ class SQLContext(@transient val sparkContext: SparkContext) // Note that this is a lazy val so we can override the default value in subclasses. protected[sql] lazy val conf: SQLConf = new SQLConf - /** Set Spark SQL configuration properties. */ + /** + * Set Spark SQL configuration properties. + * + * @group config + */ def setConf(props: Properties): Unit = conf.setConf(props) - /** Set the given Spark SQL configuration property. */ + /** + * Set the given Spark SQL configuration property. + * + * @group config + */ def setConf(key: String, value: String): Unit = conf.setConf(key, value) - /** Return the value of Spark SQL configuration property for the given key. */ + /** + * Return the value of Spark SQL configuration property for the given key. + * + * @group config + */ def getConf(key: String): String = conf.getConf(key) /** * Return the value of Spark SQL configuration property for the given key. If the key is not set * yet, return `defaultValue`. + * + * @group config */ def getConf(key: String, defaultValue: String): String = conf.getConf(key, defaultValue) /** * Return all the configuration properties that have been set (i.e. not the default). * This creates a new copy of the config properties in the form of a Map. + * + * @group config */ def getAllConfs: immutable.Map[String, String] = conf.getAllConfs @@ -87,17 +107,26 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] lazy val catalog: Catalog = new SimpleCatalog(true) @transient - protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry + protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry(true) @transient protected[sql] lazy val analyzer: Analyzer = - new Analyzer(catalog, functionRegistry, caseSensitive = true) + new Analyzer(catalog, functionRegistry, caseSensitive = true) { + override val extendedResolutionRules = + ExtractPythonUdfs :: + sources.PreInsertCastAndRename :: + Nil + + override val extendedCheckRules = Seq( + sources.PreWriteCheck(catalog) + ) + } @transient protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer @transient - protected[sql] val ddlParser = new DDLParser + protected[sql] val ddlParser = new DDLParser(sqlParser.apply(_)) @transient protected[sql] val sqlParser = { @@ -118,14 +147,30 @@ class SQLContext(@transient val sparkContext: SparkContext) case _ => } + @transient protected[sql] val cacheManager = new CacheManager(this) /** + * :: Experimental :: * A collection of methods that are considered experimental, but can be used to hook into - * the query planner for advanced functionalities. + * the query planner for advanced functionality. + * + * @group basic */ + @Experimental + @transient val experimental: ExperimentalMethods = new ExperimentalMethods(this) + /** + * :: Experimental :: + * Returns a [[DataFrame]] with no rows or columns. + * + * @group basic + */ + @Experimental + @transient + lazy val emptyDataFrame: DataFrame = createDataFrame(sparkContext.emptyRDD[Row], StructType(Nil)) + /** * A collection of methods for registering user-defined functions (UDF). * @@ -151,33 +196,151 @@ class SQLContext(@transient val sparkContext: SparkContext) * (Integer arg1, String arg2) -> arg2 + arg1), * DataTypes.StringType); * }}} + * + * @group basic */ + @transient val udf: UDFRegistration = new UDFRegistration(this) - /** Returns true if the table is currently cached in-memory. */ + /** + * Returns true if the table is currently cached in-memory. + * @group cachemgmt + */ def isCached(tableName: String): Boolean = cacheManager.isCached(tableName) - /** Caches the specified table in-memory. */ + /** + * Caches the specified table in-memory. + * @group cachemgmt + */ def cacheTable(tableName: String): Unit = cacheManager.cacheTable(tableName) - /** Removes the specified table from the in-memory cache. */ + /** + * Removes the specified table from the in-memory cache. + * @group cachemgmt + */ def uncacheTable(tableName: String): Unit = cacheManager.uncacheTable(tableName) /** + * Removes all cached tables from the in-memory cache. + */ + def clearCache(): Unit = cacheManager.clearCache() + + // scalastyle:off + // Disable style checker so "implicits" object can start with lowercase i + /** + * :: Experimental :: + * (Scala-specific) Implicit methods available in Scala for converting + * common Scala objects into [[DataFrame]]s. + * + * {{{ + * val sqlContext = new SQLContext(sc) + * import sqlContext.implicits._ + * }}} + * + * @group basic + */ + @Experimental + object implicits extends Serializable { + // scalastyle:on + + /** Converts $"col name" into an [[Column]]. */ + implicit class StringToColumn(val sc: StringContext) { + def $(args: Any*): ColumnName = { + new ColumnName(sc.s(args :_*)) + } + } + + /** An implicit conversion that turns a Scala `Symbol` into a [[Column]]. */ + implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) + + /** Creates a DataFrame from an RDD of case classes or tuples. */ + implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = { + DataFrameHolder(self.createDataFrame(rdd)) + } + + /** Creates a DataFrame from a local Seq of Product. */ + implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = + { + DataFrameHolder(self.createDataFrame(data)) + } + + // Do NOT add more implicit conversions. They are likely to break source compatibility by + // making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous + // because of [[DoubleRDDFunctions]]. + + /** Creates a single column DataFrame from an RDD[Int]. */ + implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = { + val dataType = IntegerType + val rows = data.mapPartitions { iter => + val row = new SpecificMutableRow(dataType :: Nil) + iter.map { v => + row.setInt(0, v) + row: Row + } + } + DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + } + + /** Creates a single column DataFrame from an RDD[Long]. */ + implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = { + val dataType = LongType + val rows = data.mapPartitions { iter => + val row = new SpecificMutableRow(dataType :: Nil) + iter.map { v => + row.setLong(0, v) + row: Row + } + } + DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + } + + /** Creates a single column DataFrame from an RDD[String]. */ + implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = { + val dataType = StringType + val rows = data.mapPartitions { iter => + val row = new SpecificMutableRow(dataType :: Nil) + iter.map { v => + row.setString(0, v) + row: Row + } + } + DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + } + } + + /** + * :: Experimental :: * Creates a DataFrame from an RDD of case classes. * - * @group userf + * @group dataframes */ - implicit def createDataFrame[A <: Product: TypeTag](rdd: RDD[A]): DataFrame = { + @Experimental + def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { SparkPlan.currentContext.set(self) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes val rowRDD = RDDConversions.productToRowRdd(rdd, schema) - DataFrame(this, LogicalRDD(attributeSeq, rowRDD)(self)) + DataFrame(self, LogicalRDD(attributeSeq, rowRDD)(self)) + } + + /** + * :: Experimental :: + * Creates a DataFrame from a local Seq of Product. + * + * @group dataframes + */ + @Experimental + def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { + SparkPlan.currentContext.set(self) + val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] + val attributeSeq = schema.toAttributes + DataFrame(self, LocalRelation.fromProduct(attributeSeq, data)) } /** * Convert a [[BaseRelation]] created for external data sources into a [[DataFrame]]. + * + * @group dataframes */ def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame = { DataFrame(this, LogicalRelation(baseRelation)) @@ -185,12 +348,13 @@ class SQLContext(@transient val sparkContext: SparkContext) /** * :: DeveloperApi :: - * Creates a [[DataFrame]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD. + * Creates a [[DataFrame]] from an [[RDD]] containing [[Row]]s using the given schema. * It is important to make sure that the structure of every [[Row]] of the provided RDD matches * the provided schema. Otherwise, there will be runtime exception. * Example: * {{{ * import org.apache.spark.sql._ + * import org.apache.spark.sql.types._ * val sqlContext = new org.apache.spark.sql.SQLContext(sc) * * val schema = @@ -201,7 +365,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * val people = * sc.textFile("examples/src/main/resources/people.txt").map( * _.split(",")).map(p => Row(p(0), p(1).trim.toInt)) - * val dataFrame = sqlContext. applySchema(people, schema) + * val dataFrame = sqlContext.createDataFrame(people, schema) * dataFrame.printSchema * // root * // |-- name: string (nullable = false) @@ -211,23 +375,65 @@ class SQLContext(@transient val sparkContext: SparkContext) * sqlContext.sql("select name from people").collect.foreach(println) * }}} * - * @group userf + * @group dataframes */ @DeveloperApi - def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = { + def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = { + createDataFrame(rowRDD, schema, needsConversion = true) + } + + /** + * Creates a DataFrame from an RDD[Row]. User can specify whether the input rows should be + * converted to Catalyst rows. + */ + private[sql] + def createDataFrame(rowRDD: RDD[Row], schema: StructType, needsConversion: Boolean) = { // TODO: use MutableProjection when rowRDD is another DataFrame and the applied // schema differs from the existing schema on any field data type. - val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self) + val catalystRows = if (needsConversion) { + rowRDD.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row]) + } else { + rowRDD + } + val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self) DataFrame(this, logicalPlan) } + /** + * :: DeveloperApi :: + * Creates a [[DataFrame]] from an [[JavaRDD]] containing [[Row]]s using the given schema. + * It is important to make sure that the structure of every [[Row]] of the provided RDD matches + * the provided schema. Otherwise, there will be runtime exception. + * + * @group dataframes + */ + @DeveloperApi + def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { + createDataFrame(rowRDD.rdd, schema) + } + + /** + * Creates a [[DataFrame]] from an [[JavaRDD]] containing [[Row]]s by applying + * a seq of names of columns to this RDD, the data type for each column will + * be inferred by the first row. + * + * @param rowRDD an JavaRDD of Row + * @param columns names for each column + * @return DataFrame + * @group dataframes + */ + def createDataFrame(rowRDD: JavaRDD[Row], columns: java.util.List[String]): DataFrame = { + createDataFrame(rowRDD.rdd, columns.toSeq) + } + /** * Applies a schema to an RDD of Java Beans. * * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, * SELECT * queries will return the columns in an undefined order. + * @group dataframes */ - def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = { + def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = { val attributeSeq = getSchema(beanClass) val className = beanClass.getName val rowRdd = rdd.mapPartitions { iter => @@ -253,24 +459,96 @@ class SQLContext(@transient val sparkContext: SparkContext) * * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, * SELECT * queries will return the columns in an undefined order. + * @group dataframes */ + def createDataFrame(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { + createDataFrame(rdd.rdd, beanClass) + } + + /** + * :: DeveloperApi :: + * Creates a [[DataFrame]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD. + * It is important to make sure that the structure of every [[Row]] of the provided RDD matches + * the provided schema. Otherwise, there will be runtime exception. + * Example: + * {{{ + * import org.apache.spark.sql._ + * import org.apache.spark.sql.types._ + * val sqlContext = new org.apache.spark.sql.SQLContext(sc) + * + * val schema = + * StructType( + * StructField("name", StringType, false) :: + * StructField("age", IntegerType, true) :: Nil) + * + * val people = + * sc.textFile("examples/src/main/resources/people.txt").map( + * _.split(",")).map(p => Row(p(0), p(1).trim.toInt)) + * val dataFrame = sqlContext. applySchema(people, schema) + * dataFrame.printSchema + * // root + * // |-- name: string (nullable = false) + * // |-- age: integer (nullable = true) + * + * dataFrame.registerTempTable("people") + * sqlContext.sql("select name from people").collect.foreach(println) + * }}} + */ + @deprecated("use createDataFrame", "1.3.0") + def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = { + createDataFrame(rowRDD, schema) + } + + @deprecated("use createDataFrame", "1.3.0") + def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { + createDataFrame(rowRDD, schema) + } + + /** + * Applies a schema to an RDD of Java Beans. + * + * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, + * SELECT * queries will return the columns in an undefined order. + */ + @deprecated("use createDataFrame", "1.3.0") + def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = { + createDataFrame(rdd, beanClass) + } + + /** + * Applies a schema to an RDD of Java Beans. + * + * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, + * SELECT * queries will return the columns in an undefined order. + */ + @deprecated("use createDataFrame", "1.3.0") def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { - applySchema(rdd.rdd, beanClass) + createDataFrame(rdd, beanClass) } /** - * Loads a Parquet file, returning the result as a [[DataFrame]]. + * Loads a Parquet file, returning the result as a [[DataFrame]]. This function returns an empty + * [[DataFrame]] if no paths are passed in. * - * @group userf + * @group specificdata */ - def parquetFile(path: String): DataFrame = - DataFrame(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration), this)) + @scala.annotation.varargs + def parquetFile(paths: String*): DataFrame = { + if (paths.isEmpty) { + emptyDataFrame + } else if (conf.parquetUseDataSourceApi) { + baseRelationToDataFrame(parquet.ParquetRelation2(paths, Map.empty)(this)) + } else { + DataFrame(this, parquet.ParquetRelation( + paths.mkString(","), Some(sparkContext.hadoopConfiguration), this)) + } + } /** * Loads a JSON file (one object per line), returning the result as a [[DataFrame]]. * It goes through the entire dataset once to determine the schema. * - * @group userf + * @group specificdata */ def jsonFile(path: String): DataFrame = jsonFile(path, 1.0) @@ -279,38 +557,45 @@ class SQLContext(@transient val sparkContext: SparkContext) * Loads a JSON file (one object per line) and applies the given schema, * returning the result as a [[DataFrame]]. * - * @group userf + * @group specificdata */ @Experimental - def jsonFile(path: String, schema: StructType): DataFrame = { - val json = sparkContext.textFile(path) - jsonRDD(json, schema) - } + def jsonFile(path: String, schema: StructType): DataFrame = + load("json", schema, Map("path" -> path)) /** * :: Experimental :: + * @group specificdata */ @Experimental - def jsonFile(path: String, samplingRatio: Double): DataFrame = { - val json = sparkContext.textFile(path) - jsonRDD(json, samplingRatio) - } + def jsonFile(path: String, samplingRatio: Double): DataFrame = + load("json", Map("path" -> path, "samplingRatio" -> samplingRatio.toString)) /** * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a * [[DataFrame]]. * It goes through the entire dataset once to determine the schema. * - * @group userf + * @group specificdata */ def jsonRDD(json: RDD[String]): DataFrame = jsonRDD(json, 1.0) + + /** + * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a + * [[DataFrame]]. + * It goes through the entire dataset once to determine the schema. + * + * @group specificdata + */ + def jsonRDD(json: JavaRDD[String]): DataFrame = jsonRDD(json.rdd, 1.0) + /** * :: Experimental :: * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema, * returning the result as a [[DataFrame]]. * - * @group userf + * @group specificdata */ @Experimental def jsonRDD(json: RDD[String], schema: StructType): DataFrame = { @@ -320,11 +605,27 @@ class SQLContext(@transient val sparkContext: SparkContext) JsonRDD.nullTypeToStringType( JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord))) val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) - applySchema(rowRDD, appliedSchema) + createDataFrame(rowRDD, appliedSchema, needsConversion = false) } /** * :: Experimental :: + * Loads an JavaRDD storing JSON objects (one object per record) and applies the given + * schema, returning the result as a [[DataFrame]]. + * + * @group specificdata + */ + @Experimental + def jsonRDD(json: JavaRDD[String], schema: StructType): DataFrame = { + jsonRDD(json.rdd, schema) + } + + /** + * :: Experimental :: + * Loads an RDD[String] storing JSON objects (one object per record) inferring the + * schema, returning the result as a [[DataFrame]]. + * + * @group specificdata */ @Experimental def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = { @@ -333,86 +634,279 @@ class SQLContext(@transient val sparkContext: SparkContext) JsonRDD.nullTypeToStringType( JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord)) val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) - applySchema(rowRDD, appliedSchema) + createDataFrame(rowRDD, appliedSchema, needsConversion = false) } + /** + * :: Experimental :: + * Loads a JavaRDD[String] storing JSON objects (one object per record) inferring the + * schema, returning the result as a [[DataFrame]]. + * + * @group specificdata + */ + @Experimental + def jsonRDD(json: JavaRDD[String], samplingRatio: Double): DataFrame = { + jsonRDD(json.rdd, samplingRatio); + } + + /** + * :: Experimental :: + * Returns the dataset stored at path as a DataFrame, + * using the default data source configured by spark.sql.sources.default. + * + * @group genericdata + */ @Experimental def load(path: String): DataFrame = { val dataSourceName = conf.defaultDataSourceName - load(dataSourceName, ("path", path)) + load(path, dataSourceName) } + /** + * :: Experimental :: + * Returns the dataset stored at path as a DataFrame, using the given data source. + * + * @group genericdata + */ @Experimental - def load( - dataSourceName: String, - option: (String, String), - options: (String, String)*): DataFrame = { - val resolved = ResolvedDataSource(this, None, dataSourceName, (option +: options).toMap) + def load(path: String, source: String): DataFrame = { + load(source, Map("path" -> path)) + } + + /** + * :: Experimental :: + * (Java-specific) Returns the dataset specified by the given data source and + * a set of options as a DataFrame. + * + * @group genericdata + */ + @Experimental + def load(source: String, options: java.util.Map[String, String]): DataFrame = { + load(source, options.toMap) + } + + /** + * :: Experimental :: + * (Scala-specific) Returns the dataset specified by the given data source and + * a set of options as a DataFrame. + * + * @group genericdata + */ + @Experimental + def load(source: String, options: Map[String, String]): DataFrame = { + val resolved = ResolvedDataSource(this, None, source, options) DataFrame(this, LogicalRelation(resolved.relation)) } + /** + * :: Experimental :: + * (Java-specific) Returns the dataset specified by the given data source and + * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. + * + * @group genericdata + */ @Experimental def load( - dataSourceName: String, + source: String, + schema: StructType, options: java.util.Map[String, String]): DataFrame = { - val opts = options.toSeq - load(dataSourceName, opts.head, opts.tail:_*) + load(source, schema, options.toMap) + } + + /** + * :: Experimental :: + * (Scala-specific) Returns the dataset specified by the given data source and + * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. + * @group genericdata + */ + @Experimental + def load( + source: String, + schema: StructType, + options: Map[String, String]): DataFrame = { + val resolved = ResolvedDataSource(this, Some(schema), source, options) + DataFrame(this, LogicalRelation(resolved.relation)) } /** * :: Experimental :: - * Construct an RDD representing the database table accessible via JDBC URL + * Creates an external table from the given path and returns the corresponding DataFrame. + * It will use the default data source configured by spark.sql.sources.default. + * + * @group ddl_ops + */ + @Experimental + def createExternalTable(tableName: String, path: String): DataFrame = { + val dataSourceName = conf.defaultDataSourceName + createExternalTable(tableName, path, dataSourceName) + } + + /** + * :: Experimental :: + * Creates an external table from the given path based on a data source + * and returns the corresponding DataFrame. + * + * @group ddl_ops + */ + @Experimental + def createExternalTable( + tableName: String, + path: String, + source: String): DataFrame = { + createExternalTable(tableName, source, Map("path" -> path)) + } + + /** + * :: Experimental :: + * Creates an external table from the given path based on a data source and a set of options. + * Then, returns the corresponding DataFrame. + * + * @group ddl_ops + */ + @Experimental + def createExternalTable( + tableName: String, + source: String, + options: java.util.Map[String, String]): DataFrame = { + createExternalTable(tableName, source, options.toMap) + } + + /** + * :: Experimental :: + * (Scala-specific) + * Creates an external table from the given path based on a data source and a set of options. + * Then, returns the corresponding DataFrame. + * + * @group ddl_ops + */ + @Experimental + def createExternalTable( + tableName: String, + source: String, + options: Map[String, String]): DataFrame = { + val cmd = + CreateTableUsing( + tableName, + userSpecifiedSchema = None, + source, + temporary = false, + options, + allowExisting = false, + managedIfNoPath = false) + executePlan(cmd).toRdd + table(tableName) + } + + /** + * :: Experimental :: + * Create an external table from the given path based on a data source, a schema and + * a set of options. Then, returns the corresponding DataFrame. + * + * @group ddl_ops + */ + @Experimental + def createExternalTable( + tableName: String, + source: String, + schema: StructType, + options: java.util.Map[String, String]): DataFrame = { + createExternalTable(tableName, source, schema, options.toMap) + } + + /** + * :: Experimental :: + * (Scala-specific) + * Create an external table from the given path based on a data source, a schema and + * a set of options. Then, returns the corresponding DataFrame. + * + * @group ddl_ops + */ + @Experimental + def createExternalTable( + tableName: String, + source: String, + schema: StructType, + options: Map[String, String]): DataFrame = { + val cmd = + CreateTableUsing( + tableName, + userSpecifiedSchema = Some(schema), + source, + temporary = false, + options, + allowExisting = false, + managedIfNoPath = false) + executePlan(cmd).toRdd + table(tableName) + } + + /** + * :: Experimental :: + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL * url named table. + * + * @group specificdata */ @Experimental - def jdbcRDD(url: String, table: String): DataFrame = { - jdbcRDD(url, table, null.asInstanceOf[JDBCPartitioningInfo]) + def jdbc(url: String, table: String): DataFrame = { + jdbc(url, table, JDBCRelation.columnPartition(null)) } /** * :: Experimental :: - * Construct an RDD representing the database table accessible via JDBC URL - * url named table. The PartitioningInfo parameter - * gives the name of a column of integral type, a number of partitions, and - * advisory minimum and maximum values for the column. The RDD is - * partitioned according to said column. + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL + * url named table. Partitions of the table will be retrieved in parallel based on the parameters + * passed to this function. + * + * @param columnName the name of a column of integral type that will be used for partitioning. + * @param lowerBound the minimum value of `columnName` to retrieve + * @param upperBound the maximum value of `columnName` to retrieve + * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split + * evenly into this many partitions + * + * @group specificdata */ @Experimental - def jdbcRDD(url: String, table: String, partitioning: JDBCPartitioningInfo): - DataFrame = { + def jdbc( + url: String, + table: String, + columnName: String, + lowerBound: Long, + upperBound: Long, + numPartitions: Int): DataFrame = { + val partitioning = JDBCPartitioningInfo(columnName, lowerBound, upperBound, numPartitions) val parts = JDBCRelation.columnPartition(partitioning) - jdbcRDD(url, table, parts) + jdbc(url, table, parts) } /** * :: Experimental :: - * Construct an RDD representing the database table accessible via JDBC URL + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL * url named table. The theParts parameter gives a list expressions * suitable for inclusion in WHERE clauses; each one defines one partition - * of the RDD. + * of the [[DataFrame]]. + * + * @group specificdata */ @Experimental - def jdbcRDD(url: String, table: String, theParts: Array[String]): - DataFrame = { - val parts: Array[Partition] = theParts.zipWithIndex.map( - x => JDBCPartition(x._1, x._2).asInstanceOf[Partition]) - jdbcRDD(url, table, parts) + def jdbc(url: String, table: String, theParts: Array[String]): DataFrame = { + val parts: Array[Partition] = theParts.zipWithIndex.map { case (part, i) => + JDBCPartition(part, i) : Partition + } + jdbc(url, table, parts) } - private def jdbcRDD(url: String, table: String, parts: Array[Partition]): - DataFrame = { + private def jdbc(url: String, table: String, parts: Array[Partition]): DataFrame = { val relation = JDBCRelation(url, table, parts)(this) baseRelationToDataFrame(relation) } /** - * Registers the given RDD as a temporary table in the catalog. Temporary tables exist only - * during the lifetime of this instance of SQLContext. - * - * @group userf + * Registers the given [[DataFrame]] as a temporary table in the catalog. Temporary tables exist + * only during the lifetime of this instance of SQLContext. */ - def registerRDDAsTable(rdd: DataFrame, tableName: String): Unit = { - catalog.registerTable(Seq(tableName), rdd.logicalPlan) + private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = { + catalog.registerTable(Seq(tableName), df.logicalPlan) } /** @@ -421,7 +915,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @param tableName the name of the table to be unregistered. * - * @group userf + * @group basic */ def dropTempTable(tableName: String): Unit = { cacheManager.tryUncacheQuery(table(tableName)) @@ -432,7 +926,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * Executes a SQL query using Spark, returning the result as a [[DataFrame]]. The dialect that is * used for SQL parsing can be configured with 'spark.sql.dialect'. * - * @group userf + * @group basic */ def sql(sqlText: String): DataFrame = { if (conf.dialect == "sql") { @@ -442,18 +936,66 @@ class SQLContext(@transient val sparkContext: SparkContext) } } - /** Returns the specified table as a [[DataFrame]]. */ + /** + * Returns the specified table as a [[DataFrame]]. + * + * @group ddl_ops + */ def table(tableName: String): DataFrame = DataFrame(this, catalog.lookupRelation(Seq(tableName))) + /** + * Returns a [[DataFrame]] containing names of existing tables in the current database. + * The returned DataFrame has two columns, tableName and isTemporary (a Boolean + * indicating if a table is a temporary one or not). + * + * @group ddl_ops + */ + def tables(): DataFrame = { + DataFrame(this, ShowTablesCommand(None)) + } + + /** + * Returns a [[DataFrame]] containing names of existing tables in the given database. + * The returned DataFrame has two columns, tableName and isTemporary (a Boolean + * indicating if a table is a temporary one or not). + * + * @group ddl_ops + */ + def tables(databaseName: String): DataFrame = { + DataFrame(this, ShowTablesCommand(Some(databaseName))) + } + + /** + * Returns the names of tables in the current database as an array. + * + * @group ddl_ops + */ + def tableNames(): Array[String] = { + catalog.getTables(None).map { + case (tableName, _) => tableName + }.toArray + } + + /** + * Returns the names of tables in the given database as an array. + * + * @group ddl_ops + */ + def tableNames(databaseName: String): Array[String] = { + catalog.getTables(Some(databaseName)).map { + case (tableName, _) => tableName + }.toArray + } + protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext: SparkContext = self.sparkContext val sqlContext: SQLContext = self - def codegenEnabled = self.conf.codegenEnabled + def codegenEnabled: Boolean = self.conf.codegenEnabled - def numPartitions = self.conf.numShufflePartitions + def numPartitions: Int = self.conf.numShufflePartitions def strategies: Seq[Strategy] = experimental.extraStrategies ++ ( @@ -490,7 +1032,8 @@ class SQLContext(@transient val sparkContext: SparkContext) val projectSet = AttributeSet(projectList.flatMap(_.references)) val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) - val filterCondition = prunePushedDownFilters(filterPredicates).reduceLeftOption(And) + val filterCondition = + prunePushedDownFilters(filterPredicates).reduceLeftOption(expressions.And) // Right now we still use a projection even if the only evaluation is applying an alias // to a column. Since this is a no-op, it could be avoided. However, using this @@ -534,9 +1077,13 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @DeveloperApi protected[sql] class QueryExecution(val logical: LogicalPlan) { + def assertAnalyzed(): Unit = analyzer.checkAnalysis(analyzed) - lazy val analyzed: LogicalPlan = ExtractPythonUdfs(analyzer(logical)) - lazy val withCachedData: LogicalPlan = cacheManager.useCachedData(analyzed) + lazy val analyzed: LogicalPlan = analyzer(logical) + lazy val withCachedData: LogicalPlan = { + assertAnalyzed() + cacheManager.useCachedData(analyzed) + } lazy val optimizedPlan: LogicalPlan = optimizer(withCachedData) // TODO: Don't just pick the first one... @@ -605,6 +1152,7 @@ class SQLContext(@transient val sparkContext: SparkContext) def needsConversion(dataType: DataType): Boolean = dataType match { case ByteType => true case ShortType => true + case LongType => true case FloatType => true case DateType => true case TimestampType => true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala index f1a4053b79113..5921eaf5e63f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala @@ -23,7 +23,7 @@ import scala.util.parsing.combinator.RegexParsers import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.{UncacheTableCommand, CacheTableCommand, SetCommand} +import org.apache.spark.sql.execution._ import org.apache.spark.sql.types.StringType @@ -57,12 +57,16 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr protected val AS = Keyword("AS") protected val CACHE = Keyword("CACHE") + protected val CLEAR = Keyword("CLEAR") + protected val IN = Keyword("IN") protected val LAZY = Keyword("LAZY") protected val SET = Keyword("SET") + protected val SHOW = Keyword("SHOW") protected val TABLE = Keyword("TABLE") + protected val TABLES = Keyword("TABLES") protected val UNCACHE = Keyword("UNCACHE") - override protected lazy val start: Parser[LogicalPlan] = cache | uncache | set | others + override protected lazy val start: Parser[LogicalPlan] = cache | uncache | set | show | others private lazy val cache: Parser[LogicalPlan] = CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ { @@ -71,15 +75,22 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr } private lazy val uncache: Parser[LogicalPlan] = - UNCACHE ~ TABLE ~> ident ^^ { - case tableName => UncacheTableCommand(tableName) - } + ( UNCACHE ~ TABLE ~> ident ^^ { + case tableName => UncacheTableCommand(tableName) + } + | CLEAR ~ CACHE ^^^ ClearCacheCommand + ) private lazy val set: Parser[LogicalPlan] = SET ~> restInput ^^ { case input => SetCommandParser(input) } + private lazy val show: Parser[LogicalPlan] = + SHOW ~> TABLES ~ (IN ~> ident).? ^^ { + case _ ~ dbName => ShowTablesCommand(dbName) + } + private lazy val others: Parser[LogicalPlan] = wholeInput ^^ { case input => fallback(input) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala similarity index 56% rename from sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala rename to sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 1beb19437a8da..b97aaf73529a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -32,9 +32,9 @@ import org.apache.spark.sql.types.DataType /** - * Functions for registering user-defined functions. + * Functions for registering user-defined functions. Use [[SQLContext.udf]] to access this. */ -class UDFRegistration(sqlContext: SQLContext) extends Logging { +class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { private val functionRegistry = sqlContext.functionRegistry @@ -61,7 +61,7 @@ class UDFRegistration(sqlContext: SQLContext) extends Logging { val dataType = sqlContext.parseDataType(stringDataType) - def builder(e: Seq[Expression]) = + def builder(e: Seq[Expression]): PythonUDF = PythonUDF( name, command, @@ -78,20 +78,21 @@ class UDFRegistration(sqlContext: SQLContext) extends Logging { // scalastyle:off - /* registerFunction 0-22 were generated by this script + /* register 0-22 were generated by this script (0 to 22).map { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) val typeTags = (1 to x).map(i => s"A${i}: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) - val argDocs = (1 to x).map(i => s" * @tparam A$i type of the UDF argument at position $i.").foldLeft("")(_ + "\n" + _) println(s""" /** * Register a Scala closure of ${x} arguments as user-defined function (UDF). - * @tparam RT return type of UDF.$argDocs + * @tparam RT return type of UDF. */ - def register[$typeTags](name: String, func: Function$x[$types]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) }""") } @@ -116,462 +117,258 @@ class UDFRegistration(sqlContext: SQLContext) extends Logging { * Register a Scala closure of 0 arguments as user-defined function (UDF). * @tparam RT return type of UDF. */ - def register[RT: TypeTag](name: String, func: Function0[RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 1 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. */ - def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 2 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 3 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 4 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 5 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 6 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 7 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 8 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 9 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - * @tparam A9 type of the UDF argument at position 9. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 10 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - * @tparam A9 type of the UDF argument at position 9. - * @tparam A10 type of the UDF argument at position 10. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 11 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - * @tparam A9 type of the UDF argument at position 9. - * @tparam A10 type of the UDF argument at position 10. - * @tparam A11 type of the UDF argument at position 11. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 12 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - * @tparam A9 type of the UDF argument at position 9. - * @tparam A10 type of the UDF argument at position 10. - * @tparam A11 type of the UDF argument at position 11. - * @tparam A12 type of the UDF argument at position 12. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 13 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - * @tparam A9 type of the UDF argument at position 9. - * @tparam A10 type of the UDF argument at position 10. - * @tparam A11 type of the UDF argument at position 11. - * @tparam A12 type of the UDF argument at position 12. - * @tparam A13 type of the UDF argument at position 13. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 14 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - * @tparam A9 type of the UDF argument at position 9. - * @tparam A10 type of the UDF argument at position 10. - * @tparam A11 type of the UDF argument at position 11. - * @tparam A12 type of the UDF argument at position 12. - * @tparam A13 type of the UDF argument at position 13. - * @tparam A14 type of the UDF argument at position 14. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 15 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - * @tparam A9 type of the UDF argument at position 9. - * @tparam A10 type of the UDF argument at position 10. - * @tparam A11 type of the UDF argument at position 11. - * @tparam A12 type of the UDF argument at position 12. - * @tparam A13 type of the UDF argument at position 13. - * @tparam A14 type of the UDF argument at position 14. - * @tparam A15 type of the UDF argument at position 15. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 16 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - * @tparam A9 type of the UDF argument at position 9. - * @tparam A10 type of the UDF argument at position 10. - * @tparam A11 type of the UDF argument at position 11. - * @tparam A12 type of the UDF argument at position 12. - * @tparam A13 type of the UDF argument at position 13. - * @tparam A14 type of the UDF argument at position 14. - * @tparam A15 type of the UDF argument at position 15. - * @tparam A16 type of the UDF argument at position 16. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 17 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - * @tparam A9 type of the UDF argument at position 9. - * @tparam A10 type of the UDF argument at position 10. - * @tparam A11 type of the UDF argument at position 11. - * @tparam A12 type of the UDF argument at position 12. - * @tparam A13 type of the UDF argument at position 13. - * @tparam A14 type of the UDF argument at position 14. - * @tparam A15 type of the UDF argument at position 15. - * @tparam A16 type of the UDF argument at position 16. - * @tparam A17 type of the UDF argument at position 17. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 18 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - * @tparam A9 type of the UDF argument at position 9. - * @tparam A10 type of the UDF argument at position 10. - * @tparam A11 type of the UDF argument at position 11. - * @tparam A12 type of the UDF argument at position 12. - * @tparam A13 type of the UDF argument at position 13. - * @tparam A14 type of the UDF argument at position 14. - * @tparam A15 type of the UDF argument at position 15. - * @tparam A16 type of the UDF argument at position 16. - * @tparam A17 type of the UDF argument at position 17. - * @tparam A18 type of the UDF argument at position 18. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 19 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - * @tparam A9 type of the UDF argument at position 9. - * @tparam A10 type of the UDF argument at position 10. - * @tparam A11 type of the UDF argument at position 11. - * @tparam A12 type of the UDF argument at position 12. - * @tparam A13 type of the UDF argument at position 13. - * @tparam A14 type of the UDF argument at position 14. - * @tparam A15 type of the UDF argument at position 15. - * @tparam A16 type of the UDF argument at position 16. - * @tparam A17 type of the UDF argument at position 17. - * @tparam A18 type of the UDF argument at position 18. - * @tparam A19 type of the UDF argument at position 19. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 20 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - * @tparam A9 type of the UDF argument at position 9. - * @tparam A10 type of the UDF argument at position 10. - * @tparam A11 type of the UDF argument at position 11. - * @tparam A12 type of the UDF argument at position 12. - * @tparam A13 type of the UDF argument at position 13. - * @tparam A14 type of the UDF argument at position 14. - * @tparam A15 type of the UDF argument at position 15. - * @tparam A16 type of the UDF argument at position 16. - * @tparam A17 type of the UDF argument at position 17. - * @tparam A18 type of the UDF argument at position 18. - * @tparam A19 type of the UDF argument at position 19. - * @tparam A20 type of the UDF argument at position 20. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 21 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - * @tparam A9 type of the UDF argument at position 9. - * @tparam A10 type of the UDF argument at position 10. - * @tparam A11 type of the UDF argument at position 11. - * @tparam A12 type of the UDF argument at position 12. - * @tparam A13 type of the UDF argument at position 13. - * @tparam A14 type of the UDF argument at position 14. - * @tparam A15 type of the UDF argument at position 15. - * @tparam A16 type of the UDF argument at position 16. - * @tparam A17 type of the UDF argument at position 17. - * @tparam A18 type of the UDF argument at position 18. - * @tparam A19 type of the UDF argument at position 19. - * @tparam A20 type of the UDF argument at position 20. - * @tparam A21 type of the UDF argument at position 21. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 22 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - * @tparam A9 type of the UDF argument at position 9. - * @tparam A10 type of the UDF argument at position 10. - * @tparam A11 type of the UDF argument at position 11. - * @tparam A12 type of the UDF argument at position 12. - * @tparam A13 type of the UDF argument at position 13. - * @tparam A14 type of the UDF argument at position 14. - * @tparam A15 type of the UDF argument at position 15. - * @tparam A16 type of the UDF argument at position 16. - * @tparam A17 type of the UDF argument at position 17. - * @tparam A18 type of the UDF argument at position 18. - * @tparam A19 type of the UDF argument at position 19. - * @tparam A20 type of the UDF argument at position 20. - * @tparam A21 type of the UDF argument at position 21. - * @tparam A22 type of the UDF argument at position 22. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } + ////////////////////////////////////////////////////////////////////////////////////////////// + ////////////////////////////////////////////////////////////////////////////////////////////// + /** * Register a user-defined function with 1 arguments. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala new file mode 100644 index 0000000000000..295db539adfc4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala @@ -0,0 +1,67 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import java.util.{List => JList, Map => JMap} + +import org.apache.spark.Accumulator +import org.apache.spark.api.python.PythonBroadcast +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.expressions.ScalaUdf +import org.apache.spark.sql.execution.PythonUDF +import org.apache.spark.sql.types.DataType + +/** + * A user-defined function. To create one, use the `udf` functions in [[functions]]. + * As an example: + * {{{ + * // Defined a UDF that returns true or false based on some numeric score. + * val predict = udf((score: Double) => if (score > 0.5) true else false) + * + * // Projects a column that adds a prediction column based on the score column. + * df.select( predict(df("score")) ) + * }}} + */ +case class UserDefinedFunction protected[sql] (f: AnyRef, dataType: DataType) { + + def apply(exprs: Column*): Column = { + Column(ScalaUdf(f, dataType, exprs.map(_.expr))) + } +} + +/** + * A user-defined Python function. To create one, use the `pythonUDF` functions in [[functions]]. + * This is used by Python API. + */ +private[sql] case class UserDefinedPythonFunction( + name: String, + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + broadcastVars: JList[Broadcast[PythonBroadcast]], + accumulator: Accumulator[JList[Array[Byte]]], + dataType: DataType) { + + /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ + def apply(exprs: Column*): Column = { + val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, broadcastVars, + accumulator, dataType, exprs.map(_.expr)) + Column(udf) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/package.scala new file mode 100644 index 0000000000000..cbbd005228d44 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/package.scala @@ -0,0 +1,23 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +/** + * Contains API classes that are specific to a single language (i.e. Java). + */ +package object api diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index 91c4c105b14e6..b7b3fc221e340 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -48,9 +48,9 @@ private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType]( protected def initialize() {} - def hasNext = buffer.hasRemaining + override def hasNext: Boolean = buffer.hasRemaining - def extractTo(row: MutableRow, ordinal: Int): Unit = { + override def extractTo(row: MutableRow, ordinal: Int): Unit = { extractSingle(row, ordinal) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 3a4977b836af7..b7d93d5dc2cf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -58,7 +58,7 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( override def initialize( initialSize: Int, columnName: String = "", - useCompression: Boolean = false) = { + useCompression: Boolean = false): Unit = { val size = if (initialSize == 0) DEFAULT_INITIAL_BUFFER_SIZE else initialSize this.columnName = columnName @@ -73,7 +73,7 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( columnType.append(row, ordinal, buffer) } - override def build() = { + override def build(): ByteBuffer = { buffer.flip().asInstanceOf[ByteBuffer] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 391b3dae5c8ce..07b61af2d9985 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.columnar -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute, AttributeReference} @@ -76,7 +76,7 @@ private[sql] sealed trait ColumnStats extends Serializable { private[sql] class NoopColumnStats extends ColumnStats { override def gatherStats(row: Row, ordinal: Int): Unit = super.gatherStats(row, ordinal) - def collectedStatistics = Row(null, null, nullCount, count, 0L) + override def collectedStatistics: Row = Row(null, null, nullCount, count, 0L) } private[sql] class BooleanColumnStats extends ColumnStats { @@ -93,7 +93,7 @@ private[sql] class BooleanColumnStats extends ColumnStats { } } - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class ByteColumnStats extends ColumnStats { @@ -110,7 +110,7 @@ private[sql] class ByteColumnStats extends ColumnStats { } } - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class ShortColumnStats extends ColumnStats { @@ -127,7 +127,7 @@ private[sql] class ShortColumnStats extends ColumnStats { } } - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class LongColumnStats extends ColumnStats { @@ -144,7 +144,7 @@ private[sql] class LongColumnStats extends ColumnStats { } } - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class DoubleColumnStats extends ColumnStats { @@ -161,7 +161,7 @@ private[sql] class DoubleColumnStats extends ColumnStats { } } - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class FloatColumnStats extends ColumnStats { @@ -178,7 +178,7 @@ private[sql] class FloatColumnStats extends ColumnStats { } } - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class IntColumnStats extends ColumnStats { @@ -195,7 +195,7 @@ private[sql] class IntColumnStats extends ColumnStats { } } - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class StringColumnStats extends ColumnStats { @@ -212,25 +212,10 @@ private[sql] class StringColumnStats extends ColumnStats { } } - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) } -private[sql] class DateColumnStats extends ColumnStats { - protected var upper: Date = null - protected var lower: Date = null - - override def gatherStats(row: Row, ordinal: Int) { - super.gatherStats(row, ordinal) - if (!row.isNullAt(ordinal)) { - val value = row(ordinal).asInstanceOf[Date] - if (upper == null || value.compareTo(upper) > 0) upper = value - if (lower == null || value.compareTo(lower) < 0) lower = value - sizeInBytes += DATE.defaultSize - } - } - - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) -} +private[sql] class DateColumnStats extends IntColumnStats private[sql] class TimestampColumnStats extends ColumnStats { protected var upper: Timestamp = null @@ -246,7 +231,7 @@ private[sql] class TimestampColumnStats extends ColumnStats { } } - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class BinaryColumnStats extends ColumnStats { @@ -257,7 +242,7 @@ private[sql] class BinaryColumnStats extends ColumnStats { } } - def collectedStatistics = Row(null, null, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(null, null, nullCount, count, sizeInBytes) } private[sql] class GenericColumnStats extends ColumnStats { @@ -268,5 +253,5 @@ private[sql] class GenericColumnStats extends ColumnStats { } } - def collectedStatistics = Row(null, null, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(null, null, nullCount, count, sizeInBytes) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index fcf2faa0914c0..69d89dda6a929 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -98,7 +98,7 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( */ def clone(v: JvmType): JvmType = v - override def toString = getClass.getSimpleName.stripSuffix("$") + override def toString: String = getClass.getSimpleName.stripSuffix("$") } private[sql] abstract class NativeColumnType[T <: NativeType]( @@ -114,7 +114,7 @@ private[sql] abstract class NativeColumnType[T <: NativeType]( } private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) { - def append(v: Int, buffer: ByteBuffer): Unit = { + override def append(v: Int, buffer: ByteBuffer): Unit = { buffer.putInt(v) } @@ -122,7 +122,7 @@ private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) { buffer.putInt(row.getInt(ordinal)) } - def extract(buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer): Int = { buffer.getInt() } @@ -134,7 +134,7 @@ private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) { row.setInt(ordinal, value) } - override def getField(row: Row, ordinal: Int) = row.getInt(ordinal) + override def getField(row: Row, ordinal: Int): Int = row.getInt(ordinal) override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { to.setInt(toOrdinal, from.getInt(fromOrdinal)) @@ -150,7 +150,7 @@ private[sql] object LONG extends NativeColumnType(LongType, 1, 8) { buffer.putLong(row.getLong(ordinal)) } - override def extract(buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer): Long = { buffer.getLong() } @@ -162,7 +162,7 @@ private[sql] object LONG extends NativeColumnType(LongType, 1, 8) { row.setLong(ordinal, value) } - override def getField(row: Row, ordinal: Int) = row.getLong(ordinal) + override def getField(row: Row, ordinal: Int): Long = row.getLong(ordinal) override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { to.setLong(toOrdinal, from.getLong(fromOrdinal)) @@ -178,7 +178,7 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) { buffer.putFloat(row.getFloat(ordinal)) } - override def extract(buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer): Float = { buffer.getFloat() } @@ -190,7 +190,7 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) { row.setFloat(ordinal, value) } - override def getField(row: Row, ordinal: Int) = row.getFloat(ordinal) + override def getField(row: Row, ordinal: Int): Float = row.getFloat(ordinal) override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { to.setFloat(toOrdinal, from.getFloat(fromOrdinal)) @@ -206,7 +206,7 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) { buffer.putDouble(row.getDouble(ordinal)) } - override def extract(buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer): Double = { buffer.getDouble() } @@ -218,7 +218,7 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) { row.setDouble(ordinal, value) } - override def getField(row: Row, ordinal: Int) = row.getDouble(ordinal) + override def getField(row: Row, ordinal: Int): Double = row.getDouble(ordinal) override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { to.setDouble(toOrdinal, from.getDouble(fromOrdinal)) @@ -234,7 +234,7 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) { buffer.put(if (row.getBoolean(ordinal)) 1: Byte else 0: Byte) } - override def extract(buffer: ByteBuffer) = buffer.get() == 1 + override def extract(buffer: ByteBuffer): Boolean = buffer.get() == 1 override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { row.setBoolean(ordinal, buffer.get() == 1) @@ -244,7 +244,7 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) { row.setBoolean(ordinal, value) } - override def getField(row: Row, ordinal: Int) = row.getBoolean(ordinal) + override def getField(row: Row, ordinal: Int): Boolean = row.getBoolean(ordinal) override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { to.setBoolean(toOrdinal, from.getBoolean(fromOrdinal)) @@ -260,7 +260,7 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) { buffer.put(row.getByte(ordinal)) } - override def extract(buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer): Byte = { buffer.get() } @@ -272,7 +272,7 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) { row.setByte(ordinal, value) } - override def getField(row: Row, ordinal: Int) = row.getByte(ordinal) + override def getField(row: Row, ordinal: Int): Byte = row.getByte(ordinal) override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { to.setByte(toOrdinal, from.getByte(fromOrdinal)) @@ -288,7 +288,7 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) { buffer.putShort(row.getShort(ordinal)) } - override def extract(buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer): Short = { buffer.getShort() } @@ -300,7 +300,7 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) { row.setShort(ordinal, value) } - override def getField(row: Row, ordinal: Int) = row.getShort(ordinal) + override def getField(row: Row, ordinal: Int): Short = row.getShort(ordinal) override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { to.setShort(toOrdinal, from.getShort(fromOrdinal)) @@ -317,7 +317,7 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { buffer.putInt(stringBytes.length).put(stringBytes, 0, stringBytes.length) } - override def extract(buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer): String = { val length = buffer.getInt() val stringBytes = new Array[Byte](length) buffer.get(stringBytes, 0, length) @@ -328,34 +328,33 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { row.setString(ordinal, value) } - override def getField(row: Row, ordinal: Int) = row.getString(ordinal) + override def getField(row: Row, ordinal: Int): String = row.getString(ordinal) override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { to.setString(toOrdinal, from.getString(fromOrdinal)) } } -private[sql] object DATE extends NativeColumnType(DateType, 8, 8) { - override def extract(buffer: ByteBuffer) = { - val date = new Date(buffer.getLong()) - date +private[sql] object DATE extends NativeColumnType(DateType, 8, 4) { + override def extract(buffer: ByteBuffer): Int = { + buffer.getInt } - override def append(v: Date, buffer: ByteBuffer): Unit = { - buffer.putLong(v.getTime) + override def append(v: Int, buffer: ByteBuffer): Unit = { + buffer.putInt(v) } - override def getField(row: Row, ordinal: Int) = { - row(ordinal).asInstanceOf[Date] + override def getField(row: Row, ordinal: Int): Int = { + row(ordinal).asInstanceOf[Int] } - override def setField(row: MutableRow, ordinal: Int, value: Date): Unit = { + def setField(row: MutableRow, ordinal: Int, value: Int): Unit = { row(ordinal) = value } } private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 12) { - override def extract(buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer): Timestamp = { val timestamp = new Timestamp(buffer.getLong()) timestamp.setNanos(buffer.getInt()) timestamp @@ -365,7 +364,7 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 12) { buffer.putLong(v.getTime).putInt(v.getNanos) } - override def getField(row: Row, ordinal: Int) = { + override def getField(row: Row, ordinal: Int): Timestamp = { row(ordinal).asInstanceOf[Timestamp] } @@ -379,7 +378,7 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( defaultSize: Int) extends ColumnType[T, Array[Byte]](typeId, defaultSize) { - override def actualSize(row: Row, ordinal: Int) = { + override def actualSize(row: Row, ordinal: Int): Int = { getField(row, ordinal).length + 4 } @@ -387,7 +386,7 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( buffer.putInt(v.length).put(v, 0, v.length) } - override def extract(buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer): Array[Byte] = { val length = buffer.getInt() val bytes = new Array[Byte](length) buffer.get(bytes, 0, length) @@ -400,7 +399,9 @@ private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](10, 16) row(ordinal) = value } - override def getField(row: Row, ordinal: Int) = row(ordinal).asInstanceOf[Array[Byte]] + override def getField(row: Row, ordinal: Int): Array[Byte] = { + row(ordinal).asInstanceOf[Array[Byte]] + } } // Used to process generic objects (all types other than those listed above). Objects should be @@ -411,7 +412,9 @@ private[sql] object GENERIC extends ByteArrayColumnType[DataType](11, 16) { row(ordinal) = SparkSqlSerializer.deserialize[Any](value) } - override def getField(row: Row, ordinal: Int) = SparkSqlSerializer.serialize(row(ordinal)) + override def getField(row: Row, ordinal: Int): Array[Byte] = { + SparkSqlSerializer.serialize(row(ordinal)) + } } private[sql] object ColumnType { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 11d5943fb427f..c0fcbdfe9237a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -19,6 +19,9 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer +import org.apache.spark.Accumulator +import org.apache.spark.sql.catalyst.expressions + import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD @@ -77,20 +80,23 @@ private[sql] case class InMemoryRelation( _statistics } - override def statistics = if (_statistics == null) { - if (batchStats.value.isEmpty) { - // Underlying columnar RDD hasn't been materialized, no useful statistics information - // available, return the default statistics. - Statistics(sizeInBytes = child.sqlContext.conf.defaultSizeInBytes) + override def statistics: Statistics = { + if (_statistics == null) { + if (batchStats.value.isEmpty) { + // Underlying columnar RDD hasn't been materialized, no useful statistics information + // available, return the default statistics. + Statistics(sizeInBytes = child.sqlContext.conf.defaultSizeInBytes) + } else { + // Underlying columnar RDD has been materialized, required information has also been + // collected via the `batchStats` accumulator, compute the final statistics, + // and update `_statistics`. + _statistics = Statistics(sizeInBytes = computeSizeInBytes) + _statistics + } } else { - // Underlying columnar RDD has been materialized, required information has also been collected - // via the `batchStats` accumulator, compute the final statistics, and update `_statistics`. - _statistics = Statistics(sizeInBytes = computeSizeInBytes) + // Pre-computed statistics _statistics } - } else { - // Pre-computed statistics - _statistics } // If the cached column buffers were not passed in, we calculate them in the constructor. @@ -99,7 +105,7 @@ private[sql] case class InMemoryRelation( buildBuffers() } - def recache() = { + def recache(): Unit = { _cachedColumnBuffers.unpersist() _cachedColumnBuffers = null buildBuffers() @@ -109,7 +115,7 @@ private[sql] case class InMemoryRelation( val output = child.output val cached = child.execute().mapPartitions { rowIterator => new Iterator[CachedBatch] { - def next() = { + def next(): CachedBatch = { val columnBuilders = output.map { attribute => val columnType = ColumnType(attribute.dataType) val initialBufferSize = columnType.defaultSize * batchSize @@ -119,6 +125,17 @@ private[sql] case class InMemoryRelation( var rowCount = 0 while (rowIterator.hasNext && rowCount < batchSize) { val row = rowIterator.next() + + // Added for SPARK-6082. This assertion can be useful for scenarios when something + // like Hive TRANSFORM is used. The external data generation script used in TRANSFORM + // may result malformed rows, causing ArrayIndexOutOfBoundsException, which is somewhat + // hard to decipher. + assert( + row.size == columnBuilders.size, + s"""Row column number mismatch, expected ${output.size} columns, but got ${row.size}. + |Row content: $row + """.stripMargin) + var i = 0 while (i < row.length) { columnBuilders(i).appendFrom(row, i) @@ -133,7 +150,7 @@ private[sql] case class InMemoryRelation( CachedBatch(columnBuilders.map(_.build().array()), stats) } - def hasNext = rowIterator.hasNext + def hasNext: Boolean = rowIterator.hasNext } }.persist(storageLevel) @@ -147,9 +164,9 @@ private[sql] case class InMemoryRelation( _cachedColumnBuffers, statisticsToBePropagated) } - override def children = Seq.empty + override def children: Seq[LogicalPlan] = Seq.empty - override def newInstance() = { + override def newInstance(): this.type = { new InMemoryRelation( output.map(_.newInstance()), useCompression, @@ -161,7 +178,7 @@ private[sql] case class InMemoryRelation( statisticsToBePropagated).asInstanceOf[this.type] } - def cachedColumnBuffers = _cachedColumnBuffers + def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers override protected def otherCopyArgs: Seq[AnyRef] = Seq(_cachedColumnBuffers, statisticsToBePropagated) @@ -209,7 +226,7 @@ private[sql] case class InMemoryColumnarTableScan( case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0 } - val partitionFilters = { + val partitionFilters: Seq[Expression] = { predicates.flatMap { p => val filter = buildFilter.lift(p) val boundFilter = @@ -228,12 +245,12 @@ private[sql] case class InMemoryColumnarTableScan( } // Accumulators used for testing purposes - val readPartitions = sparkContext.accumulator(0) - val readBatches = sparkContext.accumulator(0) + val readPartitions: Accumulator[Int] = sparkContext.accumulator(0) + val readBatches: Accumulator[Int] = sparkContext.accumulator(0) private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning - override def execute() = { + override def execute(): RDD[Row] = { readPartitions.setValue(0) readBatches.setValue(0) @@ -260,7 +277,7 @@ private[sql] case class InMemoryColumnarTableScan( val nextRow = new SpecificMutableRow(requestedColumnDataTypes) - def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]) = { + def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]): Iterator[Row] = { val rows = cacheBatches.flatMap { cachedBatch => // Build column accessors val columnAccessors = requestedColumnIndices.map { batch => @@ -270,7 +287,7 @@ private[sql] case class InMemoryColumnarTableScan( // Extract rows via column accessors new Iterator[Row] { private[this] val rowLen = nextRow.length - override def next() = { + override def next(): Row = { var i = 0 while (i < rowLen) { columnAccessors(i).extractTo(nextRow, i) @@ -279,7 +296,7 @@ private[sql] case class InMemoryColumnarTableScan( nextRow } - override def hasNext = columnAccessors(0).hasNext + override def hasNext: Boolean = columnAccessors(0).hasNext } } @@ -295,7 +312,7 @@ private[sql] case class InMemoryColumnarTableScan( if (inMemoryPartitionPruningEnabled) { cachedBatchIterator.filter { cachedBatch => if (!partitionFilter(cachedBatch.stats)) { - def statsString = relation.partitionStatistics.schema + def statsString: String = relation.partitionStatistics.schema .zip(cachedBatch.stats.toSeq) .map { case (a, s) => s"${a.name}: $s" } .mkString(", ") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala index 965782a40031b..4d35650d4b1eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala @@ -55,5 +55,5 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor { pos += 1 } - abstract override def hasNext = seenNulls < nullCount || super.hasNext + abstract override def hasNext: Boolean = seenNulls < nullCount || super.hasNext } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala index 7dff9deac8dc0..d0b602a834dfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala @@ -26,12 +26,12 @@ private[sql] trait CompressibleColumnAccessor[T <: NativeType] extends ColumnAcc private var decoder: Decoder[T] = _ - abstract override protected def initialize() = { + abstract override protected def initialize(): Unit = { super.initialize() decoder = CompressionScheme(underlyingBuffer.getInt()).decoder(buffer, columnType) } - abstract override def hasNext = super.hasNext || decoder.hasNext + abstract override def hasNext: Boolean = super.hasNext || decoder.hasNext override def extractSingle(row: MutableRow, ordinal: Int): Unit = { decoder.next(row, ordinal) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala index aead768ecdf0a..b9cfc5df550d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala @@ -81,7 +81,7 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] } } - override def build() = { + override def build(): ByteBuffer = { val nonNullBuffer = buildNonNulls() val typeId = nonNullBuffer.getInt() val encoder: Encoder[T] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala index 68a5b1de7691b..8727d71c48bb7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala @@ -33,22 +33,23 @@ import org.apache.spark.util.Utils private[sql] case object PassThrough extends CompressionScheme { override val typeId = 0 - override def supports(columnType: ColumnType[_, _]) = true + override def supports(columnType: ColumnType[_, _]): Boolean = true - override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + override def encoder[T <: NativeType](columnType: NativeColumnType[T]): Encoder[T] = { new this.Encoder[T](columnType) } - override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + override def decoder[T <: NativeType]( + buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] = { new this.Decoder(buffer, columnType) } class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { - override def uncompressedSize = 0 + override def uncompressedSize: Int = 0 - override def compressedSize = 0 + override def compressedSize: Int = 0 - override def compress(from: ByteBuffer, to: ByteBuffer) = { + override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = { // Writes compression type ID and copies raw contents to.putInt(PassThrough.typeId).put(from).rewind() to @@ -62,22 +63,23 @@ private[sql] case object PassThrough extends CompressionScheme { columnType.extract(buffer, row, ordinal) } - override def hasNext = buffer.hasRemaining + override def hasNext: Boolean = buffer.hasRemaining } } private[sql] case object RunLengthEncoding extends CompressionScheme { override val typeId = 1 - override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + override def encoder[T <: NativeType](columnType: NativeColumnType[T]): Encoder[T] = { new this.Encoder[T](columnType) } - override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + override def decoder[T <: NativeType]( + buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] = { new this.Decoder(buffer, columnType) } - override def supports(columnType: ColumnType[_, _]) = columnType match { + override def supports(columnType: ColumnType[_, _]): Boolean = columnType match { case INT | LONG | SHORT | BYTE | STRING | BOOLEAN => true case _ => false } @@ -90,9 +92,9 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { private val lastValue = new SpecificMutableRow(Seq(columnType.dataType)) private var lastRun = 0 - override def uncompressedSize = _uncompressedSize + override def uncompressedSize: Int = _uncompressedSize - override def compressedSize = _compressedSize + override def compressedSize: Int = _compressedSize override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { val value = columnType.getField(row, ordinal) @@ -114,7 +116,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { } } - override def compress(from: ByteBuffer, to: ByteBuffer) = { + override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = { to.putInt(RunLengthEncoding.typeId) if (from.hasRemaining) { @@ -169,7 +171,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { columnType.setField(row, ordinal, currentValue) } - override def hasNext = valueCount < run || buffer.hasRemaining + override def hasNext: Boolean = valueCount < run || buffer.hasRemaining } } @@ -179,15 +181,16 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { // 32K unique values allowed val MAX_DICT_SIZE = Short.MaxValue - override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + : Decoder[T] = { new this.Decoder(buffer, columnType) } - override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + override def encoder[T <: NativeType](columnType: NativeColumnType[T]): Encoder[T] = { new this.Encoder[T](columnType) } - override def supports(columnType: ColumnType[_, _]) = columnType match { + override def supports(columnType: ColumnType[_, _]): Boolean = columnType match { case INT | LONG | STRING => true case _ => false } @@ -237,7 +240,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { } } - override def compress(from: ByteBuffer, to: ByteBuffer) = { + override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = { if (overflow) { throw new IllegalStateException( "Dictionary encoding should not be used because of dictionary overflow.") @@ -260,9 +263,9 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { to } - override def uncompressedSize = _uncompressedSize + override def uncompressedSize: Int = _uncompressedSize - override def compressedSize = if (overflow) Int.MaxValue else dictionarySize + count * 2 + override def compressedSize: Int = if (overflow) Int.MaxValue else dictionarySize + count * 2 } class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) @@ -284,7 +287,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { columnType.setField(row, ordinal, dictionary(buffer.getShort())) } - override def hasNext = buffer.hasRemaining + override def hasNext: Boolean = buffer.hasRemaining } } @@ -293,15 +296,16 @@ private[sql] case object BooleanBitSet extends CompressionScheme { val BITS_PER_LONG = 64 - override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + : compression.Decoder[T] = { new this.Decoder(buffer).asInstanceOf[compression.Decoder[T]] } - override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + override def encoder[T <: NativeType](columnType: NativeColumnType[T]): compression.Encoder[T] = { (new this.Encoder).asInstanceOf[compression.Encoder[T]] } - override def supports(columnType: ColumnType[_, _]) = columnType == BOOLEAN + override def supports(columnType: ColumnType[_, _]): Boolean = columnType == BOOLEAN class Encoder extends compression.Encoder[BooleanType.type] { private var _uncompressedSize = 0 @@ -310,7 +314,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme { _uncompressedSize += BOOLEAN.defaultSize } - override def compress(from: ByteBuffer, to: ByteBuffer) = { + override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = { to.putInt(BooleanBitSet.typeId) // Total element count (1 byte per Boolean value) .putInt(from.remaining) @@ -347,9 +351,9 @@ private[sql] case object BooleanBitSet extends CompressionScheme { to } - override def uncompressedSize = _uncompressedSize + override def uncompressedSize: Int = _uncompressedSize - override def compressedSize = { + override def compressedSize: Int = { val extra = if (_uncompressedSize % BITS_PER_LONG == 0) 0 else 1 (_uncompressedSize / BITS_PER_LONG + extra) * 8 + 4 } @@ -380,22 +384,23 @@ private[sql] case object BooleanBitSet extends CompressionScheme { private[sql] case object IntDelta extends CompressionScheme { override def typeId: Int = 4 - override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + : compression.Decoder[T] = { new Decoder(buffer, INT).asInstanceOf[compression.Decoder[T]] } - override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + override def encoder[T <: NativeType](columnType: NativeColumnType[T]): compression.Encoder[T] = { (new Encoder).asInstanceOf[compression.Encoder[T]] } - override def supports(columnType: ColumnType[_, _]) = columnType == INT + override def supports(columnType: ColumnType[_, _]): Boolean = columnType == INT class Encoder extends compression.Encoder[IntegerType.type] { protected var _compressedSize: Int = 0 protected var _uncompressedSize: Int = 0 - override def compressedSize = _compressedSize - override def uncompressedSize = _uncompressedSize + override def compressedSize: Int = _compressedSize + override def uncompressedSize: Int = _uncompressedSize private var prevValue: Int = _ @@ -459,22 +464,23 @@ private[sql] case object IntDelta extends CompressionScheme { private[sql] case object LongDelta extends CompressionScheme { override def typeId: Int = 5 - override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + : compression.Decoder[T] = { new Decoder(buffer, LONG).asInstanceOf[compression.Decoder[T]] } - override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + override def encoder[T <: NativeType](columnType: NativeColumnType[T]): compression.Encoder[T] = { (new Encoder).asInstanceOf[compression.Encoder[T]] } - override def supports(columnType: ColumnType[_, _]) = columnType == LONG + override def supports(columnType: ColumnType[_, _]): Boolean = columnType == LONG class Encoder extends compression.Encoder[LongType.type] { protected var _compressedSize: Int = 0 protected var _uncompressedSize: Int = 0 - override def compressedSize = _compressedSize - override def uncompressedSize = _uncompressedSize + override def compressedSize: Int = _compressedSize + override def uncompressedSize: Int = _uncompressedSize private var prevValue: Long = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala old mode 100755 new mode 100644 index ad44a01d0e164..18b1ba4c5c4b9 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -21,6 +21,7 @@ import java.util.HashMap import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -45,7 +46,7 @@ case class Aggregate( child: SparkPlan) extends UnaryNode { - override def requiredChildDistribution = + override def requiredChildDistribution: List[Distribution] = { if (partial) { UnspecifiedDistribution :: Nil } else { @@ -55,8 +56,9 @@ case class Aggregate( ClusteredDistribution(groupingExpressions) :: Nil } } + } - override def output = aggregateExpressions.map(_.toAttribute) + override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) /** * An aggregate that needs to be computed for each row in a group. @@ -119,7 +121,7 @@ case class Aggregate( } } - override def execute() = attachTree(this, "execute") { + override def execute(): RDD[Row] = attachTree(this, "execute") { if (groupingExpressions.isEmpty) { child.execute().mapPartitions { iter => val buffer = newAggregateBuffer() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 7c0b72aab448e..437408d30bfd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.sql.catalyst.expressions import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, SparkConf} -import org.apache.spark.rdd.ShuffledRDD +import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.RowOrdering +import org.apache.spark.sql.catalyst.expressions.{Attribute, RowOrdering} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair @@ -34,9 +35,9 @@ import org.apache.spark.util.MutablePair @DeveloperApi case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { - override def outputPartitioning = newPartitioning + override def outputPartitioning: Partitioning = newPartitioning - override def output = child.output + override def output: Seq[Attribute] = child.output /** We must copy rows when sort based shuffle is on */ protected def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] @@ -44,7 +45,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una private val bypassMergeThreshold = child.sqlContext.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - override def execute() = attachTree(this , "execute") { + override def execute(): RDD[Row] = attachTree(this , "execute") { newPartitioning match { case HashPartitioning(expressions, numPartitions) => // TODO: Eliminate redundant expressions in grouping key and value. @@ -123,13 +124,13 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una */ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPlan] { // TODO: Determine the number of partitions. - def numPartitions = sqlContext.conf.numShufflePartitions + def numPartitions: Int = sqlContext.conf.numShufflePartitions def apply(plan: SparkPlan): SparkPlan = plan.transformUp { case operator: SparkPlan => // Check if every child's outputPartitioning satisfies the corresponding // required data distribution. - def meetsRequirements = + def meetsRequirements: Boolean = !operator.requiredChildDistribution.zip(operator.children).map { case (required, child) => val valid = child.outputPartitioning.satisfies(required) @@ -147,7 +148,7 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl // datasets are both clustered by "a", but these two outputPartitionings are not // compatible. // TODO: ASSUMES TRANSITIVITY? - def compatible = + def compatible: Boolean = !operator.children .map(_.outputPartitioning) .sliding(2) @@ -158,7 +159,7 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl // Check if the partitioning we want to ensure is the same as the child's output // partitioning. If so, we do not need to add the Exchange operator. - def addExchangeIfNecessary(partitioning: Partitioning, child: SparkPlan) = + def addExchangeIfNecessary(partitioning: Partitioning, child: SparkPlan): SparkPlan = if (child.outputPartitioning != partitioning) Exchange(partitioning, child) else child if (meetsRequirements && compatible) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 20b14834bb0d5..d8955725e59b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -26,6 +26,8 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.types.StructType +import scala.collection.immutable + /** * :: DeveloperApi :: */ @@ -54,59 +56,49 @@ object RDDConversions { } } +/** Logical plan node for scanning data from an RDD. */ case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLContext) extends LogicalPlan with MultiInstanceRelation { - def children = Nil + override def children: Seq[LogicalPlan] = Nil - def newInstance() = + override def newInstance(): LogicalRDD.this.type = LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type] - override def sameResult(plan: LogicalPlan) = plan match { + override def sameResult(plan: LogicalPlan): Boolean = plan match { case LogicalRDD(_, otherRDD) => rdd.id == otherRDD.id case _ => false } - @transient override lazy val statistics = Statistics( + @transient override lazy val statistics: Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. sizeInBytes = BigInt(sqlContext.conf.defaultSizeInBytes) ) } +/** Physical plan node for scanning data from an RDD. */ case class PhysicalRDD(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { - override def execute() = rdd -} - -@deprecated("Use LogicalRDD", "1.2.0") -case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { - override def execute() = rdd + override def execute(): RDD[Row] = rdd } -@deprecated("Use LogicalRDD", "1.2.0") -case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQLContext) - extends LogicalPlan with MultiInstanceRelation { +/** Logical plan node for scanning data from a local collection. */ +case class LogicalLocalTable(output: Seq[Attribute], rows: Seq[Row])(sqlContext: SQLContext) + extends LogicalPlan with MultiInstanceRelation { - def output = alreadyPlanned.output - override def children = Nil + override def children: Seq[LogicalPlan] = Nil - override final def newInstance(): this.type = { - SparkLogicalPlan( - alreadyPlanned match { - case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance()), rdd) - case _ => sys.error("Multiple instance of the same relation detected.") - })(sqlContext).asInstanceOf[this.type] - } + override def newInstance(): this.type = + LogicalLocalTable(output.map(_.newInstance()), rows)(sqlContext).asInstanceOf[this.type] - override def sameResult(plan: LogicalPlan) = plan match { - case SparkLogicalPlan(ExistingRdd(_, rdd)) => - rdd.id == alreadyPlanned.asInstanceOf[ExistingRdd].rdd.id + override def sameResult(plan: LogicalPlan): Boolean = plan match { + case LogicalRDD(_, otherRDD) => rows == rows case _ => false } - @transient override lazy val statistics = Statistics( - // TODO: Instead of returning a default value here, find a way to return a meaningful size - // estimate for RDDs. See PR 1238 for more discussions. - sizeInBytes = BigInt(sqlContext.conf.defaultSizeInBytes) + @transient override lazy val statistics: Statistics = Statistics( + // TODO: Improve the statistics estimation. + // This is made small enough so it can be broadcasted. + sizeInBytes = sqlContext.conf.autoBroadcastJoinThreshold - 1 ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index 95172420608f9..575849481faad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ @@ -42,7 +43,7 @@ case class Expand( // as UNKNOWN partitioning override def outputPartitioning: Partitioning = UnknownPartitioning(0) - override def execute() = attachTree(this, "execute") { + override def execute(): RDD[Row] = attachTree(this, "execute") { child.execute().mapPartitions { iter => // TODO Move out projection objects creation and transfer to // workers via closure. However we can't assume the Projection @@ -55,7 +56,7 @@ case class Expand( private[this] var idx = -1 // -1 means the initial state private[this] var input: Row = _ - override final def hasNext = (-1 < idx && idx < groups.length) || iter.hasNext + override final def hasNext: Boolean = (-1 < idx && idx < groups.length) || iter.hasNext override final def next(): Row = { if (idx <= 0) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 38877c28de3a8..12271048bb39c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ /** @@ -54,7 +55,7 @@ case class Generate( val boundGenerator = BindReferences.bindReference(generator, child.output) - override def execute() = { + override def execute(): RDD[Row] = { if (join) { child.execute().mapPartitions { iter => val nullValues = Seq.fill(generator.output.size)(Literal(null)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 4abe26fe4afc6..89682d25ca7dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.trees._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -49,7 +50,7 @@ case class GeneratedAggregate( child: SparkPlan) extends UnaryNode { - override def requiredChildDistribution = + override def requiredChildDistribution: Seq[Distribution] = if (partial) { UnspecifiedDistribution :: Nil } else { @@ -60,9 +61,9 @@ case class GeneratedAggregate( } } - override def output = aggregateExpressions.map(_.toAttribute) + override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) - override def execute() = { + override def execute(): RDD[Row] = { val aggregatesToCompute = aggregateExpressions.flatMap { a => a.collect { case agg: AggregateExpression => agg} } @@ -271,9 +272,9 @@ case class GeneratedAggregate( private[this] val resultIterator = buffers.entrySet.iterator() private[this] val resultProjection = resultProjectionBuilder() - def hasNext = resultIterator.hasNext + def hasNext: Boolean = resultIterator.hasNext - def next() = { + def next(): Row = { val currentGroup = resultIterator.next() resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala new file mode 100644 index 0000000000000..5bd699a2fa949 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.expressions.Attribute + + +/** + * Physical plan node for scanning data from a local collection. + */ +case class LocalTableScan(output: Seq[Attribute], rows: Seq[Row]) extends LeafNode { + + private lazy val rdd = sqlContext.sparkContext.parallelize(rows) + + override def execute(): RDD[Row] = rdd + + override def executeCollect(): Array[Row] = + rows.map(ScalaReflection.convertRowToScala(_, schema)).toArray + + override def executeTake(limit: Int): Array[Row] = + rows.map(ScalaReflection.convertRowToScala(_, schema)).take(limit).toArray +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 6fecd1ff066c3..d239637cd4b4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -27,6 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ +import scala.collection.mutable.ArrayBuffer + object SparkPlan { protected[sql] val currentContext = new ThreadLocal[SQLContext]() } @@ -65,6 +67,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ // TODO: Move to `DistributedPlan` /** Specifies how data is partitioned across different nodes in the cluster. */ def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH! + /** Specifies any partition requirements on the input data for this operator. */ def requiredChildDistribution: Seq[Distribution] = Seq.fill(children.size)(UnspecifiedDistribution) @@ -77,8 +80,53 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** * Runs this query returning the result as an array. */ - def executeCollect(): Array[Row] = + def executeCollect(): Array[Row] = { execute().map(ScalaReflection.convertRowToScala(_, schema)).collect() + } + + /** + * Runs this query returning the first `n` rows as an array. + * + * This is modeled after RDD.take but never runs any job locally on the driver. + */ + def executeTake(n: Int): Array[Row] = { + if (n == 0) { + return new Array[Row](0) + } + + val childRDD = execute().map(_.copy()) + + val buf = new ArrayBuffer[Row] + val totalParts = childRDD.partitions.length + var partsScanned = 0 + while (buf.size < n && partsScanned < totalParts) { + // The number of partitions to try in this iteration. It is ok for this number to be + // greater than totalParts because we actually cap it at totalParts in runJob. + var numPartsToTry = 1 + if (partsScanned > 0) { + // If we didn't find any rows after the first iteration, just try all partitions next. + // Otherwise, interpolate the number of partitions we need to try, but overestimate it + // by 50%. + if (buf.size == 0) { + numPartsToTry = totalParts - 1 + } else { + numPartsToTry = (1.5 * n * partsScanned / buf.size).toInt + } + } + numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions + + val left = n - buf.size + val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + val sc = sqlContext.sparkContext + val res = + sc.runJob(childRDD, (it: Iterator[Row]) => it.take(left).toArray, p, allowLocal = false) + + res.foreach(buf ++= _.take(n - buf.size)) + partsScanned += numPartsToTry + } + + buf.toArray.map(ScalaReflection.convertRowToScala(_, this.schema)) + } protected def newProjection( expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 30564e14fa896..967bd76b302d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{IntegerHashSet, LongHa private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) { override def newKryo(): Kryo = { - val kryo = new Kryo() + val kryo = super.newKryo() kryo.setRegistrationRequired(false) kryo.register(classOf[MutablePair[_, _]]) kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow]) @@ -57,8 +57,6 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co kryo.register(classOf[Decimal]) kryo.setReferences(false) - kryo.setClassLoader(Utils.getSparkClassLoader) - new AllScalaRegistrar().apply(kryo) kryo } } @@ -74,7 +72,7 @@ private[execution] class KryoResourcePool(size: Int) new KryoSerializer(sparkConf) } - def newInstance() = ser.newInstance() + def newInstance(): SerializerInstance = ser.newInstance() } private[sql] object SparkSqlSerializer { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ff0609d4b3b72..f754fa770d1b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,17 +17,17 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{SQLContext, Strategy, execution} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} +import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} import org.apache.spark.sql.parquet._ +import org.apache.spark.sql.sources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} import org.apache.spark.sql.types._ -import org.apache.spark.sql.sources._ - +import org.apache.spark.sql.{SQLContext, Strategy, execution} private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SQLContext#SparkPlanner => @@ -154,7 +154,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => Nil } - def canBeCodeGened(aggs: Seq[AggregateExpression]) = !aggs.exists { + def canBeCodeGened(aggs: Seq[AggregateExpression]): Boolean = !aggs.exists { case _: Sum | _: Count | _: Max | _: CombineSetsAndCount => false // The generated set implementation is pretty limited ATM. case CollectHashSet(exprs) if exprs.size == 1 && @@ -162,7 +162,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => true } - def allAggregates(exprs: Seq[Expression]) = + def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression] = exprs.flatMap(_.collect { case a: AggregateExpression => a }) } @@ -257,7 +257,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { - def numPartitions = self.numPartitions + def numPartitions: Int = self.numPartitions def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case r: RunnableCommand => ExecutedCommand(r) :: Nil @@ -284,13 +284,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil case logical.Sample(fraction, withReplacement, seed, child) => execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil - case SparkLogicalPlan(alreadyPlanned) => alreadyPlanned :: Nil case logical.LocalRelation(output, data) => - val nPartitions = if (data.isEmpty) 1 else numPartitions - PhysicalRDD( - output, - RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions), - StructType.fromAttributes(output))) :: Nil + LocalTableScan(output, data) :: Nil case logical.Limit(IntegerLiteral(limit), child) => execution.Limit(limit, planLater(child)) :: Nil case Unions(unionChildren) => @@ -301,7 +296,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Intersect(planLater(left), planLater(right)) :: Nil case logical.Generate(generator, join, outer, _, child) => execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil - case logical.NoRelation => + case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil @@ -314,7 +309,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object DDLStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case CreateTableUsing(tableName, userSpecifiedSchema, provider, true, opts, false) => + case CreateTableUsing(tableName, userSpecifiedSchema, provider, true, opts, false, _) => ExecutedCommand( CreateTempTableUsing( tableName, userSpecifiedSchema, provider, opts)) :: Nil @@ -323,24 +318,17 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case c: CreateTableUsing if c.temporary && c.allowExisting => sys.error("allowExisting should be set to false when creating a temporary table.") - case CreateTableUsingAsSelect(tableName, provider, true, opts, false, query) => - val logicalPlan = sqlContext.parseSql(query) + case CreateTableUsingAsSelect(tableName, provider, true, mode, opts, query) => val cmd = - CreateTempTableUsingAsSelect(tableName, provider, opts, logicalPlan) + CreateTempTableUsingAsSelect(tableName, provider, mode, opts, query) ExecutedCommand(cmd) :: Nil case c: CreateTableUsingAsSelect if !c.temporary => sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.") - case c: CreateTableUsingAsSelect if c.temporary && c.allowExisting => - sys.error("allowExisting should be set to false when creating a temporary table.") - case CreateTableUsingAsLogicalPlan(tableName, provider, true, opts, false, query) => - val cmd = - CreateTempTableUsingAsSelect(tableName, provider, opts, query) - ExecutedCommand(cmd) :: Nil - case c: CreateTableUsingAsLogicalPlan if !c.temporary => - sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.") - case c: CreateTableUsingAsLogicalPlan if c.temporary && c.allowExisting => - sys.error("allowExisting should be set to false when creating a temporary table.") + case LogicalDescribeCommand(table, isExtended) => + val resultPlan = self.sqlContext.executePlan(table).executedPlan + ExecutedCommand( + RunnableDescribeCommand(resultPlan, resultPlan.output, isExtended)) :: Nil case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 16ca4be5587c4..1f5251a20376f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -17,9 +17,6 @@ package org.apache.spark.sql.execution -import scala.collection.mutable.ArrayBuffer -import scala.reflect.runtime.universe.TypeTag - import org.apache.spark.{SparkEnv, HashPartitioner, SparkConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, ShuffledRDD} @@ -27,8 +24,8 @@ import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, OrderedDistribution, SinglePartition, UnspecifiedDistribution} -import org.apache.spark.util.MutablePair +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.util.{CompletionIterator, MutablePair} import org.apache.spark.util.collection.ExternalSorter /** @@ -36,11 +33,11 @@ import org.apache.spark.util.collection.ExternalSorter */ @DeveloperApi case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { - override def output = projectList.map(_.toAttribute) + override def output: Seq[Attribute] = projectList.map(_.toAttribute) @transient lazy val buildProjection = newMutableProjection(projectList, child.output) - def execute() = child.execute().mapPartitions { iter => + override def execute(): RDD[Row] = child.execute().mapPartitions { iter => val resuableProjection = buildProjection() iter.map(resuableProjection) } @@ -51,11 +48,11 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends */ @DeveloperApi case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output - @transient lazy val conditionEvaluator = newPredicate(condition, child.output) + @transient lazy val conditionEvaluator: (Row) => Boolean = newPredicate(condition, child.output) - def execute() = child.execute().mapPartitions { iter => + override def execute(): RDD[Row] = child.execute().mapPartitions { iter => iter.filter(conditionEvaluator) } } @@ -67,10 +64,12 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: SparkPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output // TODO: How to pick seed? - override def execute() = child.execute().map(_.copy()).sample(withReplacement, fraction, seed) + override def execute(): RDD[Row] = { + child.execute().map(_.copy()).sample(withReplacement, fraction, seed) + } } /** @@ -79,8 +78,8 @@ case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: @DeveloperApi case class Union(children: Seq[SparkPlan]) extends SparkPlan { // TODO: attributes output by union should be distinct for nullability purposes - override def output = children.head.output - override def execute() = sparkContext.union(children.map(_.execute())) + override def output: Seq[Attribute] = children.head.output + override def execute(): RDD[Row] = sparkContext.union(children.map(_.execute())) } /** @@ -100,54 +99,12 @@ case class Limit(limit: Int, child: SparkPlan) /** We must copy rows when sort based shuffle is on */ private def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] - override def output = child.output - override def outputPartitioning = SinglePartition - - /** - * A custom implementation modeled after the take function on RDDs but which never runs any job - * locally. This is to avoid shipping an entire partition of data in order to retrieve only a few - * rows. - */ - override def executeCollect(): Array[Row] = { - if (limit == 0) { - return new Array[Row](0) - } - - val childRDD = child.execute().map(_.copy()) - - val buf = new ArrayBuffer[Row] - val totalParts = childRDD.partitions.length - var partsScanned = 0 - while (buf.size < limit && partsScanned < totalParts) { - // The number of partitions to try in this iteration. It is ok for this number to be - // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1 - if (partsScanned > 0) { - // If we didn't find any rows after the first iteration, just try all partitions next. - // Otherwise, interpolate the number of partitions we need to try, but overestimate it - // by 50%. - if (buf.size == 0) { - numPartsToTry = totalParts - 1 - } else { - numPartsToTry = (1.5 * limit * partsScanned / buf.size).toInt - } - } - numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions - - val left = limit - buf.size - val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) - val sc = sqlContext.sparkContext - val res = - sc.runJob(childRDD, (it: Iterator[Row]) => it.take(left).toArray, p, allowLocal = false) + override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = SinglePartition - res.foreach(buf ++= _.take(limit - buf.size)) - partsScanned += numPartsToTry - } - - buf.toArray.map(ScalaReflection.convertRowToScala(_, this.schema)) - } + override def executeCollect(): Array[Row] = child.executeTake(limit) - override def execute() = { + override def execute(): RDD[Row] = { val rdd: RDD[_ <: Product2[Boolean, Row]] = if (sortBasedShuffleOn) { child.execute().mapPartitions { iter => iter.take(limit).map(row => (false, row.copy())) @@ -174,18 +131,21 @@ case class Limit(limit: Int, child: SparkPlan) @DeveloperApi case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode { - override def output = child.output - override def outputPartitioning = SinglePartition + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = SinglePartition - val ord = new RowOrdering(sortOrder, child.output) + private val ord: RowOrdering = new RowOrdering(sortOrder, child.output) + + private def collectData(): Array[Row] = child.execute().map(_.copy()).takeOrdered(limit)(ord) // TODO: Is this copying for no reason? - override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ord) - .map(ScalaReflection.convertRowToScala(_, this.schema)) + override def executeCollect(): Array[Row] = + collectData().map(ScalaReflection.convertRowToScala(_, this.schema)) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. - override def execute() = sparkContext.makeRDD(executeCollect(), 1) + override def execute(): RDD[Row] = sparkContext.makeRDD(collectData(), 1) } /** @@ -200,17 +160,17 @@ case class Sort( global: Boolean, child: SparkPlan) extends UnaryNode { - override def requiredChildDistribution = + override def requiredChildDistribution: Seq[Distribution] = if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - override def execute() = attachTree(this, "sort") { + override def execute(): RDD[Row] = attachTree(this, "sort") { child.execute().mapPartitions( { iterator => val ordering = newOrdering(sortOrder, child.output) iterator.map(_.copy()).toArray.sorted(ordering).iterator }, preservesPartitioning = true) } - override def output = child.output + override def output: Seq[Attribute] = child.output } /** @@ -225,19 +185,22 @@ case class ExternalSort( global: Boolean, child: SparkPlan) extends UnaryNode { - override def requiredChildDistribution = + + override def requiredChildDistribution: Seq[Distribution] = if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - override def execute() = attachTree(this, "sort") { + override def execute(): RDD[Row] = attachTree(this, "sort") { child.execute().mapPartitions( { iterator => val ordering = newOrdering(sortOrder, child.output) val sorter = new ExternalSorter[Row, Null, Row](ordering = Some(ordering)) sorter.insertAll(iterator.map(r => (r, null))) - sorter.iterator.map(_._1) + val baseIterator = sorter.iterator.map(_._1) + // TODO(marmbrus): The complex type signature below thwarts inference for no reason. + CompletionIterator[Row, Iterator[Row]](baseIterator, sorter.stop()) }, preservesPartitioning = true) } - override def output = child.output + override def output: Seq[Attribute] = child.output } /** @@ -249,12 +212,12 @@ case class ExternalSort( */ @DeveloperApi case class Distinct(partial: Boolean, child: SparkPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output - override def requiredChildDistribution = + override def requiredChildDistribution: Seq[Distribution] = if (partial) UnspecifiedDistribution :: Nil else ClusteredDistribution(child.output) :: Nil - override def execute() = { + override def execute(): RDD[Row] = { child.execute().mapPartitions { iter => val hashSet = new scala.collection.mutable.HashSet[Row]() @@ -279,9 +242,9 @@ case class Distinct(partial: Boolean, child: SparkPlan) extends UnaryNode { */ @DeveloperApi case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode { - override def output = left.output + override def output: Seq[Attribute] = left.output - override def execute() = { + override def execute(): RDD[Row] = { left.execute().map(_.copy()).subtract(right.execute().map(_.copy())) } } @@ -293,9 +256,9 @@ case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode { */ @DeveloperApi case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { - override def output = children.head.output + override def output: Seq[Attribute] = children.head.output - override def execute() = { + override def execute(): RDD[Row] = { left.execute().map(_.copy()).intersection(right.execute().map(_.copy())) } } @@ -308,6 +271,7 @@ case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { */ @DeveloperApi case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPlan { - def children = child :: Nil - def execute() = child.execute() + def children: Seq[SparkPlan] = child :: Nil + + def execute(): RDD[Row] = child.execute() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index e1c9a2be7d20d..fad7a281dc1e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -20,9 +20,10 @@ package org.apache.spark.sql.execution import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types.{BooleanType, StructField, StructType, StringType} import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.{Row, Attribute} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row, Attribute} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -52,12 +53,14 @@ case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan { */ protected[sql] lazy val sideEffectResult: Seq[Row] = cmd.run(sqlContext) - override def output = cmd.output + override def output: Seq[Attribute] = cmd.output - override def children = Nil + override def children: Seq[SparkPlan] = Nil override def executeCollect(): Array[Row] = sideEffectResult.toArray + override def executeTake(limit: Int): Array[Row] = sideEffectResult.take(limit).toArray + override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1) } @@ -67,9 +70,10 @@ case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan { @DeveloperApi case class SetCommand( kv: Option[(String, Option[String])], - override val output: Seq[Attribute]) extends RunnableCommand with Logging { + override val output: Seq[Attribute]) + extends RunnableCommand with Logging { - override def run(sqlContext: SQLContext) = kv match { + override def run(sqlContext: SQLContext): Seq[Row] = kv match { // Configures the deprecated "mapred.reduce.tasks" property. case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, Some(value))) => logWarning( @@ -113,10 +117,13 @@ case class SetCommand( @DeveloperApi case class ExplainCommand( logicalPlan: LogicalPlan, - override val output: Seq[Attribute], extended: Boolean = false) extends RunnableCommand { + override val output: Seq[Attribute] = + Seq(AttributeReference("plan", StringType, nullable = false)()), + extended: Boolean = false) + extends RunnableCommand { // Run through the optimizer to generate the physical plan. - override def run(sqlContext: SQLContext) = try { + override def run(sqlContext: SQLContext): Seq[Row] = try { // TODO in Hive, the "extended" ExplainCommand prints the AST as well, and detailed properties. val queryExecution = sqlContext.executePlan(logicalPlan) val outputString = if (extended) queryExecution.toString else queryExecution.simpleString @@ -134,11 +141,12 @@ case class ExplainCommand( case class CacheTableCommand( tableName: String, plan: Option[LogicalPlan], - isLazy: Boolean) extends RunnableCommand { + isLazy: Boolean) + extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { plan.foreach { logicalPlan => - sqlContext.registerRDDAsTable(DataFrame(sqlContext, logicalPlan), tableName) + sqlContext.registerDataFrameAsTable(DataFrame(sqlContext, logicalPlan), tableName) } sqlContext.cacheTable(tableName) @@ -160,7 +168,7 @@ case class CacheTableCommand( @DeveloperApi case class UncacheTableCommand(tableName: String) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { sqlContext.table(tableName).unpersist(blocking = false) Seq.empty[Row] } @@ -168,15 +176,68 @@ case class UncacheTableCommand(tableName: String) extends RunnableCommand { override def output: Seq[Attribute] = Seq.empty } +/** + * :: DeveloperApi :: + * Clear all cached data from the in-memory cache. + */ +@DeveloperApi +case object ClearCacheCommand extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + sqlContext.clearCache() + Seq.empty[Row] + } + + override def output: Seq[Attribute] = Seq.empty +} + /** * :: DeveloperApi :: */ @DeveloperApi case class DescribeCommand( child: SparkPlan, - override val output: Seq[Attribute]) extends RunnableCommand { + override val output: Seq[Attribute], + isExtended: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + child.schema.fields.map { field => + val cmtKey = "comment" + val comment = if (field.metadata.contains(cmtKey)) field.metadata.getString(cmtKey) else "" + Row(field.name, field.dataType.simpleString, comment) + } + } +} + +/** + * A command for users to get tables in the given database. + * If a databaseName is not given, the current database will be used. + * The syntax of using this command in SQL is: + * {{{ + * SHOW TABLES [IN databaseName] + * }}} + * :: DeveloperApi :: + */ +@DeveloperApi +case class ShowTablesCommand(databaseName: Option[String]) extends RunnableCommand { + + // The result of SHOW TABLES has two columns, tableName and isTemporary. + override val output: Seq[Attribute] = { + val schema = StructType( + StructField("tableName", StringType, false) :: + StructField("isTemporary", BooleanType, false) :: Nil) + + schema.toAttributes + } + + override def run(sqlContext: SQLContext): Seq[Row] = { + // Since we need to return a Seq of rows, we will call getTables directly + // instead of calling tables in sqlContext. + val rows = sqlContext.catalog.getTables(databaseName).map { + case (tableName, isTemporary) => Row(tableName, isTemporary) + } - override def run(sqlContext: SQLContext) = { - child.output.map(field => Row(field.name, field.dataType.toString, null)) + rows } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 5cc67cdd13944..e916e68e58b5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.execution +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.Attribute + import scala.collection.mutable.HashSet -import org.apache.spark.{AccumulatorParam, Accumulator, SparkContext} +import org.apache.spark.{AccumulatorParam, Accumulator} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.SparkContext._ -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{SQLConf, SQLContext, DataFrame, Row} import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.types._ @@ -32,11 +34,22 @@ import org.apache.spark.sql.types._ * * Usage: * {{{ - * sql("SELECT key FROM src").debug + * import org.apache.spark.sql.execution.debug._ + * sql("SELECT key FROM src").debug() + * dataFrame.typeCheck() * }}} */ package object debug { + /** + * Augments [[SQLContext]] with debug methods. + */ + implicit class DebugSQLContext(sqlContext: SQLContext) { + def debug(): Unit = { + sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "false") + } + } + /** * :: DeveloperApi :: * Augments [[DataFrame]]s with debug methods. @@ -77,7 +90,7 @@ package object debug { } private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode { - def output = child.output + def output: Seq[Attribute] = child.output implicit object SetAccumulatorParam extends AccumulatorParam[HashSet[String]] { def zero(initialValue: HashSet[String]): HashSet[String] = { @@ -98,10 +111,10 @@ package object debug { */ case class ColumnMetrics( elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty)) - val tupleCount = sparkContext.accumulator[Int](0) + val tupleCount: Accumulator[Int] = sparkContext.accumulator[Int](0) - val numColumns = child.output.size - val columnStats = Array.fill(child.output.size)(new ColumnMetrics()) + val numColumns: Int = child.output.size + val columnStats: Array[ColumnMetrics] = Array.fill(child.output.size)(new ColumnMetrics()) def dumpStats(): Unit = { println(s"== ${child.simpleString} ==") @@ -112,11 +125,11 @@ package object debug { } } - def execute() = { + def execute(): RDD[Row] = { child.execute().mapPartitions { iter => new Iterator[Row] { - def hasNext = iter.hasNext - def next() = { + def hasNext: Boolean = iter.hasNext + def next(): Row = { val currentRow = iter.next() tupleCount += 1 var i = 0 @@ -135,11 +148,9 @@ package object debug { } /** - * :: DeveloperApi :: * Helper functions for checking that runtime types match a given schema. */ - @DeveloperApi - object TypeCheck { + private[sql] object TypeCheck { def typeCheck(data: Any, schema: DataType): Unit = (data, schema) match { case (null, _) => @@ -159,31 +170,30 @@ package object debug { case (_: Short, ShortType) => case (_: Boolean, BooleanType) => case (_: Double, DoubleType) => + case (v, udt: UserDefinedType[_]) => typeCheck(v, udt.sqlType) case (d, t) => sys.error(s"Invalid data found: got $d (${d.getClass}) expected $t") } } /** - * :: DeveloperApi :: * Augments [[DataFrame]]s with debug methods. */ - @DeveloperApi private[sql] case class TypeCheck(child: SparkPlan) extends SparkPlan { import TypeCheck._ - override def nodeName = "" + override def nodeName: String = "" /* Only required when defining this class in a REPL. override def makeCopy(args: Array[Object]): this.type = TypeCheck(args(0).asInstanceOf[SparkPlan]).asInstanceOf[this.type] */ - def output = child.output + def output: Seq[Attribute] = child.output - def children = child :: Nil + def children: List[SparkPlan] = child :: Nil - def execute() = { + def execute(): RDD[Row] = { child.execute().map { row => try typeCheck(row, child.schema) catch { case e: Exception => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 2dd22c020ef12..926f5e6c137ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.execution.joins +import org.apache.spark.rdd.RDD + import scala.concurrent._ import scala.concurrent.duration._ import scala.concurrent.ExecutionContext.Implicits.global import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions.{Row, Expression} -import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} /** @@ -42,7 +44,7 @@ case class BroadcastHashJoin( right: SparkPlan) extends BinaryNode with HashJoin { - val timeout = { + val timeout: Duration = { val timeoutValue = sqlContext.conf.broadcastTimeout if (timeoutValue < 0) { Duration.Inf @@ -53,7 +55,7 @@ case class BroadcastHashJoin( override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning - override def requiredChildDistribution = + override def requiredChildDistribution: Seq[Distribution] = UnspecifiedDistribution :: UnspecifiedDistribution :: Nil @transient @@ -64,7 +66,7 @@ case class BroadcastHashJoin( sparkContext.broadcast(hashed) } - override def execute() = { + override def execute(): RDD[Row] = { val broadcastRelation = Await.result(broadcastFuture, timeout) streamedPlan.execute().mapPartitions { streamedIter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index 2ab064fd0151e..3ef1e0d7fbdd4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.{Expression, Row} -import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} /** @@ -34,11 +34,11 @@ case class BroadcastLeftSemiJoinHash( left: SparkPlan, right: SparkPlan) extends BinaryNode with HashJoin { - override val buildSide = BuildRight + override val buildSide: BuildSide = BuildRight - override def output = left.output + override def output: Seq[Attribute] = left.output - override def execute() = { + override def execute(): RDD[Row] = { val buildIter= buildPlan.execute().map(_.copy()).collect().toIterator val hashSet = new java.util.HashSet[Row]() var currentRow: Row = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 36aad13778bd2..83b1a83765153 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} @@ -44,7 +45,7 @@ case class BroadcastNestedLoopJoin( override def outputPartitioning: Partitioning = streamed.outputPartitioning - override def output = { + override def output: Seq[Attribute] = { joinType match { case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) @@ -63,7 +64,7 @@ case class BroadcastNestedLoopJoin( .map(c => BindReferences.bindReference(c, left.output ++ right.output)) .getOrElse(Literal(true))) - override def execute() = { + override def execute(): RDD[Row] = { val broadcastedRelation = sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index 76c14c02aab34..1cbc98354d673 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.JoinedRow +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} /** @@ -26,9 +28,9 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} */ @DeveloperApi case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { - override def output = left.output ++ right.output + override def output: Seq[Attribute] = left.output ++ right.output - override def execute() = { + override def execute(): RDD[Row] = { val leftResults = left.execute().map(_.copy()) val rightResults = right.execute().map(_.copy()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 4012d757d5f9a..851de1685509a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -41,7 +41,7 @@ trait HashJoin { case BuildRight => (rightKeys, leftKeys) } - override def output = left.output ++ right.output + override def output: Seq[Attribute] = left.output ++ right.output @transient protected lazy val buildSideKeyGenerator: Projection = newProjection(buildKeys, buildPlan.output) @@ -65,7 +65,7 @@ trait HashJoin { (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || (streamIter.hasNext && fetchNext()) - override final def next() = { + override final def next(): Row = { val ret = buildSide match { case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 59ef904272545..a396c0f5d56ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.joins import java.util.{HashMap => JavaHashMap} +import org.apache.spark.rdd.RDD + import scala.collection.JavaConversions._ import org.apache.spark.annotation.DeveloperApi @@ -49,10 +51,10 @@ case class HashOuterJoin( case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") } - override def requiredChildDistribution = + override def requiredChildDistribution: Seq[ClusteredDistribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def output = { + override def output: Seq[Attribute] = { joinType match { case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) @@ -78,12 +80,12 @@ case class HashOuterJoin( private[this] def leftOuterIterator( key: Row, joinedRow: JoinedRow, rightIter: Iterable[Row]): Iterator[Row] = { - val ret: Iterable[Row] = ( + val ret: Iterable[Row] = { if (!key.anyNull) { val temp = rightIter.collect { - case r if (boundCondition(joinedRow.withRight(r))) => joinedRow.copy + case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy() } - if (temp.size == 0) { + if (temp.size == 0) { joinedRow.withRight(rightNullRow).copy :: Nil } else { temp @@ -91,19 +93,19 @@ case class HashOuterJoin( } else { joinedRow.withRight(rightNullRow).copy :: Nil } - ) + } ret.iterator } private[this] def rightOuterIterator( key: Row, leftIter: Iterable[Row], joinedRow: JoinedRow): Iterator[Row] = { - val ret: Iterable[Row] = ( + val ret: Iterable[Row] = { if (!key.anyNull) { val temp = leftIter.collect { - case l if (boundCondition(joinedRow.withLeft(l))) => joinedRow.copy + case l if boundCondition(joinedRow.withLeft(l)) => joinedRow.copy } - if (temp.size == 0) { + if (temp.size == 0) { joinedRow.withLeft(leftNullRow).copy :: Nil } else { temp @@ -111,7 +113,7 @@ case class HashOuterJoin( } else { joinedRow.withLeft(leftNullRow).copy :: Nil } - ) + } ret.iterator } @@ -130,12 +132,12 @@ case class HashOuterJoin( // 1. For those matched (satisfy the join condition) records with both sides filled, // append them directly - case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> { + case (r, idx) if boundCondition(joinedRow.withRight(r)) => matched = true // if the row satisfy the join condition, add its index into the matched set rightMatchedSet.add(idx) - joinedRow.copy - } + joinedRow.copy() + } ++ DUMMY_LIST.filter(_ => !matched).map( _ => { // 2. For those unmatched records in left, append additional records with empty right. @@ -143,22 +145,21 @@ case class HashOuterJoin( // as we don't know whether we need to append it until finish iterating all // of the records in right side. // If we didn't get any proper row, then append a single row with empty right. - joinedRow.withRight(rightNullRow).copy + joinedRow.withRight(rightNullRow).copy() }) } ++ rightIter.zipWithIndex.collect { // 3. For those unmatched records in right, append additional records with empty left. // Re-visiting the records in right, and append additional row with empty left, if its not // in the matched set. - case (r, idx) if (!rightMatchedSet.contains(idx)) => { - joinedRow(leftNullRow, r).copy - } + case (r, idx) if !rightMatchedSet.contains(idx) => + joinedRow(leftNullRow, r).copy() } } else { leftIter.iterator.map[Row] { l => - joinedRow(l, rightNullRow).copy + joinedRow(l, rightNullRow).copy() } ++ rightIter.iterator.map[Row] { r => - joinedRow(leftNullRow, r).copy + joinedRow(leftNullRow, r).copy() } } } @@ -182,13 +183,13 @@ case class HashOuterJoin( hashTable } - override def execute() = { + override def execute(): RDD[Row] = { val joinedRow = new JoinedRow() left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => // TODO this probably can be replaced by external sort (sort merged join?) joinType match { - case LeftOuter => { + case LeftOuter => val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) val keyGenerator = newProjection(leftKeys, left.output) leftIter.flatMap( currentRow => { @@ -196,8 +197,8 @@ case class HashOuterJoin( joinedRow.withLeft(currentRow) leftOuterIterator(rowKey, joinedRow, rightHashTable.getOrElse(rowKey, EMPTY_LIST)) }) - } - case RightOuter => { + + case RightOuter => val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) val keyGenerator = newProjection(rightKeys, right.output) rightIter.flatMap ( currentRow => { @@ -205,8 +206,8 @@ case class HashOuterJoin( joinedRow.withRight(currentRow) rightOuterIterator(rowKey, leftHashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow) }) - } - case FullOuter => { + + case FullOuter => val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => @@ -214,7 +215,7 @@ case class HashOuterJoin( leftHashTable.getOrElse(key, EMPTY_LIST), rightHashTable.getOrElse(key, EMPTY_LIST), joinedRow) } - } + case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 38b8993b03f82..2fa1cf5add3b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -38,7 +38,7 @@ private[joins] sealed trait HashedRelation { private[joins] final class GeneralHashedRelation(hashTable: JavaHashMap[Row, CompactBuffer[Row]]) extends HashedRelation with Serializable { - override def get(key: Row) = hashTable.get(key) + override def get(key: Row): CompactBuffer[Row] = hashTable.get(key) } @@ -49,7 +49,7 @@ private[joins] final class GeneralHashedRelation(hashTable: JavaHashMap[Row, Com private[joins] final class UniqueKeyHashedRelation(hashTable: JavaHashMap[Row, Row]) extends HashedRelation with Serializable { - override def get(key: Row) = { + override def get(key: Row): CompactBuffer[Row] = { val v = hashTable.get(key) if (v eq null) null else CompactBuffer(v) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index 60003d1900d85..1fa7e7bd0406c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -35,12 +36,13 @@ case class LeftSemiJoinBNL( override def outputPartitioning: Partitioning = streamed.outputPartitioning - override def output = left.output + override def output: Seq[Attribute] = left.output /** The Streamed Relation */ - override def left = streamed + override def left: SparkPlan = streamed + /** The Broadcast relation */ - override def right = broadcast + override def right: SparkPlan = broadcast @transient private lazy val boundCondition = InterpretedPredicate( @@ -48,7 +50,7 @@ case class LeftSemiJoinBNL( .map(c => BindReferences.bindReference(c, left.output ++ right.output)) .getOrElse(Literal(true))) - override def execute() = { + override def execute(): RDD[Row] = { val broadcastedRelation = sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index ea7babf3be948..a04f2a63b5a55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.{Expression, Row} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row} import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -34,14 +35,14 @@ case class LeftSemiJoinHash( left: SparkPlan, right: SparkPlan) extends BinaryNode with HashJoin { - override val buildSide = BuildRight + override val buildSide: BuildSide = BuildRight - override def requiredChildDistribution = + override def requiredChildDistribution: Seq[ClusteredDistribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def output = left.output + override def output: Seq[Attribute] = left.output - override def execute() = { + override def execute(): RDD[Row] = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => val hashSet = new java.util.HashSet[Row]() var currentRow: Row = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 418c1c23e5546..a6cd8337c1c3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -38,10 +40,10 @@ case class ShuffledHashJoin( override def outputPartitioning: Partitioning = left.outputPartitioning - override def requiredChildDistribution = + override def requiredChildDistribution: Seq[ClusteredDistribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def execute() = { + override def execute(): RDD[Row] = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => val hashed = HashedRelation(buildIter, buildSideKeyGenerator) hashJoin(streamIter, hashed) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index b85021acc9d4c..5b308d88d4cdf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution import java.util.{List => JList, Map => JMap} +import org.apache.spark.rdd.RDD + import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ @@ -48,11 +50,13 @@ private[spark] case class PythonUDF( dataType: DataType, children: Seq[Expression]) extends Expression with SparkLogging { - override def toString = s"PythonUDF#$name(${children.mkString(",")})" + override def toString: String = s"PythonUDF#$name(${children.mkString(",")})" def nullable: Boolean = true - override def eval(input: Row) = sys.error("PythonUDFs can not be directly evaluated.") + override def eval(input: Row): PythonUDF.this.EvaluatedType = { + sys.error("PythonUDFs can not be directly evaluated.") + } } /** @@ -63,7 +67,7 @@ private[spark] case class PythonUDF( * multiple child operators. */ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan) = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Skip EvaluatePython nodes. case p: EvaluatePython => p @@ -107,7 +111,7 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { } object EvaluatePython { - def apply(udf: PythonUDF, child: LogicalPlan) = + def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython = new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)()) /** @@ -135,6 +139,8 @@ object EvaluatePython { case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType) + case (date: Int, DateType) => DateUtils.toJavaDate(date) + // Pyrolite can handle Timestamp and Decimal case (other, _) => other } @@ -171,7 +177,7 @@ object EvaluatePython { }): Row case (c: java.util.Calendar, DateType) => - new java.sql.Date(c.getTime().getTime()) + DateUtils.fromJavaDate(new java.sql.Date(c.getTime().getTime())) case (c: java.util.Calendar, TimestampType) => new java.sql.Timestamp(c.getTime().getTime()) @@ -184,6 +190,7 @@ object EvaluatePython { case (c: Int, ShortType) => c.toShort case (c: Long, ShortType) => c.toShort case (c: Long, IntegerType) => c.toInt + case (c: Int, LongType) => c.toLong case (c: Double, FloatType) => c.toFloat case (c, StringType) if !c.isInstanceOf[String] => c.toString @@ -202,7 +209,10 @@ case class EvaluatePython( resultAttribute: AttributeReference) extends logical.UnaryNode { - def output = child.output :+ resultAttribute + def output: Seq[Attribute] = child.output :+ resultAttribute + + // References should not include the produced attribute. + override def references: AttributeSet = udf.references } /** @@ -213,9 +223,10 @@ case class EvaluatePython( @DeveloperApi case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) extends SparkPlan { - def children = child :: Nil - def execute() = { + def children: Seq[SparkPlan] = child :: Nil + + def execute(): RDD[Row] = { // TODO: Clean up after ourselves? val childResults = child.execute().map(_.copy()).cache() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala new file mode 100644 index 0000000000000..111e751588a8b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -0,0 +1,608 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.{TypeTag, typeTag} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + + +/** + * :: Experimental :: + * Functions available for [[DataFrame]]. + * + * @groupname udf_funcs UDF functions + * @groupname agg_funcs Aggregate functions + * @groupname sort_funcs Sorting functions + * @groupname normal_funcs Non-aggregate functions + * @groupname Ungrouped Support functions for DataFrames. + */ +@Experimental +// scalastyle:off +object functions { +// scalastyle:on + + private[this] implicit def toColumn(expr: Expression): Column = Column(expr) + + /** + * Returns a [[Column]] based on the given column name. + * + * @group normal_funcs + */ + def col(colName: String): Column = Column(colName) + + /** + * Returns a [[Column]] based on the given column name. Alias of [[col]]. + * + * @group normal_funcs + */ + def column(colName: String): Column = Column(colName) + + /** + * Creates a [[Column]] of literal value. + * + * The passed in object is returned directly if it is already a [[Column]]. + * If the object is a Scala Symbol, it is converted into a [[Column]] also. + * Otherwise, a new [[Column]] is created to represent the literal value. + * + * @group normal_funcs + */ + def lit(literal: Any): Column = { + literal match { + case c: Column => return c + case s: Symbol => return new ColumnName(literal.asInstanceOf[Symbol].name) + case _ => // continue + } + + val literalExpr = Literal(literal) + Column(literalExpr) + } + + ////////////////////////////////////////////////////////////////////////////////////////////// + // Sort functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Returns a sort expression based on ascending order of the column. + * {{ + * // Sort by dept in ascending order, and then age in descending order. + * df.sort(asc("dept"), desc("age")) + * }} + * + * @group sort_funcs + */ + def asc(columnName: String): Column = Column(columnName).asc + + /** + * Returns a sort expression based on the descending order of the column. + * {{ + * // Sort by dept in ascending order, and then age in descending order. + * df.sort(asc("dept"), desc("age")) + * }} + * + * @group sort_funcs + */ + def desc(columnName: String): Column = Column(columnName).desc + + ////////////////////////////////////////////////////////////////////////////////////////////// + // Aggregate functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Aggregate function: returns the sum of all values in the expression. + * + * @group agg_funcs + */ + def sum(e: Column): Column = Sum(e.expr) + + /** + * Aggregate function: returns the sum of all values in the given column. + * + * @group agg_funcs + */ + def sum(columnName: String): Column = sum(Column(columnName)) + + /** + * Aggregate function: returns the sum of distinct values in the expression. + * + * @group agg_funcs + */ + def sumDistinct(e: Column): Column = SumDistinct(e.expr) + + /** + * Aggregate function: returns the sum of distinct values in the expression. + * + * @group agg_funcs + */ + def sumDistinct(columnName: String): Column = sumDistinct(Column(columnName)) + + /** + * Aggregate function: returns the number of items in a group. + * + * @group agg_funcs + */ + def count(e: Column): Column = e.expr match { + // Turn count(*) into count(1) + case s: Star => Count(Literal(1)) + case _ => Count(e.expr) + } + + /** + * Aggregate function: returns the number of items in a group. + * + * @group agg_funcs + */ + def count(columnName: String): Column = count(Column(columnName)) + + /** + * Aggregate function: returns the number of distinct items in a group. + * + * @group agg_funcs + */ + @scala.annotation.varargs + def countDistinct(expr: Column, exprs: Column*): Column = + CountDistinct((expr +: exprs).map(_.expr)) + + /** + * Aggregate function: returns the number of distinct items in a group. + * + * @group agg_funcs + */ + @scala.annotation.varargs + def countDistinct(columnName: String, columnNames: String*): Column = + countDistinct(Column(columnName), columnNames.map(Column.apply) :_*) + + /** + * Aggregate function: returns the approximate number of distinct items in a group. + * + * @group agg_funcs + */ + def approxCountDistinct(e: Column): Column = ApproxCountDistinct(e.expr) + + /** + * Aggregate function: returns the approximate number of distinct items in a group. + * + * @group agg_funcs + */ + def approxCountDistinct(columnName: String): Column = approxCountDistinct(column(columnName)) + + /** + * Aggregate function: returns the approximate number of distinct items in a group. + * + * @group agg_funcs + */ + def approxCountDistinct(e: Column, rsd: Double): Column = ApproxCountDistinct(e.expr, rsd) + + /** + * Aggregate function: returns the approximate number of distinct items in a group. + * + * @group agg_funcs + */ + def approxCountDistinct(columnName: String, rsd: Double): Column = { + approxCountDistinct(Column(columnName), rsd) + } + + /** + * Aggregate function: returns the average of the values in a group. + * + * @group agg_funcs + */ + def avg(e: Column): Column = Average(e.expr) + + /** + * Aggregate function: returns the average of the values in a group. + * + * @group agg_funcs + */ + def avg(columnName: String): Column = avg(Column(columnName)) + + /** + * Aggregate function: returns the first value in a group. + * + * @group agg_funcs + */ + def first(e: Column): Column = First(e.expr) + + /** + * Aggregate function: returns the first value of a column in a group. + * + * @group agg_funcs + */ + def first(columnName: String): Column = first(Column(columnName)) + + /** + * Aggregate function: returns the last value in a group. + * + * @group agg_funcs + */ + def last(e: Column): Column = Last(e.expr) + + /** + * Aggregate function: returns the last value of the column in a group. + * + * @group agg_funcs + */ + def last(columnName: String): Column = last(Column(columnName)) + + /** + * Aggregate function: returns the minimum value of the expression in a group. + * + * @group agg_funcs + */ + def min(e: Column): Column = Min(e.expr) + + /** + * Aggregate function: returns the minimum value of the column in a group. + * + * @group agg_funcs + */ + def min(columnName: String): Column = min(Column(columnName)) + + /** + * Aggregate function: returns the maximum value of the expression in a group. + * + * @group agg_funcs + */ + def max(e: Column): Column = Max(e.expr) + + /** + * Aggregate function: returns the maximum value of the column in a group. + * + * @group agg_funcs + */ + def max(columnName: String): Column = max(Column(columnName)) + + ////////////////////////////////////////////////////////////////////////////////////////////// + // Non-aggregate functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Returns the first column that is not null. + * {{{ + * df.select(coalesce(df("a"), df("b"))) + * }}} + * + * @group normal_funcs + */ + @scala.annotation.varargs + def coalesce(e: Column*): Column = Coalesce(e.map(_.expr)) + + /** + * Unary minus, i.e. negate the expression. + * {{{ + * // Select the amount column and negates all values. + * // Scala: + * df.select( -df("amount") ) + * + * // Java: + * df.select( negate(df.col("amount")) ); + * }}} + * + * @group normal_funcs + */ + def negate(e: Column): Column = -e + + /** + * Inversion of boolean expression, i.e. NOT. + * {{ + * // Scala: select rows that are not active (isActive === false) + * df.filter( !df("isActive") ) + * + * // Java: + * df.filter( not(df.col("isActive")) ); + * }} + * + * @group normal_funcs + */ + def not(e: Column): Column = !e + + /** + * Converts a string expression to upper case. + * + * @group normal_funcs + */ + def upper(e: Column): Column = Upper(e.expr) + + /** + * Converts a string exprsesion to lower case. + * + * @group normal_funcs + */ + def lower(e: Column): Column = Lower(e.expr) + + /** + * Computes the square root of the specified float value. + * + * @group normal_funcs + */ + def sqrt(e: Column): Column = Sqrt(e.expr) + + /** + * Computes the absolutle value. + * + * @group normal_funcs + */ + def abs(e: Column): Column = Abs(e.expr) + + ////////////////////////////////////////////////////////////////////////////////////////////// + ////////////////////////////////////////////////////////////////////////////////////////////// + + // scalastyle:off + + /* Use the following code to generate: + (0 to 10).map { x => + val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) + val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) + println(s""" + /** + * Defines a user-defined function of ${x} arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + */ + def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + }""") + } + + (0 to 10).map { x => + val args = (1 to x).map(i => s"arg$i: Column").mkString(", ") + val fTypes = Seq.fill(x + 1)("_").mkString(", ") + val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ") + println(s""" + /** + * Call a Scala function of ${x} arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + */ + def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = { + ScalaUdf(f, returnType, Seq($argsInUdf)) + }""") + } + } + */ + /** + * Defines a user-defined function of 0 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + */ + def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = { + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + } + + /** + * Defines a user-defined function of 1 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + */ + def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = { + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + } + + /** + * Defines a user-defined function of 2 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = { + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + } + + /** + * Defines a user-defined function of 3 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = { + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + } + + /** + * Defines a user-defined function of 4 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + } + + /** + * Defines a user-defined function of 5 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + } + + /** + * Defines a user-defined function of 6 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + } + + /** + * Defines a user-defined function of 7 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + } + + /** + * Defines a user-defined function of 8 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + } + + /** + * Defines a user-defined function of 9 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + } + + /** + * Defines a user-defined function of 10 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + } + + ////////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Call a Scala function of 0 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + */ + def callUDF(f: Function0[_], returnType: DataType): Column = { + ScalaUdf(f, returnType, Seq()) + } + + /** + * Call a Scala function of 1 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + */ + def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr)) + } + + /** + * Call a Scala function of 2 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + */ + def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr)) + } + + /** + * Call a Scala function of 3 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + */ + def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) + } + + /** + * Call a Scala function of 4 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + */ + def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) + } + + /** + * Call a Scala function of 5 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + */ + def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) + } + + /** + * Call a Scala function of 6 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + */ + def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) + } + + /** + * Call a Scala function of 7 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + */ + def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) + } + + /** + * Call a Scala function of 8 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + */ + def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) + } + + /** + * Call a Scala function of 9 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + */ + def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) + } + + /** + * Call a Scala function of 10 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + */ + def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) + } + + // scalastyle:on +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index a2f94675fb5a3..463e1dcc268bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.jdbc -import java.sql.{Connection, DatabaseMetaData, DriverManager, ResultSet, ResultSetMetaData, SQLException} -import scala.collection.mutable.ArrayBuffer +import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException} +import java.util.Properties +import org.apache.commons.lang.StringEscapeUtils.escapeSql import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD -import org.apache.spark.util.NextIterator -import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.catalyst.expressions.{Row, SpecificMutableRow} import org.apache.spark.sql.types._ import org.apache.spark.sql.sources._ @@ -92,15 +91,15 @@ private[sql] object JDBCRDD extends Logging { * @throws SQLException if the table specification is garbage. * @throws SQLException if the table contains an unsupported type. */ - def resolveTable(url: String, table: String): StructType = { + def resolveTable(url: String, table: String, properties: Properties): StructType = { val quirks = DriverQuirks.get(url) - val conn: Connection = DriverManager.getConnection(url) + val conn: Connection = DriverManager.getConnection(url, properties) try { val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery() try { val rsmd = rs.getMetaData val ncols = rsmd.getColumnCount - var fields = new Array[StructField](ncols); + val fields = new Array[StructField](ncols) var i = 0 while (i < ncols) { val columnName = rsmd.getColumnName(i + 1) @@ -149,7 +148,7 @@ private[sql] object JDBCRDD extends Logging { * * @return A function that loads the driver and connects to the url. */ - def getConnector(driver: String, url: String): () => Connection = { + def getConnector(driver: String, url: String, properties: Properties): () => Connection = { () => { try { if (driver != null) Class.forName(driver) @@ -158,7 +157,7 @@ private[sql] object JDBCRDD extends Logging { logWarning(s"Couldn't find class $driver", e); } } - DriverManager.getConnection(url) + DriverManager.getConnection(url, properties) } } /** @@ -176,23 +175,28 @@ private[sql] object JDBCRDD extends Logging { * * @return An RDD representing "SELECT requiredColumns FROM fqTable". */ - def scanTable(sc: SparkContext, - schema: StructType, - driver: String, - url: String, - fqTable: String, - requiredColumns: Array[String], - filters: Array[Filter], - parts: Array[Partition]): RDD[Row] = { + def scanTable( + sc: SparkContext, + schema: StructType, + driver: String, + url: String, + properties: Properties, + fqTable: String, + requiredColumns: Array[String], + filters: Array[Filter], + parts: Array[Partition]): RDD[Row] = { + val prunedSchema = pruneSchema(schema, requiredColumns) - return new JDBCRDD(sc, - getConnector(driver, url), - prunedSchema, - fqTable, - requiredColumns, - filters, - parts) + return new + JDBCRDD( + sc, + getConnector(driver, url, properties), + prunedSchema, + fqTable, + requiredColumns, + filters, + parts) } } @@ -225,16 +229,24 @@ private[sql] class JDBCRDD( if (sb.length == 0) "1" else sb.substring(1) } + /** + * Converts value to SQL expression. + */ + private def compileValue(value: Any): Any = value match { + case stringValue: String => s"'${escapeSql(stringValue)}'" + case _ => value + } + /** * Turns a single Filter into a String representing a SQL expression. * Returns null for an unhandled filter. */ private def compileFilter(f: Filter): String = f match { - case EqualTo(attr, value) => s"$attr = $value" - case LessThan(attr, value) => s"$attr < $value" - case GreaterThan(attr, value) => s"$attr > $value" - case LessThanOrEqual(attr, value) => s"$attr <= $value" - case GreaterThanOrEqual(attr, value) => s"$attr >= $value" + case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" + case LessThan(attr, value) => s"$attr < ${compileValue(value)}" + case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" + case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" + case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}" case _ => null } @@ -305,7 +317,8 @@ private[sql] class JDBCRDD( /** * Runs the SQL query against the JDBC driver. */ - override def compute(thePart: Partition, context: TaskContext) = new Iterator[Row] { + override def compute(thePart: Partition, context: TaskContext): Iterator[Row] = new Iterator[Row] + { var closed = false var finished = false var gotNext = false @@ -350,7 +363,7 @@ private[sql] class JDBCRDD( var ans = 0L var j = 0 while (j < bytes.size) { - ans = 256*ans + (255 & bytes(j)) + ans = 256 * ans + (255 & bytes(j)) j = j + 1; } mutableRow.setLong(i, ans) @@ -369,21 +382,21 @@ private[sql] class JDBCRDD( def close() { if (closed) return try { - if (null != rs && ! rs.isClosed()) { + if (null != rs) { rs.close() } } catch { case e: Exception => logWarning("Exception closing resultset", e) } try { - if (null != stmt && ! stmt.isClosed()) { + if (null != stmt) { stmt.close() } } catch { case e: Exception => logWarning("Exception closing statement", e) } try { - if (null != conn && ! conn.isClosed()) { + if (null != conn) { conn.close() } logInfo("closed connection") @@ -412,6 +425,5 @@ private[sql] class JDBCRDD( gotNext = false nextValue } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index e09125e406ba2..4fa84dc076f7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -17,12 +17,17 @@ package org.apache.spark.sql.jdbc -import scala.collection.mutable.ArrayBuffer import java.sql.DriverManager +import java.util.Properties + +import scala.collection.mutable.ArrayBuffer import org.apache.spark.Partition +import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.StructType /** * Data corresponding to one partition of a JDBCRDD. @@ -48,11 +53,6 @@ private[sql] object JDBCRelation { * exactly once. The parameters minValue and maxValue are advisory in that * incorrect values may cause the partitioning to be poor, but no data * will fail to be represented. - * - * @param column - Column name. Must refer to a column of integral type. - * @param numPartitions - Number of partitions - * @param minValue - Smallest value of column. Advisory. - * @param maxValue - Largest value of column. Advisory. */ def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = { if (partitioning == null) return Array[Partition](JDBCPartition(null, 0)) @@ -68,12 +68,17 @@ private[sql] object JDBCRelation { var currentValue: Long = partitioning.lowerBound var ans = new ArrayBuffer[Partition]() while (i < numPartitions) { - val lowerBound = (if (i != 0) s"$column >= $currentValue" else null) + val lowerBound = if (i != 0) s"$column >= $currentValue" else null currentValue += stride - val upperBound = (if (i != numPartitions - 1) s"$column < $currentValue" else null) - val whereClause = (if (upperBound == null) lowerBound - else if (lowerBound == null) upperBound - else s"$lowerBound AND $upperBound") + val upperBound = if (i != numPartitions - 1) s"$column < $currentValue" else null + val whereClause = + if (upperBound == null) { + lowerBound + } else if (lowerBound == null) { + upperBound + } else { + s"$lowerBound AND $upperBound" + } ans += JDBCPartition(whereClause, i) i = i + 1 } @@ -96,7 +101,7 @@ private[sql] class DefaultSource extends RelationProvider { if (driver != null) Class.forName(driver) - if ( partitionColumn != null + if (partitionColumn != null && (lowerBound == null || upperBound == null || numPartitions == null)) { sys.error("Partitioning incompletely specified") } @@ -104,30 +109,40 @@ private[sql] class DefaultSource extends RelationProvider { val partitionInfo = if (partitionColumn == null) { null } else { - JDBCPartitioningInfo(partitionColumn, - lowerBound.toLong, upperBound.toLong, - numPartitions.toInt) + JDBCPartitioningInfo( + partitionColumn, + lowerBound.toLong, + upperBound.toLong, + numPartitions.toInt) } val parts = JDBCRelation.columnPartition(partitionInfo) - JDBCRelation(url, table, parts)(sqlContext) + val properties = new Properties() // Additional properties that we will pass to getConnection + parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) + JDBCRelation(url, table, parts, properties)(sqlContext) } } -private[sql] case class JDBCRelation(url: String, - table: String, - parts: Array[Partition])( - @transient val sqlContext: SQLContext) - extends PrunedFilteredScan { +private[sql] case class JDBCRelation( + url: String, + table: String, + parts: Array[Partition], + properties: Properties = new Properties())(@transient val sqlContext: SQLContext) + extends BaseRelation + with PrunedFilteredScan { - override val schema = JDBCRDD.resolveTable(url, table) + override val schema: StructType = JDBCRDD.resolveTable(url, table, properties) - override def buildScan(requiredColumns: Array[String], filters: Array[Filter]) = { + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { val driver: String = DriverManager.getDriver(url).getClass.getCanonicalName - JDBCRDD.scanTable(sqlContext.sparkContext, - schema, - driver, url, - table, - requiredColumns, filters, - parts) + JDBCRDD.scanTable( + sqlContext.sparkContext, + schema, + driver, + url, + properties, + table, + requiredColumns, + filters, + parts) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala index 34a83f0a5dad8..34f864f5fda7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -26,11 +26,11 @@ import org.apache.spark.sql.jdbc.{JDBCPartitioningInfo, JDBCRelation, JDBCPartit import org.apache.spark.sql.types._ package object jdbc { - object JDBCWriteDetails extends Logging { + private[sql] object JDBCWriteDetails extends Logging { /** * Returns a PreparedStatement that inserts a row into table via conn. */ - private def insertStatement(conn: Connection, table: String, rddSchema: StructType): + def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = { val sql = new StringBuilder(s"INSERT INTO $table VALUES (") var fieldsLeft = rddSchema.fields.length @@ -56,7 +56,7 @@ package object jdbc { * non-Serializable. Instead, we explicitly close over all variables that * are used. */ - private[jdbc] def savePartition(url: String, table: String, iterator: Iterator[Row], + def savePartition(url: String, table: String, iterator: Iterator[Row], rddSchema: StructType, nullTypes: Array[Int]): Iterator[Byte] = { val conn = DriverManager.getConnection(url) var committed = false @@ -117,19 +117,14 @@ package object jdbc { } Array[Byte]().iterator } - } - /** - * Make it so that you can call createJDBCTable and insertIntoJDBC on a DataFrame. - */ - implicit class JDBCDataFrame(rdd: DataFrame) { /** * Compute the schema string for this RDD. */ - private def schemaString(url: String): String = { + def schemaString(df: DataFrame, url: String): String = { val sb = new StringBuilder() val quirks = DriverQuirks.get(url) - rdd.schema.fields foreach { field => { + df.schema.fields foreach { field => { val name = field.name var typ: String = quirks.getJDBCType(field.dataType)._1 if (typ == null) typ = field.dataType match { @@ -156,9 +151,9 @@ package object jdbc { /** * Saves the RDD to the database in a single transaction. */ - private def saveTable(url: String, table: String) { + def saveTable(df: DataFrame, url: String, table: String) { val quirks = DriverQuirks.get(url) - var nullTypes: Array[Int] = rdd.schema.fields.map(field => { + var nullTypes: Array[Int] = df.schema.fields.map(field => { var nullType: Option[Int] = quirks.getJDBCType(field.dataType)._2 if (nullType.isEmpty) { field.dataType match { @@ -175,61 +170,16 @@ package object jdbc { case DateType => java.sql.Types.DATE case DecimalType.Unlimited => java.sql.Types.DECIMAL case _ => throw new IllegalArgumentException( - s"Can't translate null value for field $field") + s"Can't translate null value for field $field") } } else nullType.get }).toArray - val rddSchema = rdd.schema - rdd.mapPartitions(iterator => JDBCWriteDetails.savePartition( - url, table, iterator, rddSchema, nullTypes)).collect() - } - - /** - * Save this RDD to a JDBC database at `url` under the table name `table`. - * This will run a `CREATE TABLE` and a bunch of `INSERT INTO` statements. - * If you pass `true` for `allowExisting`, it will drop any table with the - * given name; if you pass `false`, it will throw if the table already - * exists. - */ - def createJDBCTable(url: String, table: String, allowExisting: Boolean) { - val conn = DriverManager.getConnection(url) - try { - if (allowExisting) { - val sql = s"DROP TABLE IF EXISTS $table" - conn.prepareStatement(sql).executeUpdate() - } - val schema = schemaString(url) - val sql = s"CREATE TABLE $table ($schema)" - conn.prepareStatement(sql).executeUpdate() - } finally { - conn.close() + val rddSchema = df.schema + df.foreachPartition { iterator => + JDBCWriteDetails.savePartition(url, table, iterator, rddSchema, nullTypes) } - saveTable(url, table) } - /** - * Save this RDD to a JDBC database at `url` under the table name `table`. - * Assumes the table already exists and has a compatible schema. If you - * pass `true` for `overwrite`, it will `TRUNCATE` the table before - * performing the `INSERT`s. - * - * The table must already exist on the database. It must have a schema - * that is compatible with the schema of this RDD; inserting the rows of - * the RDD in order via the simple statement - * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail. - */ - def insertIntoJDBC(url: String, table: String, overwrite: Boolean) { - if (overwrite) { - val conn = DriverManager.getConnection(url) - try { - val sql = s"TRUNCATE TABLE $table" - conn.prepareStatement(sql).executeUpdate() - } finally { - conn.close() - } - } - saveTable(url, table) - } - } // implicit class JDBCDataFrame + } } // package object jdbc diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index 8372decbf8aa1..f4c99b4b56606 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -21,19 +21,25 @@ import java.io.IOException import org.apache.hadoop.fs.Path import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.catalyst.expressions.Row + +import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} private[sql] class DefaultSource - extends RelationProvider with SchemaRelationProvider with CreateableRelationProvider { + extends RelationProvider with SchemaRelationProvider with CreatableRelationProvider { + + private def checkPath(parameters: Map[String, String]): String = { + parameters.getOrElse("path", sys.error("'path' must be specified for json data.")) + } /** Returns a new base relation with the parameters. */ override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { - val path = parameters.getOrElse("path", sys.error("Option 'path' not specified")) + val path = checkPath(parameters) val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) JSONRelation(path, samplingRatio, None)(sqlContext) @@ -44,7 +50,7 @@ private[sql] class DefaultSource sqlContext: SQLContext, parameters: Map[String, String], schema: StructType): BaseRelation = { - val path = parameters.getOrElse("path", sys.error("Option 'path' not specified")) + val path = checkPath(parameters) val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) JSONRelation(path, samplingRatio, Some(schema))(sqlContext) @@ -52,15 +58,44 @@ private[sql] class DefaultSource override def createRelation( sqlContext: SQLContext, + mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation = { - val path = parameters.getOrElse("path", sys.error("Option 'path' not specified")) + val path = checkPath(parameters) val filesystemPath = new Path(path) val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - if (fs.exists(filesystemPath)) { - sys.error(s"path $path already exists.") + val doSave = if (fs.exists(filesystemPath)) { + mode match { + case SaveMode.Append => + sys.error(s"Append mode is not supported by ${this.getClass.getCanonicalName}") + case SaveMode.Overwrite => { + var success: Boolean = false + try { + success = fs.delete(filesystemPath, true) + } catch { + case e: IOException => + throw new IOException( + s"Unable to clear output directory ${filesystemPath.toString} prior" + + s" to writing to JSON table:\n${e.toString}") + } + if (!success) { + throw new IOException( + s"Unable to clear output directory ${filesystemPath.toString} prior" + + s" to writing to JSON table.") + } + true + } + case SaveMode.ErrorIfExists => + sys.error(s"path $path already exists.") + case SaveMode.Ignore => false + } + } else { + true + } + if (doSave) { + // Only save data when the save mode is not ignore. + data.toJSON.saveAsTextFile(path) } - data.toJSON.saveAsTextFile(path) createRelation(sqlContext, parameters, data.schema) } @@ -71,7 +106,9 @@ private[sql] case class JSONRelation( samplingRatio: Double, userSpecifiedSchema: Option[StructType])( @transient val sqlContext: SQLContext) - extends TableScan with InsertableRelation { + extends BaseRelation + with TableScan + with InsertableRelation { // TODO: Support partitioned JSON relation. private def baseRDD = sqlContext.sparkContext.textFile(path) @@ -83,26 +120,45 @@ private[sql] case class JSONRelation( samplingRatio, sqlContext.conf.columnNameOfCorruptRecord))) - override def buildScan() = + override def buildScan(): RDD[Row] = JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.conf.columnNameOfCorruptRecord) - override def insert(data: DataFrame, overwrite: Boolean) = { + override def insert(data: DataFrame, overwrite: Boolean): Unit = { val filesystemPath = new Path(path) val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) if (overwrite) { - try { - fs.delete(filesystemPath, true) - } catch { - case e: IOException => + if (fs.exists(filesystemPath)) { + var success: Boolean = false + try { + success = fs.delete(filesystemPath, true) + } catch { + case e: IOException => + throw new IOException( + s"Unable to clear output directory ${filesystemPath.toString} prior" + + s" to writing to JSON table:\n${e.toString}") + } + if (!success) { throw new IOException( s"Unable to clear output directory ${filesystemPath.toString} prior" - + s" to INSERT OVERWRITE a JSON table:\n${e.toString}") + + s" to writing to JSON table.") + } } + // Write the data. data.toJSON.saveAsTextFile(path) + // Right now, we assume that the schema is not changed. We will not update the schema. + // schema = data.schema } else { // TODO: Support INSERT INTO sys.error("JSON table only support INSERT OVERWRITE for now.") } } + + override def hashCode(): Int = 41 * (41 + path.hashCode) + schema.hashCode() + + override def equals(other: Any): Boolean = other match { + case that: JSONRelation => + (this.path == that.path) && this.schema.sameType(that.schema) + case _ => false + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 9171939f7e8f7..0b770f2251943 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -17,14 +17,12 @@ package org.apache.spark.sql.json -import java.io.StringWriter -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import scala.collection.Map import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} -import com.fasterxml.jackson.core.JsonProcessingException -import com.fasterxml.jackson.core.JsonFactory +import com.fasterxml.jackson.core.{JsonGenerator, JsonProcessingException} import com.fasterxml.jackson.databind.ObjectMapper import org.apache.spark.rdd.RDD @@ -50,7 +48,11 @@ private[sql] object JsonRDD extends Logging { require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0") val schemaData = if (samplingRatio > 0.99) json else json.sample(false, samplingRatio, 1) val allKeys = - parseJson(schemaData, columnNameOfCorruptRecords).map(allKeysWithValueTypes).reduce(_ ++ _) + if (schemaData.isEmpty()) { + Set.empty[(String, DataType)] + } else { + parseJson(schemaData, columnNameOfCorruptRecords).map(allKeysWithValueTypes).reduce(_ ++ _) + } createSchema(allKeys) } @@ -179,7 +181,12 @@ private[sql] object JsonRDD extends Logging { } private def typeOfPrimitiveValue: PartialFunction[Any, DataType] = { - ScalaReflection.typeOfObject orElse { + // For Integer values, use LongType by default. + val useLongType: PartialFunction[Any, DataType] = { + case value: IntegerType.JvmType => LongType + } + + useLongType orElse ScalaReflection.typeOfObject orElse { // Since we do not have a data type backed by BigInteger, // when we see a Java BigInteger, we use DecimalType. case value: java.math.BigInteger => DecimalType.Unlimited @@ -196,13 +203,12 @@ private[sql] object JsonRDD extends Logging { * type conflicts. */ private def typeOfArray(l: Seq[Any]): ArrayType = { - val containsNull = l.exists(v => v == null) val elements = l.flatMap(v => Option(v)) if (elements.isEmpty) { // If this JSON array is empty, we use NullType as a placeholder. // If this array is not empty in other JSON objects, we can resolve // the type after we have passed through all JSON objects. - ArrayType(NullType, containsNull) + ArrayType(NullType, containsNull = true) } else { val elementType = elements.map { e => e match { @@ -214,7 +220,7 @@ private[sql] object JsonRDD extends Logging { } }.reduce((type1: DataType, type2: DataType) => compatibleType(type1, type2)) - ArrayType(elementType, containsNull) + ArrayType(elementType, containsNull = true) } } @@ -242,7 +248,7 @@ private[sql] object JsonRDD extends Logging { // The value associated with the key is an array. // Handle inner structs of an array. def buildKeyPathForInnerStructs(v: Any, t: DataType): Seq[(String, DataType)] = t match { - case ArrayType(e: StructType, containsNull) => { + case ArrayType(e: StructType, _) => { // The elements of this arrays are structs. v.asInstanceOf[Seq[Map[String, Any]]].flatMap(Option(_)).flatMap { element => allKeysWithValueTypes(element) @@ -250,7 +256,7 @@ private[sql] object JsonRDD extends Logging { case (k, t) => (s"$key.$k", t) } } - case ArrayType(t1, containsNull) => + case ArrayType(t1, _) => v.asInstanceOf[Seq[Any]].flatMap(Option(_)).flatMap { element => buildKeyPathForInnerStructs(element, t1) } @@ -303,6 +309,10 @@ private[sql] object JsonRDD extends Logging { val parsed = mapper.readValue(record, classOf[Object]) match { case map: java.util.Map[_, _] => scalafy(map).asInstanceOf[Map[String, Any]] :: Nil case list: java.util.List[_] => scalafy(list).asInstanceOf[Seq[Map[String, Any]]] + case _ => + sys.error( + s"Failed to parse record $record. Please make sure that each line of the file " + + "(or each string in the RDD) is a valid JSON object or an array of JSON objects.") } parsed @@ -377,10 +387,12 @@ private[sql] object JsonRDD extends Logging { } } - private def toDate(value: Any): Date = { + private def toDate(value: Any): Int = { value match { // only support string as date - case value: java.lang.String => new Date(DataTypeConversions.stringToTime(value).getTime) + case value: java.lang.String => + DateUtils.millisToDays(DataTypeConversions.stringToTime(value).getTime) + case value: java.sql.Date => DateUtils.fromJavaDate(value) } } @@ -407,6 +419,9 @@ private[sql] object JsonRDD extends Logging { case NullType => null case ArrayType(elementType, _) => value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) + case MapType(StringType, valueType, _) => + val map = value.asInstanceOf[Map[String, Any]] + map.mapValues(enforceCorrectType(_, valueType)).map(identity) case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct) case DateType => toDate(value) case TimestampType => toTimestamp(value) @@ -428,14 +443,11 @@ private[sql] object JsonRDD extends Logging { /** Transforms a single Row to JSON using Jackson * - * @param jsonFactory a JsonFactory object to construct a JsonGenerator * @param rowSchema the schema object used for conversion + * @param gen a JsonGenerator object * @param row The row to convert */ - private[sql] def rowToJSON(rowSchema: StructType, jsonFactory: JsonFactory)(row: Row): String = { - val writer = new StringWriter() - val gen = jsonFactory.createGenerator(writer) - + private[sql] def rowToJSON(rowSchema: StructType, gen: JsonGenerator)(row: Row) = { def valWriter: (DataType, Any) => Unit = { case (_, null) | (NullType, _) => gen.writeNull() case (StringType, v: String) => gen.writeString(v) @@ -477,8 +489,5 @@ private[sql] object JsonRDD extends Logging { } valWriter(rowSchema, row) - gen.close() - writer.toString } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 02e5b015e8ec2..3f97a11ceb97d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -34,10 +34,13 @@ import org.apache.spark.sql.execution.SparkPlan package object sql { /** - * Converts a logical plan into zero or more SparkPlans. + * Converts a logical plan into zero or more SparkPlans. This API is exposed for experimenting + * with the query planner and is not designed to be stable across spark releases. Developers + * writing libraries should instead consider using the stable APIs provided in + * [[org.apache.spark.sql.sources]] */ @DeveloperApi - protected[sql] type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan] + type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan] /** * Type alias for [[DataFrame]]. Kept here for backward source compatibility for Scala. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 10df8c3310092..43ca359b51735 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -17,8 +17,12 @@ package org.apache.spark.sql.parquet +import java.sql.Timestamp +import java.util.{TimeZone, Calendar} + import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap} +import jodd.datetime.JDateTime import parquet.column.Dictionary import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} import parquet.schema.MessageType @@ -26,6 +30,7 @@ import parquet.schema.MessageType import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.parquet.CatalystConverter.FieldType import org.apache.spark.sql.types._ +import org.apache.spark.sql.parquet.timestamp.NanoTime /** * Collection of converters of Parquet types (group and primitive types) that @@ -61,6 +66,11 @@ private[sql] object CatalystConverter { // Using a different value will result in Parquet silently dropping columns. val ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME = "bag" val ARRAY_ELEMENTS_SCHEMA_NAME = "array" + // SPARK-4520: Thrift generated parquet files have different array element + // schema names than avro. Thrift parquet uses array_schema_name + "_tuple" + // as opposed to "array" used by default. For more information, check + // TestThriftSchemaConverter.java in parquet.thrift. + val THRIFT_ARRAY_ELEMENTS_SCHEMA_NAME_SUFFIX = "_tuple" val MAP_KEY_SCHEMA_NAME = "key" val MAP_VALUE_SCHEMA_NAME = "value" val MAP_SCHEMA_NAME = "map" @@ -117,12 +127,24 @@ private[sql] object CatalystConverter { parent.updateByte(fieldIndex, value.asInstanceOf[ByteType.JvmType]) } } + case DateType => { + new CatalystPrimitiveConverter(parent, fieldIndex) { + override def addInt(value: Int): Unit = + parent.updateDate(fieldIndex, value.asInstanceOf[DateType.JvmType]) + } + } case d: DecimalType => { new CatalystPrimitiveConverter(parent, fieldIndex) { override def addBinary(value: Binary): Unit = parent.updateDecimal(fieldIndex, value, d) } } + case TimestampType => { + new CatalystPrimitiveConverter(parent, fieldIndex) { + override def addBinary(value: Binary): Unit = + parent.updateTimestamp(fieldIndex, value) + } + } // All other primitive types use the default converter case ctype: PrimitiveType => { // note: need the type tag here! new CatalystPrimitiveConverter(parent, fieldIndex) @@ -176,6 +198,9 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = updateField(fieldIndex, value) + protected[parquet] def updateDate(fieldIndex: Int, value: Int): Unit = + updateField(fieldIndex, value) + protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = updateField(fieldIndex, value) @@ -197,9 +222,11 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = updateField(fieldIndex, value) - protected[parquet] def updateDecimal(fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { + protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit = + updateField(fieldIndex, readTimestamp(value)) + + protected[parquet] def updateDecimal(fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = updateField(fieldIndex, readDecimal(new Decimal(), value, ctype)) - } protected[parquet] def isRootConverter: Boolean = parent == null @@ -232,6 +259,13 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { unscaled = (unscaled << (64 - numBits)) >> (64 - numBits) dest.set(unscaled, precision, scale) } + + /** + * Read a Timestamp value from a Parquet Int96Value + */ + protected[parquet] def readTimestamp(value: Binary): Timestamp = { + CatalystTimestampConverter.convertToTimestamp(value) + } } /** @@ -363,6 +397,9 @@ private[parquet] class CatalystPrimitiveRowConverter( override protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = current.setInt(fieldIndex, value) + override protected[parquet] def updateDate(fieldIndex: Int, value: Int): Unit = + current.update(fieldIndex, value) + override protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = current.setLong(fieldIndex, value) @@ -384,6 +421,9 @@ private[parquet] class CatalystPrimitiveRowConverter( override protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = current.setString(fieldIndex, value) + override protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit = + current.update(fieldIndex, readTimestamp(value)) + override protected[parquet] def updateDecimal( fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { var decimal = current(fieldIndex).asInstanceOf[Decimal] @@ -454,6 +494,73 @@ private[parquet] object CatalystArrayConverter { val INITIAL_ARRAY_SIZE = 20 } +private[parquet] object CatalystTimestampConverter { + // TODO most part of this comes from Hive-0.14 + // Hive code might have some issues, so we need to keep an eye on it. + // Also we use NanoTime and Int96Values from parquet-examples. + // We utilize jodd to convert between NanoTime and Timestamp + val parquetTsCalendar = new ThreadLocal[Calendar] + def getCalendar: Calendar = { + // this is a cache for the calendar instance. + if (parquetTsCalendar.get == null) { + parquetTsCalendar.set(Calendar.getInstance(TimeZone.getTimeZone("GMT"))) + } + parquetTsCalendar.get + } + val NANOS_PER_SECOND: Long = 1000000000 + val SECONDS_PER_MINUTE: Long = 60 + val MINUTES_PER_HOUR: Long = 60 + val NANOS_PER_MILLI: Long = 1000000 + + def convertToTimestamp(value: Binary): Timestamp = { + val nt = NanoTime.fromBinary(value) + val timeOfDayNanos = nt.getTimeOfDayNanos + val julianDay = nt.getJulianDay + val jDateTime = new JDateTime(julianDay.toDouble) + val calendar = getCalendar + calendar.set(Calendar.YEAR, jDateTime.getYear) + calendar.set(Calendar.MONTH, jDateTime.getMonth - 1) + calendar.set(Calendar.DAY_OF_MONTH, jDateTime.getDay) + + // written in command style + var remainder = timeOfDayNanos + calendar.set( + Calendar.HOUR_OF_DAY, + (remainder / (NANOS_PER_SECOND * SECONDS_PER_MINUTE * MINUTES_PER_HOUR)).toInt) + remainder = remainder % (NANOS_PER_SECOND * SECONDS_PER_MINUTE * MINUTES_PER_HOUR) + calendar.set( + Calendar.MINUTE, (remainder / (NANOS_PER_SECOND * SECONDS_PER_MINUTE)).toInt) + remainder = remainder % (NANOS_PER_SECOND * SECONDS_PER_MINUTE) + calendar.set(Calendar.SECOND, (remainder / NANOS_PER_SECOND).toInt) + val nanos = remainder % NANOS_PER_SECOND + val ts = new Timestamp(calendar.getTimeInMillis) + ts.setNanos(nanos.toInt) + ts + } + + def convertFromTimestamp(ts: Timestamp): Binary = { + val calendar = getCalendar + calendar.setTime(ts) + val jDateTime = new JDateTime(calendar.get(Calendar.YEAR), + calendar.get(Calendar.MONTH) + 1, calendar.get(Calendar.DAY_OF_MONTH)) + // Hive-0.14 didn't set hour before get day number, while the day number should + // has something to do with hour, since julian day number grows at 12h GMT + // here we just follow what hive does. + val julianDay = jDateTime.getJulianDayNumber + + val hour = calendar.get(Calendar.HOUR_OF_DAY) + val minute = calendar.get(Calendar.MINUTE) + val second = calendar.get(Calendar.SECOND) + val nanos = ts.getNanos + // Hive-0.14 would use hours directly, that might be wrong, since the day starts + // from 12h in Julian. here we just follow what hive does. + val nanosOfDay = nanos + second * NANOS_PER_SECOND + + minute * NANOS_PER_SECOND * SECONDS_PER_MINUTE + + hour * NANOS_PER_SECOND * SECONDS_PER_MINUTE * MINUTES_PER_HOUR + NanoTime(julianDay, nanosOfDay).toBinary + } +} + /** * A `parquet.io.api.GroupConverter` that converts a single-element groups that * match the characteristics of an array (see diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index a54485e719dad..fcb9513ab66f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -18,11 +18,13 @@ package org.apache.spark.sql.parquet import java.io.IOException +import java.util.logging.Level import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.permission.FsAction -import parquet.hadoop.ParquetOutputFormat +import org.apache.spark.sql.types.{StructType, DataType} +import parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat} import parquet.hadoop.metadata.CompressionCodecName import parquet.schema.MessageType @@ -65,20 +67,26 @@ private[sql] case class ParquetRelation( ParquetTypesConverter.readSchemaFromFile( new Path(path.split(",").head), conf, - sqlContext.conf.isParquetBinaryAsString) - + sqlContext.conf.isParquetBinaryAsString, + sqlContext.conf.isParquetINT96AsTimestamp) lazy val attributeMap = AttributeMap(output.map(o => o -> o)) - override def newInstance() = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] + override def newInstance(): this.type = { + ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] + } // Equals must also take into account the output attributes so that we can distinguish between // different instances of the same relation, - override def equals(other: Any) = other match { + override def equals(other: Any): Boolean = other match { case p: ParquetRelation => p.path == path && p.output == output case _ => false } + override def hashCode: Int = { + com.google.common.base.Objects.hashCode(path, output) + } + // TODO: Use data from the footers. override lazy val statistics = Statistics(sizeInBytes = sqlContext.conf.defaultSizeInBytes) } @@ -91,7 +99,7 @@ private[sql] object ParquetRelation { // checks first to see if there's any handlers already set // and if not it creates them. If this method executes prior // to that class being loaded then: - // 1) there's no handlers installed so there's none to + // 1) there's no handlers installed so there's none to // remove. But when it IS finally loaded the desired affect // of removing them is circumvented. // 2) The parquet.Log static initializer calls setUseParentHanders(false) @@ -99,7 +107,7 @@ private[sql] object ParquetRelation { // // Therefore we need to force the class to be loaded. // This should really be resolved by Parquet. - Class.forName(classOf[parquet.Log].getName()) + Class.forName(classOf[parquet.Log].getName) // Note: Logger.getLogger("parquet") has a default logger // that appends to Console which needs to be cleared. @@ -108,6 +116,11 @@ private[sql] object ParquetRelation { // TODO(witgo): Need to set the log level ? // if(parquetLogger.getLevel != null) parquetLogger.setLevel(null) if (!parquetLogger.getUseParentHandlers) parquetLogger.setUseParentHandlers(true) + + // Disables WARN log message in ParquetOutputCommitter. + // See https://issues.apache.org/jira/browse/SPARK-5968 for details + Class.forName(classOf[ParquetOutputCommitter].getName) + java.util.logging.Logger.getLogger(classOf[ParquetOutputCommitter].getName).setLevel(Level.OFF) } // The element type for the RDDs that this relation maps to. @@ -166,9 +179,13 @@ private[sql] object ParquetRelation { sqlContext.conf.parquetCompressionCodec.toUpperCase, CompressionCodecName.UNCOMPRESSED) .name()) ParquetRelation.enableLogForwarding() - ParquetTypesConverter.writeMetaData(attributes, path, conf) + // This is a hack. We always set nullable/containsNull/valueContainsNull to true + // for the schema of a parquet data. + val schema = StructType.fromAttributes(attributes).asNullable + val newAttributes = schema.toAttributes + ParquetTypesConverter.writeMetaData(newAttributes, path, conf) new ParquetRelation(path.toString, Some(conf), sqlContext) { - override val output = attributes + override val output = newAttributes } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 28cd17fde46ab..1c868da23e060 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -19,10 +19,9 @@ package org.apache.spark.sql.parquet import java.io.IOException import java.lang.{Long => JLong} -import java.text.SimpleDateFormat -import java.text.NumberFormat +import java.text.{NumberFormat, SimpleDateFormat} import java.util.concurrent.{Callable, TimeUnit} -import java.util.{ArrayList, Collections, Date, List => JList} +import java.util.{Date, List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable @@ -43,11 +42,13 @@ import parquet.io.ParquetDecodingException import parquet.schema.MessageType import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row, _} import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} +import org.apache.spark.sql.types.StructType import org.apache.spark.{Logging, SerializableWritable, TaskContext} /** @@ -55,7 +56,7 @@ import org.apache.spark.{Logging, SerializableWritable, TaskContext} * Parquet table scan operator. Imports the file that backs the given * [[org.apache.spark.sql.parquet.ParquetRelation]] as a ``RDD[Row]``. */ -case class ParquetTableScan( +private[sql] case class ParquetTableScan( attributes: Seq[Attribute], relation: ParquetRelation, columnPruningPred: Seq[Expression]) @@ -125,6 +126,13 @@ case class ParquetTableScan( conf) if (requestedPartitionOrdinals.nonEmpty) { + // This check is based on CatalystConverter.createRootConverter. + val primitiveRow = output.forall(a => ParquetTypesConverter.isPrimitiveType(a.dataType)) + + // Uses temporary variable to avoid the whole `ParquetTableScan` object being captured into + // the `mapPartitionsWithInputSplit` closure below. + val outputSize = output.size + baseRDD.mapPartitionsWithInputSplit { case (split, iter) => val partValue = "([^=]+)=([^=]+)".r val partValues = @@ -142,19 +150,47 @@ case class ParquetTableScan( relation.partitioningAttributes .map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow)) - new Iterator[Row] { - def hasNext = iter.hasNext - def next() = { - val row = iter.next()._2.asInstanceOf[SpecificMutableRow] - - // Parquet will leave partitioning columns empty, so we fill them in here. - var i = 0 - while (i < requestedPartitionOrdinals.size) { - row(requestedPartitionOrdinals(i)._2) = - partitionRowValues(requestedPartitionOrdinals(i)._1) - i += 1 + if (primitiveRow) { + new Iterator[Row] { + def hasNext: Boolean = iter.hasNext + def next(): Row = { + // We are using CatalystPrimitiveRowConverter and it returns a SpecificMutableRow. + val row = iter.next()._2.asInstanceOf[SpecificMutableRow] + + // Parquet will leave partitioning columns empty, so we fill them in here. + var i = 0 + while (i < requestedPartitionOrdinals.size) { + row(requestedPartitionOrdinals(i)._2) = + partitionRowValues(requestedPartitionOrdinals(i)._1) + i += 1 + } + row + } + } + } else { + // Create a mutable row since we need to fill in values from partition columns. + val mutableRow = new GenericMutableRow(outputSize) + new Iterator[Row] { + def hasNext: Boolean = iter.hasNext + def next(): Row = { + // We are using CatalystGroupConverter and it returns a GenericRow. + // Since GenericRow is not mutable, we just cast it to a Row. + val row = iter.next()._2.asInstanceOf[Row] + + var i = 0 + while (i < row.size) { + mutableRow(i) = row(i) + i += 1 + } + // Parquet will leave partitioning columns empty, so we fill them in here. + i = 0 + while (i < requestedPartitionOrdinals.size) { + mutableRow(requestedPartitionOrdinals(i)._2) = + partitionRowValues(requestedPartitionOrdinals(i)._1) + i += 1 + } + mutableRow } - row } } } @@ -210,7 +246,7 @@ case class ParquetTableScan( * (only detected via filename pattern so will not catch all cases). */ @DeveloperApi -case class InsertIntoParquetTable( +private[sql] case class InsertIntoParquetTable( relation: ParquetRelation, child: SparkPlan, overwrite: Boolean = false) @@ -219,7 +255,7 @@ case class InsertIntoParquetTable( /** * Inserts all rows into the Parquet file. */ - override def execute() = { + override def execute(): RDD[Row] = { // TODO: currently we do not check whether the "schema"s are compatible // That means if one first creates a table and then INSERTs data with // and incompatible schema the execution will fail. It would be nice @@ -242,7 +278,10 @@ case class InsertIntoParquetTable( ParquetOutputFormat.setWriteSupportClass(job, writeSupport) val conf = ContextUtil.getConfiguration(job) - RowWriteSupport.setSchema(relation.output, conf) + // This is a hack. We always set nullable/containsNull/valueContainsNull to true + // for the schema of a parquet data. + val schema = StructType.fromAttributes(relation.output).asNullable + RowWriteSupport.setSchema(schema.toAttributes, conf) val fspath = new Path(relation.path) val fs = fspath.getFileSystem(conf) @@ -263,7 +302,7 @@ case class InsertIntoParquetTable( childRdd } - override def output = child.output + override def output: Seq[Attribute] = child.output /** * Stores the given Row RDD as a Hadoop file. @@ -317,7 +356,7 @@ case class InsertIntoParquetTable( } finally { writer.close(hadoopContext) } - committer.commitTask(hadoopContext) + SparkHadoopMapRedUtil.commitTask(committer, hadoopContext, context) 1 } val jobFormat = new AppendingParquetOutputFormat(taskIdOffset) @@ -373,8 +412,6 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) private[parquet] class FilteringParquetRowInputFormat extends parquet.hadoop.ParquetInputFormat[Row] with Logging { - private var footers: JList[Footer] = _ - private var fileStatuses = Map.empty[Path, FileStatus] override def createRecordReader( @@ -395,46 +432,15 @@ private[parquet] class FilteringParquetRowInputFormat } } - override def getFooters(jobContext: JobContext): JList[Footer] = { - import org.apache.spark.sql.parquet.FilteringParquetRowInputFormat.footerCache - - if (footers eq null) { - val conf = ContextUtil.getConfiguration(jobContext) - val cacheMetadata = conf.getBoolean(SQLConf.PARQUET_CACHE_METADATA, true) - val statuses = listStatus(jobContext) - fileStatuses = statuses.map(file => file.getPath -> file).toMap - if (statuses.isEmpty) { - footers = Collections.emptyList[Footer] - } else if (!cacheMetadata) { - // Read the footers from HDFS - footers = getFooters(conf, statuses) - } else { - // Read only the footers that are not in the footerCache - val foundFooters = footerCache.getAllPresent(statuses) - val toFetch = new ArrayList[FileStatus] - for (s <- statuses) { - if (!foundFooters.containsKey(s)) { - toFetch.add(s) - } - } - val newFooters = new mutable.HashMap[FileStatus, Footer] - if (toFetch.size > 0) { - val startFetch = System.currentTimeMillis - val fetched = getFooters(conf, toFetch) - logInfo(s"Fetched $toFetch footers in ${System.currentTimeMillis - startFetch} ms") - for ((status, i) <- toFetch.zipWithIndex) { - newFooters(status) = fetched.get(i) - } - footerCache.putAll(newFooters) - } - footers = new ArrayList[Footer](statuses.size) - for (status <- statuses) { - footers.add(newFooters.getOrElse(status, foundFooters.get(status))) - } - } - } + // This is only a temporary solution sicne we need to use fileStatuses in + // both getClientSideSplits and getTaskSideSplits. It can be removed once we get rid of these + // two methods. + override def getSplits(jobContext: JobContext): JList[InputSplit] = { + // First set fileStatuses. + val statuses = listStatus(jobContext) + fileStatuses = statuses.map(file => file.getPath -> file).toMap - footers + super.getSplits(jobContext) } // TODO Remove this method and related code once PARQUET-16 is fixed @@ -459,13 +465,21 @@ private[parquet] class FilteringParquetRowInputFormat val getGlobalMetaData = classOf[ParquetFileWriter].getDeclaredMethod("getGlobalMetaData", classOf[JList[Footer]]) getGlobalMetaData.setAccessible(true) - val globalMetaData = getGlobalMetaData.invoke(null, footers).asInstanceOf[GlobalMetaData] + var globalMetaData = getGlobalMetaData.invoke(null, footers).asInstanceOf[GlobalMetaData] if (globalMetaData == null) { val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] return splits } + val metadata = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA) + val mergedMetadata = globalMetaData + .getKeyValueMetaData + .updated(RowReadSupport.SPARK_METADATA_KEY, setAsJavaSet(Set(metadata))) + + globalMetaData = new GlobalMetaData(globalMetaData.getSchema, + mergedMetadata, globalMetaData.getCreatedBy) + val readContext = getReadSupport(configuration).init( new InitContext(configuration, globalMetaData.getKeyValueMetaData, @@ -498,6 +512,7 @@ private[parquet] class FilteringParquetRowInputFormat import parquet.filter2.compat.FilterCompat.Filter import parquet.filter2.compat.RowGroupFilter + import org.apache.spark.sql.parquet.FilteringParquetRowInputFormat.blockLocationCache val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, true) @@ -647,6 +662,6 @@ private[parquet] object FileSystemHelper { sys.error("ERROR: attempting to append to set of Parquet files and found file" + s"that does not match name pattern: $other") case _ => 0 - }.reduceLeft((a, b) => if (a < b) b else a) + }.reduceOption(_ max _).getOrElse(0) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index fd63ad8144064..5a1b15490d273 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -83,7 +83,8 @@ private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging { // TODO: Why it can be null? if (schema == null) { log.debug("falling back to Parquet read schema") - schema = ParquetTypesConverter.convertToAttributes(parquetSchema, false) + schema = ParquetTypesConverter.convertToAttributes( + parquetSchema, false, true) } log.debug(s"list of attributes that will be read: $schema") new RowRecordMaterializer(parquetSchema, schema) @@ -98,7 +99,11 @@ private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging { val requestedAttributes = RowReadSupport.getRequestedSchema(configuration) if (requestedAttributes != null) { - parquetSchema = ParquetTypesConverter.convertFromAttributes(requestedAttributes) + // If the parquet file is thrift derived, there is a good chance that + // it will have the thrift class in metadata. + val isThriftDerived = keyValueMetaData.keySet().contains("thrift.class") + parquetSchema = ParquetTypesConverter + .convertFromAttributes(requestedAttributes, isThriftDerived) metadata.put( RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, ParquetTypesConverter.convertToString(requestedAttributes)) @@ -154,7 +159,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { val attributesSize = attributes.size if (attributesSize > record.size) { throw new IndexOutOfBoundsException( - s"Trying to write more fields than contained in row (${attributesSize}>${record.size})") + s"Trying to write more fields than contained in row ($attributesSize > ${record.size})") } var index = 0 @@ -184,12 +189,12 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { case t @ StructType(_) => writeStruct( t, value.asInstanceOf[CatalystConverter.StructScalaType[_]]) - case _ => writePrimitive(schema.asInstanceOf[PrimitiveType], value) + case _ => writePrimitive(schema.asInstanceOf[NativeType], value) } } } - private[parquet] def writePrimitive(schema: PrimitiveType, value: Any): Unit = { + private[parquet] def writePrimitive(schema: DataType, value: Any): Unit = { if (value != null) { schema match { case StringType => writer.addBinary( @@ -202,10 +207,12 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { case IntegerType => writer.addInteger(value.asInstanceOf[Int]) case ShortType => writer.addInteger(value.asInstanceOf[Short]) case LongType => writer.addLong(value.asInstanceOf[Long]) + case TimestampType => writeTimestamp(value.asInstanceOf[java.sql.Timestamp]) case ByteType => writer.addInteger(value.asInstanceOf[Byte]) case DoubleType => writer.addDouble(value.asInstanceOf[Double]) case FloatType => writer.addFloat(value.asInstanceOf[Float]) case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean]) + case DateType => writer.addInteger(value.asInstanceOf[Int]) case d: DecimalType => if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { sys.error(s"Unsupported datatype $d, cannot write to consumer") @@ -307,6 +314,10 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { writer.addBinary(Binary.fromByteArray(scratchBytes, 0, numBytes)) } + private[parquet] def writeTimestamp(ts: java.sql.Timestamp): Unit = { + val binaryNanoTime = CatalystTimestampConverter.convertFromTimestamp(ts) + writer.addBinary(binaryNanoTime) + } } // Optimized for non-nested rows @@ -315,7 +326,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { val attributesSize = attributes.size if (attributesSize > record.size) { throw new IndexOutOfBoundsException( - s"Trying to write more fields than contained in row (${attributesSize}>${record.size})") + s"Trying to write more fields than contained in row ($attributesSize > ${record.size})") } var index = 0 @@ -338,10 +349,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { index: Int): Unit = { ctype match { case StringType => writer.addBinary( - Binary.fromByteArray( - record(index).asInstanceOf[String].getBytes("utf-8") - ) - ) + Binary.fromByteArray(record(index).asInstanceOf[String].getBytes("utf-8"))) case BinaryType => writer.addBinary( Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]])) case IntegerType => writer.addInteger(record.getInt(index)) @@ -351,6 +359,8 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { case DoubleType => writer.addDouble(record.getDouble(index)) case FloatType => writer.addFloat(record.getFloat(index)) case BooleanType => writer.addBoolean(record.getBoolean(index)) + case DateType => writer.addInteger(record.getInt(index)) + case TimestampType => writeTimestamp(record(index).asInstanceOf[java.sql.Timestamp]) case d: DecimalType => if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { sys.error(s"Unsupported datatype $d, cannot write to consumer") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala index 9d6c529574da0..d6ea6679c5966 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala @@ -23,8 +23,8 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import scala.util.Try -import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.util +import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} import org.apache.spark.util.Utils /** @@ -34,10 +34,11 @@ import org.apache.spark.util.Utils * convenient to use tuples rather than special case classes when writing test cases/suites. * Especially, `Tuple1.apply` can be used to easily wrap a single type/value. */ -trait ParquetTest { +private[sql] trait ParquetTest { val sqlContext: SQLContext - import sqlContext._ + import sqlContext.implicits.{localSeqToDataFrameHolder, rddToDataFrameHolder} + import sqlContext.{conf, sparkContext} protected def configuration = sparkContext.hadoopConfiguration @@ -49,11 +50,11 @@ trait ParquetTest { */ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(getConf(key)).toOption) - (keys, values).zipped.foreach(setConf) + val currentValues = keys.map(key => Try(conf.getConf(key)).toOption) + (keys, values).zipped.foreach(conf.setConf) try f finally { keys.zip(currentValues).foreach { - case (key, Some(value)) => setConf(key, value) + case (key, Some(value)) => conf.setConf(key, value) case (key, None) => conf.unsetConf(key) } } @@ -89,7 +90,7 @@ trait ParquetTest { (data: Seq[T]) (f: String => Unit): Unit = { withTempPath { file => - sparkContext.parallelize(data).saveAsParquetFile(file.getCanonicalPath) + sparkContext.parallelize(data).toDF().saveAsParquetFile(file.getCanonicalPath) f(file.getCanonicalPath) } } @@ -98,17 +99,17 @@ trait ParquetTest { * Writes `data` to a Parquet file and reads it back as a [[DataFrame]], * which is then passed to `f`. The Parquet file will be deleted after `f` returns. */ - protected def withParquetRDD[T <: Product: ClassTag: TypeTag] + protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag] (data: Seq[T]) (f: DataFrame => Unit): Unit = { - withParquetFile(data)(path => f(parquetFile(path))) + withParquetFile(data)(path => f(sqlContext.parquetFile(path))) } /** * Drops temporary table `tableName` after calling `f`. */ protected def withTempTable(tableName: String)(f: => Unit): Unit = { - try f finally dropTempTable(tableName) + try f finally sqlContext.dropTempTable(tableName) } /** @@ -119,9 +120,36 @@ trait ParquetTest { protected def withParquetTable[T <: Product: ClassTag: TypeTag] (data: Seq[T], tableName: String) (f: => Unit): Unit = { - withParquetRDD(data) { rdd => - sqlContext.registerRDDAsTable(rdd, tableName) + withParquetDataFrame(data) { df => + sqlContext.registerDataFrameAsTable(df, tableName) withTempTable(tableName)(f) } } + + protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( + data: Seq[T], path: File): Unit = { + data.toDF().save(path.getCanonicalPath, "org.apache.spark.sql.parquet", SaveMode.Overwrite) + } + + protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( + df: DataFrame, path: File): Unit = { + df.save(path.getCanonicalPath, "org.apache.spark.sql.parquet", SaveMode.Overwrite) + } + + protected def makePartitionDir( + basePath: File, + defaultPartitionName: String, + partitionCols: (String, Any)*): File = { + val partNames = partitionCols.map { case (k, v) => + val valueString = if (v == null || v == "") defaultPartitionName else v.toString + s"$k=$valueString" + } + + val partDir = partNames.foldLeft(basePath) { (parent, child) => + new File(parent, child) + } + + assert(partDir.mkdirs(), s"Couldn't create directory $partDir") + partDir + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala index d5993656e0225..e4a10aa2ae6c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.mapreduce.Job import org.apache.spark.sql.test.TestSQLContext import parquet.example.data.{GroupWriter, Group} -import parquet.example.data.simple.SimpleGroup +import parquet.example.data.simple.{NanoTime, SimpleGroup} import parquet.hadoop.{ParquetReader, ParquetFileReader, ParquetWriter} import parquet.hadoop.api.WriteSupport import parquet.hadoop.api.WriteSupport.WriteContext @@ -63,6 +63,7 @@ private[sql] object ParquetTestData { optional int64 mylong; optional float myfloat; optional double mydouble; + optional int96 mytimestamp; }""" // field names for test assertion error messages @@ -72,7 +73,8 @@ private[sql] object ParquetTestData { "mystring:String", "mylong:Long", "myfloat:Float", - "mydouble:Double" + "mydouble:Double", + "mytimestamp:Timestamp" ) val subTestSchema = @@ -98,6 +100,7 @@ private[sql] object ParquetTestData { optional int64 myoptlong; optional float myoptfloat; optional double myoptdouble; + optional int96 mytimestamp; } """ @@ -236,6 +239,7 @@ private[sql] object ParquetTestData { record.add(3, i.toLong << 33) record.add(4, 2.5F) record.add(5, 4.5D) + record.add(6, new NanoTime(1,2)) writer.write(record) } writer.close() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 6d8c682ccced8..da668f068613b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -19,24 +19,23 @@ package org.apache.spark.sql.parquet import java.io.IOException +import scala.collection.mutable.ArrayBuffer import scala.util.Try import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job - import parquet.format.converter.ParquetMetadataConverter -import parquet.hadoop.{ParquetFileReader, Footer, ParquetFileWriter} -import parquet.hadoop.metadata.{ParquetMetadata, FileMetaData} +import parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} import parquet.hadoop.util.ContextUtil -import parquet.schema.{Type => ParquetType, Types => ParquetTypes, PrimitiveType => ParquetPrimitiveType, MessageType} -import parquet.schema.{GroupType => ParquetGroupType, OriginalType => ParquetOriginalType, ConversionPatterns, DecimalMetadata} +import parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} import parquet.schema.Type.Repetition +import parquet.schema.{ConversionPatterns, DecimalMetadata, GroupType => ParquetGroupType, MessageType, OriginalType => ParquetOriginalType, PrimitiveType => ParquetPrimitiveType, Type => ParquetType, Types => ParquetTypes} -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.types._ +import org.apache.spark.{Logging, SparkException} // Implicits import scala.collection.JavaConversions._ @@ -54,7 +53,8 @@ private[parquet] object ParquetTypesConverter extends Logging { def toPrimitiveDataType( parquetType: ParquetPrimitiveType, - binaryAsString: Boolean): DataType = { + binaryAsString: Boolean, + int96AsTimestamp: Boolean): DataType = { val originalType = parquetType.getOriginalType val decimalInfo = parquetType.getDecimalMetadata parquetType.getPrimitiveTypeName match { @@ -64,8 +64,11 @@ private[parquet] object ParquetTypesConverter extends Logging { case ParquetPrimitiveTypeName.BOOLEAN => BooleanType case ParquetPrimitiveTypeName.DOUBLE => DoubleType case ParquetPrimitiveTypeName.FLOAT => FloatType + case ParquetPrimitiveTypeName.INT32 + if originalType == ParquetOriginalType.DATE => DateType case ParquetPrimitiveTypeName.INT32 => IntegerType case ParquetPrimitiveTypeName.INT64 => LongType + case ParquetPrimitiveTypeName.INT96 if int96AsTimestamp => TimestampType case ParquetPrimitiveTypeName.INT96 => // TODO: add BigInteger type? TODO(andre) use DecimalType instead???? sys.error("Potential loss of precision: cannot convert INT96") @@ -103,7 +106,9 @@ private[parquet] object ParquetTypesConverter extends Logging { * @param parquetType The type to convert. * @return The corresponding Catalyst type. */ - def toDataType(parquetType: ParquetType, isBinaryAsString: Boolean): DataType = { + def toDataType(parquetType: ParquetType, + isBinaryAsString: Boolean, + isInt96AsTimestamp: Boolean): DataType = { def correspondsToMap(groupType: ParquetGroupType): Boolean = { if (groupType.getFieldCount != 1 || groupType.getFields.apply(0).isPrimitive) { false @@ -125,7 +130,7 @@ private[parquet] object ParquetTypesConverter extends Logging { } if (parquetType.isPrimitive) { - toPrimitiveDataType(parquetType.asPrimitiveType, isBinaryAsString) + toPrimitiveDataType(parquetType.asPrimitiveType, isBinaryAsString, isInt96AsTimestamp) } else { val groupType = parquetType.asGroupType() parquetType.getOriginalType match { @@ -137,9 +142,12 @@ private[parquet] object ParquetTypesConverter extends Logging { if (field.getName == CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME) { val bag = field.asGroupType() assert(bag.getFieldCount == 1) - ArrayType(toDataType(bag.getFields.apply(0), isBinaryAsString), containsNull = true) + ArrayType( + toDataType(bag.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp), + containsNull = true) } else { - ArrayType(toDataType(field, isBinaryAsString), containsNull = false) + ArrayType( + toDataType(field, isBinaryAsString, isInt96AsTimestamp), containsNull = false) } } case ParquetOriginalType.MAP => { @@ -152,8 +160,10 @@ private[parquet] object ParquetTypesConverter extends Logging { "Parquet Map type malformatted: nested group should have 2 (key, value) fields!") assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) - val keyType = toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString) - val valueType = toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString) + val keyType = + toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp) + val valueType = + toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString, isInt96AsTimestamp) MapType(keyType, valueType, keyValueGroup.getFields.apply(1).getRepetition != Repetition.REQUIRED) } @@ -163,8 +173,10 @@ private[parquet] object ParquetTypesConverter extends Logging { val keyValueGroup = groupType.getFields.apply(0).asGroupType() assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) - val keyType = toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString) - val valueType = toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString) + val keyType = + toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp) + val valueType = + toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString, isInt96AsTimestamp) MapType(keyType, valueType, keyValueGroup.getFields.apply(1).getRepetition != Repetition.REQUIRED) } else if (correspondsToArray(groupType)) { // ArrayType @@ -172,16 +184,19 @@ private[parquet] object ParquetTypesConverter extends Logging { if (field.getName == CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME) { val bag = field.asGroupType() assert(bag.getFieldCount == 1) - ArrayType(toDataType(bag.getFields.apply(0), isBinaryAsString), containsNull = true) + ArrayType( + toDataType(bag.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp), + containsNull = true) } else { - ArrayType(toDataType(field, isBinaryAsString), containsNull = false) + ArrayType( + toDataType(field, isBinaryAsString, isInt96AsTimestamp), containsNull = false) } } else { // everything else: StructType val fields = groupType .getFields .map(ptype => new StructField( ptype.getName, - toDataType(ptype, isBinaryAsString), + toDataType(ptype, isBinaryAsString, isInt96AsTimestamp), ptype.getRepetition != Repetition.REQUIRED)) StructType(fields) } @@ -209,7 +224,10 @@ private[parquet] object ParquetTypesConverter extends Logging { // There is no type for Byte or Short so we promote them to INT32. case ShortType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) case ByteType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) + case DateType => Some(ParquetTypeInfo( + ParquetPrimitiveTypeName.INT32, Some(ParquetOriginalType.DATE))) case LongType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT64)) + case TimestampType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT96)) case DecimalType.Fixed(precision, scale) if precision <= 18 => // TODO: for now, our writer only supports decimals that fit in a Long Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, @@ -270,13 +288,19 @@ private[parquet] object ParquetTypesConverter extends Logging { ctype: DataType, name: String, nullable: Boolean = true, - inArray: Boolean = false): ParquetType = { + inArray: Boolean = false, + toThriftSchemaNames: Boolean = false): ParquetType = { val repetition = if (inArray) { Repetition.REPEATED } else { if (nullable) Repetition.OPTIONAL else Repetition.REQUIRED } + val arraySchemaName = if (toThriftSchemaNames) { + name + CatalystConverter.THRIFT_ARRAY_ELEMENTS_SCHEMA_NAME_SUFFIX + } else { + CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME + } val typeInfo = fromPrimitiveDataType(ctype) typeInfo.map { case ParquetTypeInfo(primitiveType, originalType, decimalMetadata, length) => @@ -291,22 +315,24 @@ private[parquet] object ParquetTypesConverter extends Logging { }.getOrElse { ctype match { case udt: UserDefinedType[_] => { - fromDataType(udt.sqlType, name, nullable, inArray) + fromDataType(udt.sqlType, name, nullable, inArray, toThriftSchemaNames) } case ArrayType(elementType, false) => { val parquetElementType = fromDataType( elementType, - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, + arraySchemaName, nullable = false, - inArray = true) + inArray = true, + toThriftSchemaNames) ConversionPatterns.listType(repetition, name, parquetElementType) } case ArrayType(elementType, true) => { val parquetElementType = fromDataType( elementType, - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, + arraySchemaName, nullable = true, - inArray = false) + inArray = false, + toThriftSchemaNames) ConversionPatterns.listType( repetition, name, @@ -317,7 +343,8 @@ private[parquet] object ParquetTypesConverter extends Logging { } case StructType(structFields) => { val fields = structFields.map { - field => fromDataType(field.dataType, field.name, field.nullable, inArray = false) + field => fromDataType(field.dataType, field.name, field.nullable, + inArray = false, toThriftSchemaNames) } new ParquetGroupType(repetition, name, fields.toSeq) } @@ -327,13 +354,15 @@ private[parquet] object ParquetTypesConverter extends Logging { keyType, CatalystConverter.MAP_KEY_SCHEMA_NAME, nullable = false, - inArray = false) + inArray = false, + toThriftSchemaNames) val parquetValueType = fromDataType( valueType, CatalystConverter.MAP_VALUE_SCHEMA_NAME, nullable = valueContainsNull, - inArray = false) + inArray = false, + toThriftSchemaNames) ConversionPatterns.mapType( repetition, name, @@ -345,7 +374,9 @@ private[parquet] object ParquetTypesConverter extends Logging { } } - def convertToAttributes(parquetSchema: ParquetType, isBinaryAsString: Boolean): Seq[Attribute] = { + def convertToAttributes(parquetSchema: ParquetType, + isBinaryAsString: Boolean, + isInt96AsTimestamp: Boolean): Seq[Attribute] = { parquetSchema .asGroupType() .getFields @@ -353,14 +384,16 @@ private[parquet] object ParquetTypesConverter extends Logging { field => new AttributeReference( field.getName, - toDataType(field, isBinaryAsString), + toDataType(field, isBinaryAsString, isInt96AsTimestamp), field.getRepetition != Repetition.REQUIRED)()) } - def convertFromAttributes(attributes: Seq[Attribute]): MessageType = { + def convertFromAttributes(attributes: Seq[Attribute], + toThriftSchemaNames: Boolean = false): MessageType = { val fields = attributes.map( attribute => - fromDataType(attribute.dataType, attribute.name, attribute.nullable)) + fromDataType(attribute.dataType, attribute.name, attribute.nullable, + toThriftSchemaNames = toThriftSchemaNames)) new MessageType("root", fields) } @@ -476,7 +509,8 @@ private[parquet] object ParquetTypesConverter extends Logging { def readSchemaFromFile( origPath: Path, conf: Option[Configuration], - isBinaryAsString: Boolean): Seq[Attribute] = { + isBinaryAsString: Boolean, + isInt96AsTimestamp: Boolean): Seq[Attribute] = { val keyValueMetadata: java.util.Map[String, String] = readMetaData(origPath, conf) .getFileMetaData @@ -485,7 +519,9 @@ private[parquet] object ParquetTypesConverter extends Logging { convertFromString(keyValueMetadata.get(RowReadSupport.SPARK_METADATA_KEY)) } else { val attributes = convertToAttributes( - readMetaData(origPath, conf).getFileMetaData.getSchema, isBinaryAsString) + readMetaData(origPath, conf).getFileMetaData.getSchema, + isBinaryAsString, + isInt96AsTimestamp) log.info(s"Falling back to schema conversion from Parquet types; result: $attributes") attributes } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 1e794cad73936..d035897f195e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -16,209 +16,466 @@ */ package org.apache.spark.sql.parquet -import java.util.{List => JList} +import java.io.IOException +import java.lang.{Double => JDouble, Float => JFloat, Long => JLong} +import java.math.{BigDecimal => JBigDecimal} +import java.net.URI +import java.text.SimpleDateFormat +import java.util.{Date, List => JList} import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer +import scala.util.Try -import org.apache.hadoop.conf.{Configurable, Configuration} +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat import org.apache.hadoop.mapreduce.{InputSplit, Job, JobContext} + import parquet.filter2.predicate.FilterApi -import parquet.hadoop.ParquetInputFormat +import parquet.format.converter.ParquetMetadataConverter +import parquet.hadoop.metadata.CompressionCodecName import parquet.hadoop.util.ContextUtil +import parquet.hadoop.{ParquetInputFormat, _} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.{NewHadoopPartition, RDD} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.mapred.SparkHadoopMapRedUtil +import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.rdd.{NewHadoopPartition, NewHadoopRDD, RDD} +import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.parquet.ParquetTypesConverter._ import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} -import org.apache.spark.sql.{Row, SQLConf, SQLContext} -import org.apache.spark.{Logging, Partition => SparkPartition} - +import org.apache.spark.sql.types.{IntegerType, StructField, StructType, _} +import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext, SaveMode} +import org.apache.spark.{Logging, Partition => SparkPartition, SerializableWritable, SparkException, TaskContext} /** - * Allows creation of parquet based tables using the syntax - * `CREATE TEMPORARY TABLE ... USING org.apache.spark.sql.parquet`. Currently the only option - * required is `path`, which should be the location of a collection of, optionally partitioned, - * parquet files. + * Allows creation of Parquet based tables using the syntax: + * {{{ + * CREATE TEMPORARY TABLE ... USING org.apache.spark.sql.parquet OPTIONS (...) + * }}} + * + * Supported options include: + * + * - `path`: Required. When reading Parquet files, `path` should point to the location of the + * Parquet file(s). It can be either a single raw Parquet file, or a directory of Parquet files. + * In the latter case, this data source tries to discover partitioning information if the the + * directory is structured in the same style of Hive partitioned tables. When writing Parquet + * file, `path` should point to the destination folder. + * + * - `mergeSchema`: Optional. Indicates whether we should merge potentially different (but + * compatible) schemas stored in all Parquet part-files. + * + * - `partition.defaultName`: Optional. Partition name used when a value of a partition column is + * null or empty string. This is similar to the `hive.exec.default.partition.name` configuration + * in Hive. */ -class DefaultSource extends RelationProvider { +private[sql] class DefaultSource + extends RelationProvider + with SchemaRelationProvider + with CreatableRelationProvider { + + private def checkPath(parameters: Map[String, String]): String = { + parameters.getOrElse("path", sys.error("'path' must be specified for parquet tables.")) + } + /** Returns a new base relation with the given parameters. */ override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { - val path = - parameters.getOrElse("path", sys.error("'path' must be specified for parquet tables.")) + ParquetRelation2(Seq(checkPath(parameters)), parameters, None)(sqlContext) + } - ParquetRelation2(path)(sqlContext) + /** Returns a new base relation with the given parameters and schema. */ + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: StructType): BaseRelation = { + ParquetRelation2(Seq(checkPath(parameters)), parameters, Some(schema))(sqlContext) + } + + /** Returns a new base relation with the given parameters and save given data into it. */ + override def createRelation( + sqlContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + val path = checkPath(parameters) + val filesystemPath = new Path(path) + val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val doInsertion = (mode, fs.exists(filesystemPath)) match { + case (SaveMode.ErrorIfExists, true) => + sys.error(s"path $path already exists.") + case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) => + true + case (SaveMode.Ignore, exists) => + !exists + } + + val relation = if (doInsertion) { + // This is a hack. We always set nullable/containsNull/valueContainsNull to true + // for the schema of a parquet data. + val df = + sqlContext.createDataFrame( + data.queryExecution.toRdd, + data.schema.asNullable, + needsConversion = false) + val createdRelation = + createRelation(sqlContext, parameters, df.schema).asInstanceOf[ParquetRelation2] + createdRelation.insert(df, overwrite = mode == SaveMode.Overwrite) + createdRelation + } else { + // If the save mode is Ignore, we will just create the relation based on existing data. + createRelation(sqlContext, parameters) + } + + relation } } -private[parquet] case class Partition(partitionValues: Map[String, Any], files: Seq[FileStatus]) +private[sql] case class Partition(values: Row, path: String) + +private[sql] case class PartitionSpec(partitionColumns: StructType, partitions: Seq[Partition]) /** * An alternative to [[ParquetRelation]] that plugs in using the data sources API. This class is - * currently not intended as a full replacement of the parquet support in Spark SQL though it is - * likely that it will eventually subsume the existing physical plan implementation. - * - * Compared with the current implementation, this class has the following notable differences: + * intended as a full replacement of the Parquet support in Spark SQL. The old implementation will + * be deprecated and eventually removed once this version is proved to be stable enough. * - * Partitioning: Partitions are auto discovered and must be in the form of directories `key=value/` - * located at `path`. Currently only a single partitioning column is supported and it must - * be an integer. This class supports both fully self-describing data, which contains the partition - * key, and data where the partition key is only present in the folder structure. The presence - * of the partitioning key in the data is also auto-detected. The `null` partition is not yet - * supported. + * Compared with the old implementation, this class has the following notable differences: * - * Metadata: The metadata is automatically discovered by reading the first parquet file present. - * There is currently no support for working with files that have different schema. Additionally, - * when parquet metadata caching is turned on, the FileStatus objects for all data will be cached - * to improve the speed of interactive querying. When data is added to a table it must be dropped - * and recreated to pick up any changes. - * - * Statistics: Statistics for the size of the table are automatically populated during metadata - * discovery. + * - Partitioning discovery: Hive style multi-level partitions are auto discovered. + * - Metadata discovery: Parquet is a format comes with schema evolving support. This data source + * can detect and merge schemas from all Parquet part-files as long as they are compatible. + * Also, metadata and [[FileStatus]]es are cached for better performance. + * - Statistics: Statistics for the size of the table are automatically populated during schema + * discovery. */ @DeveloperApi -case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext) - extends CatalystScan with Logging { +private[sql] case class ParquetRelation2( + paths: Seq[String], + parameters: Map[String, String], + maybeSchema: Option[StructType] = None, + maybePartitionSpec: Option[PartitionSpec] = None)( + @transient val sqlContext: SQLContext) + extends BaseRelation + with CatalystScan + with InsertableRelation + with SparkHadoopMapReduceUtil + with Logging { + + // Should we merge schemas from all Parquet part-files? + private val shouldMergeSchemas = + parameters.getOrElse(ParquetRelation2.MERGE_SCHEMA, "true").toBoolean + + // Optional Metastore schema, used when converting Hive Metastore Parquet table + private val maybeMetastoreSchema = + parameters + .get(ParquetRelation2.METASTORE_SCHEMA) + .map(s => DataType.fromJson(s).asInstanceOf[StructType]) + + // Hive uses this as part of the default partition name when the partition column value is null + // or empty string + private val defaultPartitionName = parameters.getOrElse( + ParquetRelation2.DEFAULT_PARTITION_NAME, "__HIVE_DEFAULT_PARTITION__") + + override def equals(other: Any): Boolean = other match { + case relation: ParquetRelation2 => + // If schema merging is required, we don't compare the actual schemas since they may evolve. + val schemaEquality = if (shouldMergeSchemas) { + shouldMergeSchemas == relation.shouldMergeSchemas + } else { + schema == relation.schema + } + + paths.toSet == relation.paths.toSet && + schemaEquality && + maybeMetastoreSchema == relation.maybeMetastoreSchema && + maybePartitionSpec == relation.maybePartitionSpec + + case _ => false + } + + override def hashCode(): Int = { + if (shouldMergeSchemas) { + com.google.common.base.Objects.hashCode( + shouldMergeSchemas: java.lang.Boolean, + paths.toSet, + maybeMetastoreSchema, + maybePartitionSpec) + } else { + com.google.common.base.Objects.hashCode( + shouldMergeSchemas: java.lang.Boolean, + schema, + paths.toSet, + maybeMetastoreSchema, + maybePartitionSpec) + } + } + + private[sql] def sparkContext = sqlContext.sparkContext - def sparkContext = sqlContext.sparkContext + private class MetadataCache { + // `FileStatus` objects of all "_metadata" files. + private var metadataStatuses: Array[FileStatus] = _ - // Minor Hack: scala doesnt seem to respect @transient for vals declared via extraction - @transient - private var partitionKeys: Seq[String] = _ - @transient - private var partitions: Seq[Partition] = _ - discoverPartitions() + // `FileStatus` objects of all "_common_metadata" files. + private var commonMetadataStatuses: Array[FileStatus] = _ - // TODO: Only finds the first partition, assumes the key is of type Integer... - private def discoverPartitions() = { - val fs = FileSystem.get(new java.net.URI(path), sparkContext.hadoopConfiguration) - val partValue = "([^=]+)=([^=]+)".r + // Parquet footer cache. + var footers: Map[FileStatus, Footer] = _ - val childrenOfPath = fs.listStatus(new Path(path)).filterNot(_.getPath.getName.startsWith("_")) - val childDirs = childrenOfPath.filter(s => s.isDir) + // `FileStatus` objects of all data files (Parquet part-files). + var dataStatuses: Array[FileStatus] = _ - if (childDirs.size > 0) { - val partitionPairs = childDirs.map(_.getPath.getName).map { - case partValue(key, value) => (key, value) + // Partition spec of this table, including names, data types, and values of each partition + // column, and paths of each partition. + var partitionSpec: PartitionSpec = _ + + // Schema of the actual Parquet files, without partition columns discovered from partition + // directory paths. + var parquetSchema: StructType = _ + + // Schema of the whole table, including partition columns. + var schema: StructType = _ + + // Indicates whether partition columns are also included in Parquet data file schema. If not, + // we need to fill in partition column values into read rows when scanning the table. + var partitionKeysIncludedInParquetSchema: Boolean = _ + + def prepareMetadata(path: Path, schema: StructType, conf: Configuration): Unit = { + conf.set( + ParquetOutputFormat.COMPRESSION, + ParquetRelation + .shortParquetCompressionCodecNames + .getOrElse( + sqlContext.conf.parquetCompressionCodec.toUpperCase, + CompressionCodecName.UNCOMPRESSED).name()) + + ParquetRelation.enableLogForwarding() + ParquetTypesConverter.writeMetaData(schema.toAttributes, path, conf) + } + + /** + * Refreshes `FileStatus`es, footers, partition spec, and table schema. + */ + def refresh(): Unit = { + // Support either reading a collection of raw Parquet part-files, or a collection of folders + // containing Parquet files (e.g. partitioned Parquet table). + val baseStatuses = paths.distinct.map { p => + val fs = FileSystem.get(URI.create(p), sparkContext.hadoopConfiguration) + val path = new Path(p) + val qualified = path.makeQualified(fs.getUri, fs.getWorkingDirectory) + + if (!fs.exists(qualified) && maybeSchema.isDefined) { + fs.mkdirs(qualified) + prepareMetadata(qualified, maybeSchema.get, sparkContext.hadoopConfiguration) + } + + fs.getFileStatus(qualified) + }.toArray + assert(baseStatuses.forall(!_.isDir) || baseStatuses.forall(_.isDir)) + + // Lists `FileStatus`es of all leaf nodes (files) under all base directories. + val leaves = baseStatuses.flatMap { f => + val fs = FileSystem.get(f.getPath.toUri, sparkContext.hadoopConfiguration) + SparkHadoopUtil.get.listLeafStatuses(fs, f.getPath).filter { f => + isSummaryFile(f.getPath) || + !(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith(".")) + } } - val foundKeys = partitionPairs.map(_._1).distinct - if (foundKeys.size > 1) { - sys.error(s"Too many distinct partition keys: $foundKeys") + dataStatuses = leaves.filterNot(f => isSummaryFile(f.getPath)) + metadataStatuses = leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE) + commonMetadataStatuses = + leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE) + + footers = (dataStatuses ++ metadataStatuses ++ commonMetadataStatuses).par.map { f => + val parquetMetadata = ParquetFileReader.readFooter( + sparkContext.hadoopConfiguration, f, ParquetMetadataConverter.NO_FILTER) + f -> new Footer(f.getPath, parquetMetadata) + }.seq.toMap + + partitionSpec = maybePartitionSpec.getOrElse { + val partitionDirs = leaves + .filterNot(baseStatuses.contains) + .map(_.getPath.getParent) + .distinct + + if (partitionDirs.nonEmpty) { + // Parses names and values of partition columns, and infer their data types. + ParquetRelation2.parsePartitions(partitionDirs, defaultPartitionName) + } else { + // No partition directories found, makes an empty specification + PartitionSpec(StructType(Seq.empty[StructField]), Seq.empty[Partition]) + } } - // Do a parallel lookup of partition metadata. - val partitionFiles = - childDirs.par.map { d => - fs.listStatus(d.getPath) - // TODO: Is there a standard hadoop function for this? - .filterNot(_.getPath.getName.startsWith("_")) - .filterNot(_.getPath.getName.startsWith(".")) - }.seq - - partitionKeys = foundKeys.toSeq - partitions = partitionFiles.zip(partitionPairs).map { case (files, (key, value)) => - Partition(Map(key -> value.toInt), files) - }.toSeq - } else { - partitionKeys = Nil - partitions = Partition(Map.empty, childrenOfPath) :: Nil + // To get the schema. We first try to get the schema defined in maybeSchema. + // If maybeSchema is not defined, we will try to get the schema from existing parquet data + // (through readSchema). If data does not exist, we will try to get the schema defined in + // maybeMetastoreSchema (defined in the options of the data source). + // Finally, if we still could not get the schema. We throw an error. + parquetSchema = + maybeSchema + .orElse(readSchema()) + .orElse(maybeMetastoreSchema) + .getOrElse(sys.error("Failed to get the schema.")) + + partitionKeysIncludedInParquetSchema = + isPartitioned && + partitionColumns.forall(f => parquetSchema.fieldNames.contains(f.name)) + + schema = { + val fullRelationSchema = if (partitionKeysIncludedInParquetSchema) { + parquetSchema + } else { + StructType(parquetSchema.fields ++ partitionColumns.fields) + } + + // If this Parquet relation is converted from a Hive Metastore table, must reconcile case + // insensitivity issue and possible schema mismatch. + maybeMetastoreSchema + .map(ParquetRelation2.mergeMetastoreParquetSchema(_, fullRelationSchema)) + .getOrElse(fullRelationSchema) + } + } + + private def readSchema(): Option[StructType] = { + // Sees which file(s) we need to touch in order to figure out the schema. + val filesToTouch = + // Always tries the summary files first if users don't require a merged schema. In this case, + // "_common_metadata" is more preferable than "_metadata" because it doesn't contain row + // groups information, and could be much smaller for large Parquet files with lots of row + // groups. + // + // NOTE: Metadata stored in the summary files are merged from all part-files. However, for + // user defined key-value metadata (in which we store Spark SQL schema), Parquet doesn't know + // how to merge them correctly if some key is associated with different values in different + // part-files. When this happens, Parquet simply gives up generating the summary file. This + // implies that if a summary file presents, then: + // + // 1. Either all part-files have exactly the same Spark SQL schema, or + // 2. Some part-files don't contain Spark SQL schema in the key-value metadata at all (thus + // their schemas may differ from each other). + // + // Here we tend to be pessimistic and take the second case into account. Basically this means + // we can't trust the summary files if users require a merged schema, and must touch all part- + // files to do the merge. + if (shouldMergeSchemas) { + // Also includes summary files, 'cause there might be empty partition directories. + (metadataStatuses ++ commonMetadataStatuses ++ dataStatuses).toSeq + } else { + // Tries any "_common_metadata" first. Parquet files written by old versions or Parquet + // don't have this. + commonMetadataStatuses.headOption + // Falls back to "_metadata" + .orElse(metadataStatuses.headOption) + // Summary file(s) not found, the Parquet file is either corrupted, or different part- + // files contain conflicting user defined metadata (two or more values are associated + // with a same key in different files). In either case, we fall back to any of the + // first part-file, and just assume all schemas are consistent. + .orElse(dataStatuses.headOption) + .toSeq + } + + ParquetRelation2.readSchema(filesToTouch.map(footers.apply), sqlContext) } } - override val sizeInBytes = partitions.flatMap(_.files).map(_.getLen).sum + @transient private val metadataCache = new MetadataCache + metadataCache.refresh() - val dataSchema = StructType.fromAttributes( // TODO: Parquet code should not deal with attributes. - ParquetTypesConverter.readSchemaFromFile( - partitions.head.files.head.getPath, - Some(sparkContext.hadoopConfiguration), - sqlContext.conf.isParquetBinaryAsString)) + def partitionSpec: PartitionSpec = metadataCache.partitionSpec - val dataIncludesKey = - partitionKeys.headOption.map(dataSchema.fieldNames.contains(_)).getOrElse(true) + def partitionColumns: StructType = metadataCache.partitionSpec.partitionColumns - override val schema = - if (dataIncludesKey) { - dataSchema - } else { - StructType(dataSchema.fields :+ StructField(partitionKeys.head, IntegerType)) - } + def partitions: Seq[Partition] = metadataCache.partitionSpec.partitions - override def buildScan(output: Seq[Attribute], predicates: Seq[Expression]): RDD[Row] = { - // This is mostly a hack so that we can use the existing parquet filter code. - val requiredColumns = output.map(_.name) + def isPartitioned: Boolean = partitionColumns.nonEmpty - val job = new Job(sparkContext.hadoopConfiguration) - ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) - val jobConf: Configuration = ContextUtil.getConfiguration(job) + private def partitionKeysIncludedInDataSchema = metadataCache.partitionKeysIncludedInParquetSchema - val requestedSchema = StructType(requiredColumns.map(schema(_))) + private def parquetSchema = metadataCache.parquetSchema - val partitionKeySet = partitionKeys.toSet - val rawPredicate = - predicates - .filter(_.references.map(_.name).toSet.subsetOf(partitionKeySet)) - .reduceOption(And) - .getOrElse(Literal(true)) + override def schema: StructType = metadataCache.schema - // Translate the predicate so that it reads from the information derived from the - // folder structure - val castedPredicate = rawPredicate transform { - case a: AttributeReference => - val idx = partitionKeys.indexWhere(a.name == _) - BoundReference(idx, IntegerType, nullable = true) - } + private def isSummaryFile(file: Path): Boolean = { + file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE || + file.getName == ParquetFileWriter.PARQUET_METADATA_FILE + } - val inputData = new GenericMutableRow(partitionKeys.size) - val pruningCondition = InterpretedPredicate(castedPredicate) + // TODO Should calculate per scan size + // It's common that a query only scans a fraction of a large Parquet file. Returning size of the + // whole Parquet file disables some optimizations in this case (e.g. broadcast join). + override val sizeInBytes = metadataCache.dataStatuses.map(_.getLen).sum - val selectedPartitions = - if (partitionKeys.nonEmpty && predicates.nonEmpty) { - partitions.filter { part => - inputData(0) = part.partitionValues.values.head - pruningCondition(inputData) - } - } else { - partitions + // This is mostly a hack so that we can use the existing parquet filter code. + override def buildScan(output: Seq[Attribute], predicates: Seq[Expression]): RDD[Row] = { + val job = new Job(sparkContext.hadoopConfiguration) + ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) + val jobConf: Configuration = ContextUtil.getConfiguration(job) + + val selectedPartitions = prunePartitions(predicates, partitions) + val selectedFiles = if (isPartitioned) { + selectedPartitions.flatMap { p => + metadataCache.dataStatuses.filter(_.getPath.getParent.toString == p.path) } + } else { + metadataCache.dataStatuses.toSeq + } + val selectedFooters = selectedFiles.map(metadataCache.footers) - val fs = FileSystem.get(new java.net.URI(path), sparkContext.hadoopConfiguration) - val selectedFiles = selectedPartitions.flatMap(_.files).map(f => fs.makeQualified(f.getPath)) // FileInputFormat cannot handle empty lists. if (selectedFiles.nonEmpty) { - org.apache.hadoop.mapreduce.lib.input.FileInputFormat.setInputPaths(job, selectedFiles: _*) + FileInputFormat.setInputPaths(job, selectedFiles.map(_.getPath): _*) + } + + // Try to push down filters when filter push-down is enabled. + if (sqlContext.conf.parquetFilterPushDown) { + val partitionColNames = partitionColumns.map(_.name).toSet + predicates + // Don't push down predicates which reference partition columns + .filter { pred => + val referencedColNames = pred.references.map(_.name).toSet + referencedColNames.intersect(partitionColNames).isEmpty + } + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(ParquetFilters.createFilter) + .reduceOption(FilterApi.and) + .foreach(ParquetInputFormat.setFilterPredicate(jobConf, _)) } - // Push down filters when possible. Notice that not all filters can be converted to Parquet - // filter predicate. Here we try to convert each individual predicate and only collect those - // convertible ones. - predicates - .flatMap(ParquetFilters.createFilter) - .reduceOption(FilterApi.and) - .filter(_ => sqlContext.conf.parquetFilterPushDown) - .foreach(ParquetInputFormat.setFilterPredicate(jobConf, _)) + if (isPartitioned) { + logInfo { + val percentRead = selectedPartitions.size.toDouble / partitions.size.toDouble * 100 + s"Reading $percentRead% of partitions" + } + } - def percentRead = selectedPartitions.size.toDouble / partitions.size.toDouble * 100 - logInfo(s"Reading $percentRead% of $path partitions") + val requiredColumns = output.map(_.name) + val requestedSchema = StructType(requiredColumns.map(schema(_))) // Store both requested and original schema in `Configuration` jobConf.set( RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, - ParquetTypesConverter.convertToString(requestedSchema.toAttributes)) + convertToString(requestedSchema.toAttributes)) jobConf.set( RowWriteSupport.SPARK_ROW_SCHEMA, - ParquetTypesConverter.convertToString(schema.toAttributes)) + convertToString(schema.toAttributes)) // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata val useCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "true").toBoolean jobConf.set(SQLConf.PARQUET_CACHE_METADATA, useCache.toString) val baseRDD = - new org.apache.spark.rdd.NewHadoopRDD( + new NewHadoopRDD( sparkContext, classOf[FilteringParquetRowInputFormat], classOf[Void], @@ -227,66 +484,506 @@ case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext) val cacheMetadata = useCache @transient - val cachedStatus = selectedPartitions.flatMap(_.files) + val cachedStatus = selectedFiles + + @transient + val cachedFooters = selectedFooters // Overridden so we can inject our own cached files statuses. override def getPartitions: Array[SparkPartition] = { - val inputFormat = - if (cacheMetadata) { - new FilteringParquetRowInputFormat { - override def listStatus(jobContext: JobContext): JList[FileStatus] = cachedStatus - } - } else { - new FilteringParquetRowInputFormat - } + val inputFormat = if (cacheMetadata) { + new FilteringParquetRowInputFormat { + override def listStatus(jobContext: JobContext): JList[FileStatus] = cachedStatus - inputFormat match { - case configurable: Configurable => - configurable.setConf(getConf) - case _ => + override def getFooters(jobContext: JobContext): JList[Footer] = cachedFooters + } + } else { + new FilteringParquetRowInputFormat } + val jobContext = newJobContext(getConf, jobId) - val rawSplits = inputFormat.getSplits(jobContext).toArray - val result = new Array[SparkPartition](rawSplits.size) - for (i <- 0 until rawSplits.size) { - result(i) = - new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + val rawSplits = inputFormat.getSplits(jobContext) + + Array.tabulate[SparkPartition](rawSplits.size) { i => + new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) } - result } } - // The ordinal for the partition key in the result row, if requested. - val partitionKeyLocation = - partitionKeys - .headOption - .map(requiredColumns.indexOf(_)) - .getOrElse(-1) + // The ordinals for partition keys in the result row, if requested. + val partitionKeyLocations = partitionColumns.fieldNames.zipWithIndex.map { + case (name, index) => index -> requiredColumns.indexOf(name) + }.toMap.filter { + case (_, index) => index >= 0 + } // When the data does not include the key and the key is requested then we must fill it in // based on information from the input split. - if (!dataIncludesKey && partitionKeyLocation != -1) { - baseRDD.mapPartitionsWithInputSplit { case (split, iter) => - val partValue = "([^=]+)=([^=]+)".r - val partValues = - split.asInstanceOf[parquet.hadoop.ParquetInputSplit] - .getPath - .toString - .split("/") - .flatMap { - case partValue(key, value) => Some(key -> value) - case _ => None - }.toMap - - val currentValue = partValues.values.head.toInt - iter.map { pair => - val res = pair._2.asInstanceOf[SpecificMutableRow] - res.setInt(partitionKeyLocation, currentValue) - res + if (!partitionKeysIncludedInDataSchema && partitionKeyLocations.nonEmpty) { + // This check is based on CatalystConverter.createRootConverter. + val primitiveRow = + requestedSchema.forall(a => ParquetTypesConverter.isPrimitiveType(a.dataType)) + + baseRDD.mapPartitionsWithInputSplit { case (split: ParquetInputSplit, iterator) => + val partValues = selectedPartitions.collectFirst { + case p if split.getPath.getParent.toString == p.path => p.values + }.get + + val requiredPartOrdinal = partitionKeyLocations.keys.toSeq + + if (primitiveRow) { + iterator.map { pair => + // We are using CatalystPrimitiveRowConverter and it returns a SpecificMutableRow. + val row = pair._2.asInstanceOf[SpecificMutableRow] + var i = 0 + while (i < requiredPartOrdinal.size) { + // TODO Avoids boxing cost here! + val partOrdinal = requiredPartOrdinal(i) + row.update(partitionKeyLocations(partOrdinal), partValues(partOrdinal)) + i += 1 + } + row + } + } else { + // Create a mutable row since we need to fill in values from partition columns. + val mutableRow = new GenericMutableRow(requestedSchema.size) + iterator.map { pair => + // We are using CatalystGroupConverter and it returns a GenericRow. + // Since GenericRow is not mutable, we just cast it to a Row. + val row = pair._2.asInstanceOf[Row] + var i = 0 + while (i < row.size) { + // TODO Avoids boxing cost here! + mutableRow(i) = row(i) + i += 1 + } + + i = 0 + while (i < requiredPartOrdinal.size) { + // TODO Avoids boxing cost here! + val partOrdinal = requiredPartOrdinal(i) + mutableRow.update(partitionKeyLocations(partOrdinal), partValues(partOrdinal)) + i += 1 + } + mutableRow + } } } } else { baseRDD.map(_._2) } } + + private def prunePartitions( + predicates: Seq[Expression], + partitions: Seq[Partition]): Seq[Partition] = { + val partitionColumnNames = partitionColumns.map(_.name).toSet + val partitionPruningPredicates = predicates.filter { + _.references.map(_.name).toSet.subsetOf(partitionColumnNames) + } + + val rawPredicate = + partitionPruningPredicates.reduceOption(expressions.And).getOrElse(Literal(true)) + val boundPredicate = InterpretedPredicate(rawPredicate transform { + case a: AttributeReference => + val index = partitionColumns.indexWhere(a.name == _.name) + BoundReference(index, partitionColumns(index).dataType, nullable = true) + }) + + if (isPartitioned && partitionPruningPredicates.nonEmpty) { + partitions.filter(p => boundPredicate(p.values)) + } else { + partitions + } + } + + override def insert(data: DataFrame, overwrite: Boolean): Unit = { + assert(paths.size == 1, s"Can't write to multiple destinations: ${paths.mkString(",")}") + + // TODO: currently we do not check whether the "schema"s are compatible + // That means if one first creates a table and then INSERTs data with + // and incompatible schema the execution will fail. It would be nice + // to catch this early one, maybe having the planner validate the schema + // before calling execute(). + + val job = new Job(sqlContext.sparkContext.hadoopConfiguration) + val writeSupport = if (parquetSchema.map(_.dataType).forall(_.isPrimitive)) { + log.debug("Initializing MutableRowWriteSupport") + classOf[MutableRowWriteSupport] + } else { + classOf[RowWriteSupport] + } + + ParquetOutputFormat.setWriteSupportClass(job, writeSupport) + + val conf = ContextUtil.getConfiguration(job) + RowWriteSupport.setSchema(data.schema.toAttributes, conf) + + val destinationPath = new Path(paths.head) + + if (overwrite) { + val fs = destinationPath.getFileSystem(conf) + if (fs.exists(destinationPath)) { + var success: Boolean = false + try { + success = fs.delete(destinationPath, true) + } catch { + case e: IOException => + throw new IOException( + s"Unable to clear output directory ${destinationPath.toString} prior" + + s" to writing to Parquet table:\n${e.toString}") + } + if (!success) { + throw new IOException( + s"Unable to clear output directory ${destinationPath.toString} prior" + + s" to writing to Parquet table.") + } + } + } + + job.setOutputKeyClass(classOf[Void]) + job.setOutputValueClass(classOf[Row]) + FileOutputFormat.setOutputPath(job, destinationPath) + + val wrappedConf = new SerializableWritable(job.getConfiguration) + val jobTrackerId = new SimpleDateFormat("yyyyMMddHHmm").format(new Date()) + val stageId = sqlContext.sparkContext.newRddId() + + val taskIdOffset = if (overwrite) { + 1 + } else { + FileSystemHelper.findMaxTaskId( + FileOutputFormat.getOutputPath(job).toString, job.getConfiguration) + 1 + } + + def writeShard(context: TaskContext, iterator: Iterator[Row]): Unit = { + /* "reduce task" */ + val attemptId = newTaskAttemptID( + jobTrackerId, stageId, isMap = false, context.partitionId(), context.attemptNumber()) + val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) + val format = new AppendingParquetOutputFormat(taskIdOffset) + val committer = format.getOutputCommitter(hadoopContext) + committer.setupTask(hadoopContext) + val writer = format.getRecordWriter(hadoopContext) + try { + while (iterator.hasNext) { + val row = iterator.next() + writer.write(null, row) + } + } finally { + writer.close(hadoopContext) + } + + SparkHadoopMapRedUtil.commitTask(committer, hadoopContext, context) + } + val jobFormat = new AppendingParquetOutputFormat(taskIdOffset) + /* apparently we need a TaskAttemptID to construct an OutputCommitter; + * however we're only going to use this local OutputCommitter for + * setupJob/commitJob, so we just use a dummy "map" task. + */ + val jobAttemptId = newTaskAttemptID(jobTrackerId, stageId, isMap = true, 0, 0) + val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) + val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) + + jobCommitter.setupJob(jobTaskContext) + sqlContext.sparkContext.runJob(data.queryExecution.executedPlan.execute(), writeShard _) + jobCommitter.commitJob(jobTaskContext) + + metadataCache.refresh() + } +} + +private[sql] object ParquetRelation2 extends Logging { + // Whether we should merge schemas collected from all Parquet part-files. + val MERGE_SCHEMA = "mergeSchema" + + // Default partition name to use when the partition column value is null or empty string. + val DEFAULT_PARTITION_NAME = "partition.defaultName" + + // Hive Metastore schema, used when converting Metastore Parquet tables. This option is only used + // internally. + private[sql] val METASTORE_SCHEMA = "metastoreSchema" + + private[parquet] def readSchema( + footers: Seq[Footer], sqlContext: SQLContext): Option[StructType] = { + footers.map { footer => + val metadata = footer.getParquetMetadata.getFileMetaData + val parquetSchema = metadata.getSchema + val maybeSparkSchema = metadata + .getKeyValueMetaData + .toMap + .get(RowReadSupport.SPARK_METADATA_KEY) + .flatMap { serializedSchema => + // Don't throw even if we failed to parse the serialized Spark schema. Just fallback to + // whatever is available. + Try(DataType.fromJson(serializedSchema)) + .recover { case _: Throwable => + logInfo( + s"Serialized Spark schema in Parquet key-value metadata is not in JSON format, " + + "falling back to the deprecated DataType.fromCaseClassString parser.") + DataType.fromCaseClassString(serializedSchema) + } + .recover { case cause: Throwable => + logWarning( + s"""Failed to parse serialized Spark schema in Parquet key-value metadata: + |\t$serializedSchema + """.stripMargin, + cause) + } + .map(_.asInstanceOf[StructType]) + .toOption + } + + maybeSparkSchema.getOrElse { + // Falls back to Parquet schema if Spark SQL schema is absent. + StructType.fromAttributes( + // TODO Really no need to use `Attribute` here, we only need to know the data type. + convertToAttributes( + parquetSchema, + sqlContext.conf.isParquetBinaryAsString, + sqlContext.conf.isParquetINT96AsTimestamp)) + } + }.reduceOption { (left, right) => + try left.merge(right) catch { case e: Throwable => + throw new SparkException(s"Failed to merge incompatible schemas $left and $right", e) + } + } + } + + /** + * Reconciles Hive Metastore case insensitivity issue and data type conflicts between Metastore + * schema and Parquet schema. + * + * Hive doesn't retain case information, while Parquet is case sensitive. On the other hand, the + * schema read from Parquet files may be incomplete (e.g. older versions of Parquet doesn't + * distinguish binary and string). This method generates a correct schema by merging Metastore + * schema data types and Parquet schema field names. + */ + private[parquet] def mergeMetastoreParquetSchema( + metastoreSchema: StructType, + parquetSchema: StructType): StructType = { + def schemaConflictMessage: String = + s"""Converting Hive Metastore Parquet, but detected conflicting schemas. Metastore schema: + |${metastoreSchema.prettyJson} + | + |Parquet schema: + |${parquetSchema.prettyJson} + """.stripMargin + + val mergedParquetSchema = mergeMissingNullableFields(metastoreSchema, parquetSchema) + + assert(metastoreSchema.size <= mergedParquetSchema.size, schemaConflictMessage) + + val ordinalMap = metastoreSchema.zipWithIndex.map { + case (field, index) => field.name.toLowerCase -> index + }.toMap + val reorderedParquetSchema = mergedParquetSchema.sortBy(f => + ordinalMap.getOrElse(f.name.toLowerCase, metastoreSchema.size + 1)) + + StructType(metastoreSchema.zip(reorderedParquetSchema).map { + // Uses Parquet field names but retains Metastore data types. + case (mSchema, pSchema) if mSchema.name.toLowerCase == pSchema.name.toLowerCase => + mSchema.copy(name = pSchema.name) + case _ => + throw new SparkException(schemaConflictMessage) + }) + } + + /** + * Returns the original schema from the Parquet file with any missing nullable fields from the + * Hive Metastore schema merged in. + * + * When constructing a DataFrame from a collection of structured data, the resulting object has + * a schema corresponding to the union of the fields present in each element of the collection. + * Spark SQL simply assigns a null value to any field that isn't present for a particular row. + * In some cases, it is possible that a given table partition stored as a Parquet file doesn't + * contain a particular nullable field in its schema despite that field being present in the + * table schema obtained from the Hive Metastore. This method returns a schema representing the + * Parquet file schema along with any additional nullable fields from the Metastore schema + * merged in. + */ + private[parquet] def mergeMissingNullableFields( + metastoreSchema: StructType, + parquetSchema: StructType): StructType = { + val fieldMap = metastoreSchema.map(f => f.name.toLowerCase -> f).toMap + val missingFields = metastoreSchema + .map(_.name.toLowerCase) + .diff(parquetSchema.map(_.name.toLowerCase)) + .map(fieldMap(_)) + .filter(_.nullable) + StructType(parquetSchema ++ missingFields) + } + + + // TODO Data source implementations shouldn't touch Catalyst types (`Literal`). + // However, we are already using Catalyst expressions for partition pruning and predicate + // push-down here... + private[parquet] case class PartitionValues(columnNames: Seq[String], literals: Seq[Literal]) { + require(columnNames.size == literals.size) + } + + /** + * Given a group of qualified paths, tries to parse them and returns a partition specification. + * For example, given: + * {{{ + * hdfs://:/path/to/partition/a=1/b=hello/c=3.14 + * hdfs://:/path/to/partition/a=2/b=world/c=6.28 + * }}} + * it returns: + * {{{ + * PartitionSpec( + * partitionColumns = StructType( + * StructField(name = "a", dataType = IntegerType, nullable = true), + * StructField(name = "b", dataType = StringType, nullable = true), + * StructField(name = "c", dataType = DoubleType, nullable = true)), + * partitions = Seq( + * Partition( + * values = Row(1, "hello", 3.14), + * path = "hdfs://:/path/to/partition/a=1/b=hello/c=3.14"), + * Partition( + * values = Row(2, "world", 6.28), + * path = "hdfs://:/path/to/partition/a=2/b=world/c=6.28"))) + * }}} + */ + private[parquet] def parsePartitions( + paths: Seq[Path], + defaultPartitionName: String): PartitionSpec = { + val partitionValues = resolvePartitions(paths.map(parsePartition(_, defaultPartitionName))) + val fields = { + val (PartitionValues(columnNames, literals)) = partitionValues.head + columnNames.zip(literals).map { case (name, Literal(_, dataType)) => + StructField(name, dataType, nullable = true) + } + } + + val partitions = partitionValues.zip(paths).map { + case (PartitionValues(_, literals), path) => + Partition(Row(literals.map(_.value): _*), path.toString) + } + + PartitionSpec(StructType(fields), partitions) + } + + /** + * Parses a single partition, returns column names and values of each partition column. For + * example, given: + * {{{ + * path = hdfs://:/path/to/partition/a=42/b=hello/c=3.14 + * }}} + * it returns: + * {{{ + * PartitionValues( + * Seq("a", "b", "c"), + * Seq( + * Literal(42, IntegerType), + * Literal("hello", StringType), + * Literal(3.14, FloatType))) + * }}} + */ + private[parquet] def parsePartition( + path: Path, + defaultPartitionName: String): PartitionValues = { + val columns = ArrayBuffer.empty[(String, Literal)] + // Old Hadoop versions don't have `Path.isRoot` + var finished = path.getParent == null + var chopped = path + + while (!finished) { + val maybeColumn = parsePartitionColumn(chopped.getName, defaultPartitionName) + maybeColumn.foreach(columns += _) + chopped = chopped.getParent + finished = maybeColumn.isEmpty || chopped.getParent == null + } + + val (columnNames, values) = columns.reverse.unzip + PartitionValues(columnNames, values) + } + + private def parsePartitionColumn( + columnSpec: String, + defaultPartitionName: String): Option[(String, Literal)] = { + val equalSignIndex = columnSpec.indexOf('=') + if (equalSignIndex == -1) { + None + } else { + val columnName = columnSpec.take(equalSignIndex) + assert(columnName.nonEmpty, s"Empty partition column name in '$columnSpec'") + + val rawColumnValue = columnSpec.drop(equalSignIndex + 1) + assert(rawColumnValue.nonEmpty, s"Empty partition column value in '$columnSpec'") + + val literal = inferPartitionColumnValue(rawColumnValue, defaultPartitionName) + Some(columnName -> literal) + } + } + + /** + * Resolves possible type conflicts between partitions by up-casting "lower" types. The up- + * casting order is: + * {{{ + * NullType -> + * IntegerType -> LongType -> + * FloatType -> DoubleType -> DecimalType.Unlimited -> + * StringType + * }}} + */ + private[parquet] def resolvePartitions(values: Seq[PartitionValues]): Seq[PartitionValues] = { + // Column names of all partitions must match + val distinctPartitionsColNames = values.map(_.columnNames).distinct + assert(distinctPartitionsColNames.size == 1, { + val list = distinctPartitionsColNames.mkString("\t", "\n", "") + s"Conflicting partition column names detected:\n$list" + }) + + // Resolves possible type conflicts for each column + val columnCount = values.head.columnNames.size + val resolvedValues = (0 until columnCount).map { i => + resolveTypeConflicts(values.map(_.literals(i))) + } + + // Fills resolved literals back to each partition + values.zipWithIndex.map { case (d, index) => + d.copy(literals = resolvedValues.map(_(index))) + } + } + + /** + * Converts a string to a `Literal` with automatic type inference. Currently only supports + * [[IntegerType]], [[LongType]], [[FloatType]], [[DoubleType]], [[DecimalType.Unlimited]], and + * [[StringType]]. + */ + private[parquet] def inferPartitionColumnValue( + raw: String, + defaultPartitionName: String): Literal = { + // First tries integral types + Try(Literal(Integer.parseInt(raw), IntegerType)) + .orElse(Try(Literal(JLong.parseLong(raw), LongType))) + // Then falls back to fractional types + .orElse(Try(Literal(JFloat.parseFloat(raw), FloatType))) + .orElse(Try(Literal(JDouble.parseDouble(raw), DoubleType))) + .orElse(Try(Literal(new JBigDecimal(raw), DecimalType.Unlimited))) + // Then falls back to string + .getOrElse { + if (raw == defaultPartitionName) Literal(null, NullType) else Literal(raw, StringType) + } + } + + private val upCastingOrder: Seq[DataType] = + Seq(NullType, IntegerType, LongType, FloatType, DoubleType, DecimalType.Unlimited, StringType) + + /** + * Given a collection of [[Literal]]s, resolves possible type conflicts by up-casting "lower" + * types. + */ + private def resolveTypeConflicts(literals: Seq[Literal]): Seq[Literal] = { + val desiredType = { + val topType = literals.map(_.dataType).maxBy(upCastingOrder.indexOf(_)) + // Falls back to string if all values of this column are null or empty string + if (topType == NullType) StringType else topType + } + + literals.map { case l @ Literal(_, dataType) => + Literal(Cast(l, desiredType).eval(), desiredType) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala new file mode 100644 index 0000000000000..70bcca7526aae --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet.timestamp + +import java.nio.{ByteBuffer, ByteOrder} + +import parquet.Preconditions +import parquet.io.api.{Binary, RecordConsumer} + +private[parquet] class NanoTime extends Serializable { + private var julianDay = 0 + private var timeOfDayNanos = 0L + + def set(julianDay: Int, timeOfDayNanos: Long): this.type = { + this.julianDay = julianDay + this.timeOfDayNanos = timeOfDayNanos + this + } + + def getJulianDay: Int = julianDay + + def getTimeOfDayNanos: Long = timeOfDayNanos + + def toBinary: Binary = { + val buf = ByteBuffer.allocate(12) + buf.order(ByteOrder.LITTLE_ENDIAN) + buf.putLong(timeOfDayNanos) + buf.putInt(julianDay) + buf.flip() + Binary.fromByteBuffer(buf) + } + + def writeValue(recordConsumer: RecordConsumer): Unit = { + recordConsumer.addBinary(toBinary) + } + + override def toString: String = + "NanoTime{julianDay=" + julianDay + ", timeOfDayNanos=" + timeOfDayNanos + "}" +} + +private[sql] object NanoTime { + def fromBinary(bytes: Binary): NanoTime = { + Preconditions.checkArgument(bytes.length() == 12, "Must be 12 bytes") + val buf = bytes.toByteBuffer + buf.order(ByteOrder.LITTLE_ENDIAN) + val timeOfDayNanos = buf.getLong + val julianDay = buf.getInt + new NanoTime().set(julianDay, timeOfDayNanos) + } + + def apply(julianDay: Int, timeOfDayNanos: Long): NanoTime = { + new NanoTime().set(julianDay, timeOfDayNanos) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index 386ff2452f1a3..e13759b7feb7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -18,12 +18,13 @@ package org.apache.spark.sql.sources import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, Strategy} import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, AttributeSet, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, InsertIntoTable => LogicalInsertIntoTable} -import org.apache.spark.sql.execution +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.{Row, Strategy, execution, sources} /** * A Strategy for planning scans over data sources defined using the sources API. @@ -54,12 +55,9 @@ private[sql] object DataSourceStrategy extends Strategy { case l @ LogicalRelation(t: TableScan) => execution.PhysicalRDD(l.output, t.buildScan()) :: Nil - case i @ LogicalInsertIntoTable( - l @ LogicalRelation(t: InsertableRelation), partition, query, overwrite) => - if (partition.nonEmpty) { - sys.error(s"Insert into a partition is not allowed because $l is not partitioned.") - } - execution.ExecutedCommand(InsertIntoRelation(t, query, overwrite)) :: Nil + case i @ logical.InsertIntoTable( + l @ LogicalRelation(t: InsertableRelation), part, query, overwrite) if part.isEmpty => + execution.ExecutedCommand(InsertIntoDataSource(l, query, overwrite)) :: Nil case _ => Nil } @@ -88,7 +86,7 @@ private[sql] object DataSourceStrategy extends Strategy { val projectSet = AttributeSet(projectList.flatMap(_.references)) val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) - val filterCondition = filterPredicates.reduceLeftOption(And) + val filterCondition = filterPredicates.reduceLeftOption(expressions.And) val pushedFilters = filterPredicates.map { _ transform { case a: AttributeReference => relation.attributeMap(a) // Match original case of attributes. @@ -118,27 +116,69 @@ private[sql] object DataSourceStrategy extends Strategy { } } - /** Turn Catalyst [[Expression]]s into data source [[Filter]]s. */ - protected[sql] def selectFilters(filters: Seq[Expression]): Seq[Filter] = filters.collect { - case expressions.EqualTo(a: Attribute, expressions.Literal(v, _)) => EqualTo(a.name, v) - case expressions.EqualTo(expressions.Literal(v, _), a: Attribute) => EqualTo(a.name, v) - - case expressions.GreaterThan(a: Attribute, expressions.Literal(v, _)) => GreaterThan(a.name, v) - case expressions.GreaterThan(expressions.Literal(v, _), a: Attribute) => LessThan(a.name, v) - - case expressions.LessThan(a: Attribute, expressions.Literal(v, _)) => LessThan(a.name, v) - case expressions.LessThan(expressions.Literal(v, _), a: Attribute) => GreaterThan(a.name, v) - - case expressions.GreaterThanOrEqual(a: Attribute, expressions.Literal(v, _)) => - GreaterThanOrEqual(a.name, v) - case expressions.GreaterThanOrEqual(expressions.Literal(v, _), a: Attribute) => - LessThanOrEqual(a.name, v) - - case expressions.LessThanOrEqual(a: Attribute, expressions.Literal(v, _)) => - LessThanOrEqual(a.name, v) - case expressions.LessThanOrEqual(expressions.Literal(v, _), a: Attribute) => - GreaterThanOrEqual(a.name, v) + /** + * Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s, + * and convert them. + */ + protected[sql] def selectFilters(filters: Seq[Expression]) = { + def translate(predicate: Expression): Option[Filter] = predicate match { + case expressions.EqualTo(a: Attribute, Literal(v, _)) => + Some(sources.EqualTo(a.name, v)) + case expressions.EqualTo(Literal(v, _), a: Attribute) => + Some(sources.EqualTo(a.name, v)) + + case expressions.GreaterThan(a: Attribute, Literal(v, _)) => + Some(sources.GreaterThan(a.name, v)) + case expressions.GreaterThan(Literal(v, _), a: Attribute) => + Some(sources.LessThan(a.name, v)) + + case expressions.LessThan(a: Attribute, Literal(v, _)) => + Some(sources.LessThan(a.name, v)) + case expressions.LessThan(Literal(v, _), a: Attribute) => + Some(sources.GreaterThan(a.name, v)) + + case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) => + Some(sources.GreaterThanOrEqual(a.name, v)) + case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) => + Some(sources.LessThanOrEqual(a.name, v)) + + case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => + Some(sources.LessThanOrEqual(a.name, v)) + case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => + Some(sources.GreaterThanOrEqual(a.name, v)) + + case expressions.InSet(a: Attribute, set) => + Some(sources.In(a.name, set.toArray)) + + case expressions.IsNull(a: Attribute) => + Some(sources.IsNull(a.name)) + case expressions.IsNotNull(a: Attribute) => + Some(sources.IsNotNull(a.name)) + + case expressions.And(left, right) => + (translate(left) ++ translate(right)).reduceOption(sources.And) + + case expressions.Or(left, right) => + for { + leftFilter <- translate(left) + rightFilter <- translate(right) + } yield sources.Or(leftFilter, rightFilter) + + case expressions.Not(child) => + translate(child).map(sources.Not) + + case expressions.StartsWith(a: Attribute, Literal(v: String, StringType)) => + Some(sources.StringStartsWith(a.name, v)) + + case expressions.EndsWith(a: Attribute, Literal(v: String, StringType)) => + Some(sources.StringEndsWith(a.name, v)) + + case expressions.Contains(a: Attribute, Literal(v: String, StringType)) => + Some(sources.StringContains(a.name, v)) + + case _ => None + } - case expressions.InSet(a: Attribute, set) => In(a.name, set.toArray) + filters.flatMap(translate) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala index 12b59ba20bb10..f374abffdd505 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala @@ -30,24 +30,28 @@ private[sql] case class LogicalRelation(relation: BaseRelation) override val output: Seq[AttributeReference] = relation.schema.toAttributes // Logical Relations are distinct if they have different output for the sake of transformations. - override def equals(other: Any) = other match { + override def equals(other: Any): Boolean = other match { case l @ LogicalRelation(otherRelation) => relation == otherRelation && output == l.output case _ => false } - override def sameResult(otherPlan: LogicalPlan) = otherPlan match { + override def hashCode: Int = { + com.google.common.base.Objects.hashCode(relation, output) + } + + override def sameResult(otherPlan: LogicalPlan): Boolean = otherPlan match { case LogicalRelation(otherRelation) => relation == otherRelation case _ => false } - @transient override lazy val statistics = Statistics( + @transient override lazy val statistics: Statistics = Statistics( sizeInBytes = BigInt(relation.sizeInBytes) ) /** Used to lookup original attribute capitalization */ - val attributeMap = AttributeMap(output.map(o => (o, o))) + val attributeMap: AttributeMap[AttributeReference] = AttributeMap(output.map(o => (o, o))) - def newInstance() = LogicalRelation(relation).asInstanceOf[this.type] + def newInstance(): this.type = LogicalRelation(relation).asInstanceOf[this.type] - override def simpleString = s"Relation[${output.mkString(",")}] $relation" + override def simpleString: String = s"Relation[${output.mkString(",")}] $relation" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index d7942dc30934b..dbdb0d39c26a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -21,14 +21,22 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.RunnableCommand -private[sql] case class InsertIntoRelation( - relation: InsertableRelation, +private[sql] case class InsertIntoDataSource( + logicalRelation: LogicalRelation, query: LogicalPlan, overwrite: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { - relation.insert(DataFrame(sqlContext, query), overwrite) + override def run(sqlContext: SQLContext): Seq[Row] = { + val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] + val data = DataFrame(sqlContext, query) + // Apply the schema of the existing table to the new data. + val df = sqlContext.createDataFrame( + data.queryExecution.toRdd, logicalRelation.schema, needsConversion = false) + relation.insert(df, overwrite) + + // Invalidate the cache. + sqlContext.cacheManager.invalidateCache(logicalRelation) Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index ead827728cf4b..a53abb992b443 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -20,9 +20,11 @@ package org.apache.spark.sql.sources import scala.language.implicitConversions import org.apache.spark.Logging -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{AnalysisException, SaveMode, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.AbstractSparkSQLParser +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row} import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -30,7 +32,9 @@ import org.apache.spark.util.Utils /** * A parser for foreign DDL commands. */ -private[sql] class DDLParser extends AbstractSparkSQLParser with Logging { +private[sql] class DDLParser( + parseQuery: String => LogicalPlan) + extends AbstractSparkSQLParser with DataTypeParser with Logging { def apply(input: String, exceptionOnError: Boolean): Option[LogicalPlan] = { try { @@ -42,15 +46,6 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging { } } - def parseType(input: String): DataType = { - lexical.initialize(reservedWords) - phrase(dataType)(new lexical.Scanner(input)) match { - case Success(r, x) => r - case x => throw new DDLException(s"Unsupported dataType: $x") - } - } - - // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` // properties via reflection the class in runtime for constructing the SqlLexical object protected val CREATE = Keyword("CREATE") @@ -61,28 +56,13 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging { protected val EXISTS = Keyword("EXISTS") protected val USING = Keyword("USING") protected val OPTIONS = Keyword("OPTIONS") + protected val DESCRIBE = Keyword("DESCRIBE") + protected val EXTENDED = Keyword("EXTENDED") protected val AS = Keyword("AS") protected val COMMENT = Keyword("COMMENT") + protected val REFRESH = Keyword("REFRESH") - // Data types. - protected val STRING = Keyword("STRING") - protected val BINARY = Keyword("BINARY") - protected val BOOLEAN = Keyword("BOOLEAN") - protected val TINYINT = Keyword("TINYINT") - protected val SMALLINT = Keyword("SMALLINT") - protected val INT = Keyword("INT") - protected val BIGINT = Keyword("BIGINT") - protected val FLOAT = Keyword("FLOAT") - protected val DOUBLE = Keyword("DOUBLE") - protected val DECIMAL = Keyword("DECIMAL") - protected val DATE = Keyword("DATE") - protected val TIMESTAMP = Keyword("TIMESTAMP") - protected val VARCHAR = Keyword("VARCHAR") - protected val ARRAY = Keyword("ARRAY") - protected val MAP = Keyword("MAP") - protected val STRUCT = Keyword("STRUCT") - - protected lazy val ddl: Parser[LogicalPlan] = createTable + protected lazy val ddl: Parser[LogicalPlan] = createTable | describeTable | refreshTable protected def start: Parser[LogicalPlan] = ddl @@ -101,26 +81,37 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging { * AS SELECT ... */ protected lazy val createTable: Parser[LogicalPlan] = - ( - (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ ident - ~ (tableCols).? ~ (USING ~> className) ~ (OPTIONS ~> options) ~ (AS ~> restInput).? ^^ { + // TODO: Support database.table. + (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ ident ~ + tableCols.? ~ (USING ~> className) ~ (OPTIONS ~> options).? ~ (AS ~> restInput).? ^^ { case temp ~ allowExisting ~ tableName ~ columns ~ provider ~ opts ~ query => if (temp.isDefined && allowExisting.isDefined) { throw new DDLException( "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.") } + val options = opts.getOrElse(Map.empty[String, String]) if (query.isDefined) { if (columns.isDefined) { throw new DDLException( "a CREATE TABLE AS SELECT statement does not allow column definitions.") } + // When IF NOT EXISTS clause appears in the query, the save mode will be ignore. + val mode = if (allowExisting.isDefined) { + SaveMode.Ignore + } else if (temp.isDefined) { + SaveMode.Overwrite + } else { + SaveMode.ErrorIfExists + } + + val queryPlan = parseQuery(query.get) CreateTableUsingAsSelect(tableName, provider, temp.isDefined, - opts, - allowExisting.isDefined, - query.get) + mode, + options, + queryPlan) } else { val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields))) CreateTableUsing( @@ -128,14 +119,36 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging { userSpecifiedSchema, provider, temp.isDefined, - opts, - allowExisting.isDefined) + options, + allowExisting.isDefined, + managedIfNoPath = false) } - } - ) + } protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" + /* + * describe [extended] table avroTable + * This will display all columns of table `avroTable` includes column_name,column_type,nullable + */ + protected lazy val describeTable: Parser[LogicalPlan] = + (DESCRIBE ~> opt(EXTENDED)) ~ (ident <~ ".").? ~ ident ^^ { + case e ~ db ~ tbl => + val tblIdentifier = db match { + case Some(dbName) => + Seq(dbName, tbl) + case None => + Seq(tbl) + } + DescribeCommand(UnresolvedRelation(tblIdentifier, None), e.isDefined) + } + + protected lazy val refreshTable: Parser[LogicalPlan] = + REFRESH ~> TABLE ~> (ident <~ ".").? ~ ident ^^ { + case maybeDatabaseName ~ tableName => + RefreshTable(maybeDatabaseName.getOrElse("default"), tableName) + } + protected lazy val options: Parser[Map[String, String]] = "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } @@ -147,191 +160,188 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging { ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm => val meta = cm match { case Some(comment) => - new MetadataBuilder().putString(COMMENT.str.toLowerCase(), comment).build() + new MetadataBuilder().putString(COMMENT.str.toLowerCase, comment).build() case None => Metadata.empty } - StructField(columnName, typ, true, meta) - } - protected lazy val primitiveType: Parser[DataType] = - STRING ^^^ StringType | - BINARY ^^^ BinaryType | - BOOLEAN ^^^ BooleanType | - TINYINT ^^^ ByteType | - SMALLINT ^^^ ShortType | - INT ^^^ IntegerType | - BIGINT ^^^ LongType | - FLOAT ^^^ FloatType | - DOUBLE ^^^ DoubleType | - fixedDecimalType | // decimal with precision/scale - DECIMAL ^^^ DecimalType.Unlimited | // decimal with no precision/scale - DATE ^^^ DateType | - TIMESTAMP ^^^ TimestampType | - VARCHAR ~ "(" ~ numericLit ~ ")" ^^^ StringType - - protected lazy val fixedDecimalType: Parser[DataType] = - (DECIMAL ~ "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ { - case precision ~ scale => DecimalType(precision.toInt, scale.toInt) + StructField(columnName, typ, nullable = true, meta) } +} - protected lazy val arrayType: Parser[DataType] = - ARRAY ~> "<" ~> dataType <~ ">" ^^ { - case tpe => ArrayType(tpe) - } +private[sql] object ResolvedDataSource { - protected lazy val mapType: Parser[DataType] = - MAP ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ { - case t1 ~ _ ~ t2 => MapType(t1, t2) - } + private val builtinSources = Map( + "jdbc" -> classOf[org.apache.spark.sql.jdbc.DefaultSource], + "json" -> classOf[org.apache.spark.sql.json.DefaultSource], + "parquet" -> classOf[org.apache.spark.sql.parquet.DefaultSource] + ) - protected lazy val structField: Parser[StructField] = - ident ~ ":" ~ dataType ^^ { - case fieldName ~ _ ~ tpe => StructField(fieldName, tpe, nullable = true) + /** Given a provider name, look up the data source class definition. */ + def lookupDataSource(provider: String): Class[_] = { + if (builtinSources.contains(provider)) { + return builtinSources(provider) } - protected lazy val structType: Parser[DataType] = - (STRUCT ~> "<" ~> repsep(structField, ",") <~ ">" ^^ { - case fields => StructType(fields) - }) | - (STRUCT ~> "<>" ^^ { - case fields => StructType(Nil) - }) - - private[sql] lazy val dataType: Parser[DataType] = - arrayType | - mapType | - structType | - primitiveType -} - -object ResolvedDataSource { - def apply( - sqlContext: SQLContext, - userSpecifiedSchema: Option[StructType], - provider: String, - options: Map[String, String]): ResolvedDataSource = { val loader = Utils.getContextOrSparkClassLoader - val clazz: Class[_] = try loader.loadClass(provider) catch { + try { + loader.loadClass(provider) + } catch { case cnf: java.lang.ClassNotFoundException => - try loader.loadClass(provider + ".DefaultSource") catch { + try { + loader.loadClass(provider + ".DefaultSource") + } catch { case cnf: java.lang.ClassNotFoundException => sys.error(s"Failed to load class for data source: $provider") } } + } + /** Create a [[ResolvedDataSource]] for reading data in. */ + def apply( + sqlContext: SQLContext, + userSpecifiedSchema: Option[StructType], + provider: String, + options: Map[String, String]): ResolvedDataSource = { + val clazz: Class[_] = lookupDataSource(provider) + def className = clazz.getCanonicalName val relation = userSpecifiedSchema match { - case Some(schema: StructType) => { - clazz.newInstance match { - case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => - dataSource - .asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider] - .createRelation(sqlContext, new CaseInsensitiveMap(options), schema) - case dataSource: org.apache.spark.sql.sources.RelationProvider => - sys.error(s"${clazz.getCanonicalName} does not allow user-specified schemas.") - } + case Some(schema: StructType) => clazz.newInstance() match { + case dataSource: SchemaRelationProvider => + dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema) + case dataSource: org.apache.spark.sql.sources.RelationProvider => + throw new AnalysisException(s"$className does not allow user-specified schemas.") + case _ => + throw new AnalysisException(s"$className is not a RelationProvider.") } - case None => { - clazz.newInstance match { - case dataSource: org.apache.spark.sql.sources.RelationProvider => - dataSource - .asInstanceOf[org.apache.spark.sql.sources.RelationProvider] - .createRelation(sqlContext, new CaseInsensitiveMap(options)) - case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => - sys.error(s"A schema needs to be specified when using ${clazz.getCanonicalName}.") - } + + case None => clazz.newInstance() match { + case dataSource: RelationProvider => + dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options)) + case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => + throw new AnalysisException( + s"A schema needs to be specified when using $className.") + case _ => + throw new AnalysisException(s"$className is not a RelationProvider.") } } - new ResolvedDataSource(clazz, relation) } + /** Create a [[ResolvedDataSource]] for saving the content of the given [[DataFrame]]. */ def apply( sqlContext: SQLContext, provider: String, + mode: SaveMode, options: Map[String, String], data: DataFrame): ResolvedDataSource = { - val loader = Utils.getContextOrSparkClassLoader - val clazz: Class[_] = try loader.loadClass(provider) catch { - case cnf: java.lang.ClassNotFoundException => - try loader.loadClass(provider + ".DefaultSource") catch { - case cnf: java.lang.ClassNotFoundException => - sys.error(s"Failed to load class for data source: $provider") - } - } - - val relation = clazz.newInstance match { - case dataSource: org.apache.spark.sql.sources.CreateableRelationProvider => - dataSource - .asInstanceOf[org.apache.spark.sql.sources.CreateableRelationProvider] - .createRelation(sqlContext, options, data) + val clazz: Class[_] = lookupDataSource(provider) + val relation = clazz.newInstance() match { + case dataSource: CreatableRelationProvider => + dataSource.createRelation(sqlContext, mode, options, data) case _ => sys.error(s"${clazz.getCanonicalName} does not allow create table as select.") } - new ResolvedDataSource(clazz, relation) } } private[sql] case class ResolvedDataSource(provider: Class[_], relation: BaseRelation) +/** + * Returned for the "DESCRIBE [EXTENDED] [dbName.]tableName" command. + * @param table The table to be described. + * @param isExtended True if "DESCRIBE EXTENDED" is used. Otherwise, false. + * It is effective only when the table is a Hive table. + */ +private[sql] case class DescribeCommand( + table: LogicalPlan, + isExtended: Boolean) extends Command { + override val output = Seq( + // Column names are based on Hive. + AttributeReference("col_name", StringType, nullable = false, + new MetadataBuilder().putString("comment", "name of the column").build())(), + AttributeReference("data_type", StringType, nullable = false, + new MetadataBuilder().putString("comment", "data type of the column").build())(), + AttributeReference("comment", StringType, nullable = false, + new MetadataBuilder().putString("comment", "comment of the column").build())()) +} + +/** + * Used to represent the operation of create table using a data source. + * @param allowExisting If it is true, we will do nothing when the table already exists. + * If it is false, an exception will be thrown + */ private[sql] case class CreateTableUsing( tableName: String, userSpecifiedSchema: Option[StructType], provider: String, temporary: Boolean, options: Map[String, String], - allowExisting: Boolean) extends Command - -private[sql] case class CreateTableUsingAsSelect( - tableName: String, - provider: String, - temporary: Boolean, - options: Map[String, String], allowExisting: Boolean, - query: String) extends Command + managedIfNoPath: Boolean) extends Command -private[sql] case class CreateTableUsingAsLogicalPlan( +/** + * A node used to support CTAS statements and saveAsTable for the data source API. + * This node is a [[UnaryNode]] instead of a [[Command]] because we want the analyzer + * can analyze the logical plan that will be used to populate the table. + * So, [[PreWriteCheck]] can detect cases that are not allowed. + */ +private[sql] case class CreateTableUsingAsSelect( tableName: String, provider: String, temporary: Boolean, + mode: SaveMode, options: Map[String, String], - allowExisting: Boolean, - query: LogicalPlan) extends Command + child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = Seq.empty[Attribute] + // TODO: Override resolved after we support databaseName. + // override lazy val resolved = databaseName != None && childrenResolved +} -private [sql] case class CreateTempTableUsing( +private[sql] case class CreateTempTableUsing( tableName: String, userSpecifiedSchema: Option[StructType], provider: String, options: Map[String, String]) extends RunnableCommand { - def run(sqlContext: SQLContext) = { + def run(sqlContext: SQLContext): Seq[Row] = { val resolved = ResolvedDataSource(sqlContext, userSpecifiedSchema, provider, options) - sqlContext.registerRDDAsTable( + sqlContext.registerDataFrameAsTable( DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) Seq.empty } } -private [sql] case class CreateTempTableUsingAsSelect( +private[sql] case class CreateTempTableUsingAsSelect( tableName: String, provider: String, + mode: SaveMode, options: Map[String, String], query: LogicalPlan) extends RunnableCommand { - def run(sqlContext: SQLContext) = { + def run(sqlContext: SQLContext): Seq[Row] = { val df = DataFrame(sqlContext, query) - val resolved = ResolvedDataSource(sqlContext, provider, options, df) - sqlContext.registerRDDAsTable( + val resolved = ResolvedDataSource(sqlContext, provider, mode, options, df) + sqlContext.registerDataFrameAsTable( DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) Seq.empty } } +private[sql] case class RefreshTable(databaseName: String, tableName: String) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + sqlContext.catalog.refreshTable(databaseName, tableName) + Seq.empty[Row] + } +} + /** * Builds a map in which keys are case insensitive */ -protected class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] +protected[sql] class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] with Serializable { val baseMap = map.map(kv => kv.copy(_1 = kv._1.toLowerCase)) @@ -343,11 +353,10 @@ protected class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, override def iterator: Iterator[(String, String)] = baseMap.iterator - override def -(key: String): Map[String, String] = baseMap - key.toLowerCase() + override def -(key: String): Map[String, String] = baseMap - key.toLowerCase } /** * The exception thrown from the DDL parser. - * @param message */ protected[sql] class DDLException(message: String) extends Exception(message) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 4a9fefc12b9ad..791046e0079d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -17,11 +17,85 @@ package org.apache.spark.sql.sources +/** + * A filter predicate for data sources. + */ abstract class Filter +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * equal to `value`. + */ case class EqualTo(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * greater than `value`. + */ case class GreaterThan(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * greater than or equal to `value`. + */ case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * less than `value`. + */ case class LessThan(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * less than or equal to `value`. + */ case class LessThanOrEqual(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to one of the values in the array. + */ case class In(attribute: String, values: Array[Any]) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to null. + */ +case class IsNull(attribute: String) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a non-null value. + */ +case class IsNotNull(attribute: String) extends Filter + +/** + * A filter that evaluates to `true` iff both `left` or `right` evaluate to `true`. + */ +case class And(left: Filter, right: Filter) extends Filter + +/** + * A filter that evaluates to `true` iff at least one of `left` or `right` evaluates to `true`. + */ +case class Or(left: Filter, right: Filter) extends Filter + +/** + * A filter that evaluates to `true` iff `child` is evaluated to `false`. + */ +case class Not(child: Filter) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to + * a string that starts with `value`. + */ +case class StringStartsWith(attribute: String, value: String) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to + * a string that starts with `value`. + */ +case class StringEndsWith(attribute: String, value: String) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to + * a string that contains the string `value`. + */ +case class StringContains(attribute: String, value: String) extends Filter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index ad0a35b91ebc2..8f9946a5a801e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -14,11 +14,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.sql.sources import org.apache.spark.annotation.{Experimental, DeveloperApi} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{SaveMode, DataFrame, Row, SQLContext} import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute} import org.apache.spark.sql.types.StructType @@ -78,22 +79,35 @@ trait SchemaRelationProvider { } @DeveloperApi -trait CreateableRelationProvider { +trait CreatableRelationProvider { + /** + * Creates a relation with the given parameters based on the contents of the given + * DataFrame. The mode specifies the expected behavior of createRelation when + * data already exists. + * Right now, there are three modes, Append, Overwrite, and ErrorIfExists. + * Append mode means that when saving a DataFrame to a data source, if data already exists, + * contents of the DataFrame are expected to be appended to existing data. + * Overwrite mode means that when saving a DataFrame to a data source, if data already exists, + * existing data is expected to be overwritten by the contents of the DataFrame. + * ErrorIfExists mode means that when saving a DataFrame to a data source, + * if data already exists, an exception is expected to be thrown. + */ def createRelation( sqlContext: SQLContext, + mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation } /** * ::DeveloperApi:: - * Represents a collection of tuples with a known schema. Classes that extend BaseRelation must - * be able to produce the schema of their data in the form of a [[StructType]] Concrete + * Represents a collection of tuples with a known schema. Classes that extend BaseRelation must + * be able to produce the schema of their data in the form of a [[StructType]]. Concrete * implementation should inherit from one of the descendant `Scan` classes, which define various * abstract methods for execution. * * BaseRelations must also define a equality function that only returns true when the two - * instances will return the same data. This equality function is used when determining when + * instances will return the same data. This equality function is used when determining when * it is safe to substitute cached results for a given relation. */ @DeveloperApi @@ -102,13 +116,16 @@ abstract class BaseRelation { def schema: StructType /** - * Returns an estimated size of this relation in bytes. This information is used by the planner + * Returns an estimated size of this relation in bytes. This information is used by the planner * to decided when it is safe to broadcast a relation and can be overridden by sources that * know the size ahead of time. By default, the system will assume that tables are too - * large to broadcast. This method will be called multiple times during query planning + * large to broadcast. This method will be called multiple times during query planning * and thus should not perform expensive operations for each invocation. + * + * Note that it is always better to overestimate size than underestimate, because underestimation + * could lead to execution plans that are suboptimal (i.e. broadcasting a very large table). */ - def sizeInBytes = sqlContext.conf.defaultSizeInBytes + def sizeInBytes: Long = sqlContext.conf.defaultSizeInBytes } /** @@ -116,7 +133,7 @@ abstract class BaseRelation { * A BaseRelation that can produce all of its tuples as an RDD of Row objects. */ @DeveloperApi -trait TableScan extends BaseRelation { +trait TableScan { def buildScan(): RDD[Row] } @@ -126,7 +143,7 @@ trait TableScan extends BaseRelation { * containing all of its tuples as Row objects. */ @DeveloperApi -trait PrunedScan extends BaseRelation { +trait PrunedScan { def buildScan(requiredColumns: Array[String]): RDD[Row] } @@ -135,29 +152,48 @@ trait PrunedScan extends BaseRelation { * A BaseRelation that can eliminate unneeded columns and filter using selected * predicates before producing an RDD containing all matching tuples as Row objects. * + * The actual filter should be the conjunction of all `filters`, + * i.e. they should be "and" together. + * * The pushed down filters are currently purely an optimization as they will all be evaluated * again. This means it is safe to use them with methods that produce false positives such * as filtering partitions based on a bloom filter. */ @DeveloperApi -trait PrunedFilteredScan extends BaseRelation { +trait PrunedFilteredScan { def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] } +/** + * ::DeveloperApi:: + * A BaseRelation that can be used to insert data into it through the insert method. + * If overwrite in insert method is true, the old data in the relation should be overwritten with + * the new data. If overwrite in insert method is false, the new data should be appended. + * + * InsertableRelation has the following three assumptions. + * 1. It assumes that the data (Rows in the DataFrame) provided to the insert method + * exactly matches the ordinal of fields in the schema of the BaseRelation. + * 2. It assumes that the schema of this relation will not be changed. + * Even if the insert method updates the schema (e.g. a relation of JSON or Parquet data may have a + * schema update after an insert operation), the new schema will not be used. + * 3. It assumes that fields of the data provided in the insert method are nullable. + * If a data source needs to check the actual nullability of a field, it needs to do it in the + * insert method. + */ +@DeveloperApi +trait InsertableRelation { + def insert(data: DataFrame, overwrite: Boolean): Unit +} + /** * ::Experimental:: * An interface for experimenting with a more direct connection to the query planner. Compared to * [[PrunedFilteredScan]], this operator receives the raw expressions from the * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. Unlike the other APIs this - * interface is not designed to be binary compatible across releases and thus should only be used + * interface is NOT designed to be binary compatible across releases and thus should only be used * for experimentation. */ @Experimental -trait CatalystScan extends BaseRelation { +trait CatalystScan { def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row] } - -@DeveloperApi -trait InsertableRelation extends BaseRelation { - def insert(data: DataFrame, overwrite: Boolean): Unit -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala new file mode 100644 index 0000000000000..5a78001117d1b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql.{SaveMode, AnalysisException} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, Catalog} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Alias} +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.DataType + +/** + * A rule to do pre-insert data type casting and field renaming. Before we insert into + * an [[InsertableRelation]], we will use this rule to make sure that + * the columns to be inserted have the correct data type and fields have the correct names. + */ +private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // Wait until children are resolved. + case p: LogicalPlan if !p.childrenResolved => p + + // We are inserting into an InsertableRelation. + case i @ InsertIntoTable( + l @ LogicalRelation(r: InsertableRelation), partition, child, overwrite) => { + // First, make sure the data to be inserted have the same number of fields with the + // schema of the relation. + if (l.output.size != child.output.size) { + sys.error( + s"$l requires that the query in the SELECT clause of the INSERT INTO/OVERWRITE " + + s"statement generates the same number of columns as its schema.") + } + castAndRenameChildOutput(i, l.output, child) + } + } + + /** If necessary, cast data types and rename fields to the expected types and names. */ + def castAndRenameChildOutput( + insertInto: InsertIntoTable, + expectedOutput: Seq[Attribute], + child: LogicalPlan): InsertIntoTable = { + val newChildOutput = expectedOutput.zip(child.output).map { + case (expected, actual) => + val needCast = !expected.dataType.sameType(actual.dataType) + // We want to make sure the filed names in the data to be inserted exactly match + // names in the schema. + val needRename = expected.name != actual.name + (needCast, needRename) match { + case (true, _) => Alias(Cast(actual, expected.dataType), expected.name)() + case (false, true) => Alias(actual, expected.name)() + case (_, _) => actual + } + } + + if (newChildOutput == child.output) { + insertInto + } else { + insertInto.copy(child = Project(newChildOutput, child)) + } + } +} + +/** + * A rule to do various checks before inserting into or writing to a data source table. + */ +private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => Unit) { + def failAnalysis(msg: String): Unit = { throw new AnalysisException(msg) } + + def apply(plan: LogicalPlan): Unit = { + plan.foreach { + case i @ logical.InsertIntoTable( + l @ LogicalRelation(t: InsertableRelation), partition, query, overwrite) => + // Right now, we do not support insert into a data source table with partition specs. + if (partition.nonEmpty) { + failAnalysis(s"Insert into a partition is not allowed because $l is not partitioned.") + } else { + // Get all input data source relations of the query. + val srcRelations = query.collect { + case LogicalRelation(src: BaseRelation) => src + } + if (srcRelations.contains(t)) { + failAnalysis( + "Cannot insert overwrite into table that is also being read from.") + } else { + // OK + } + } + + case i @ logical.InsertIntoTable( + l: LogicalRelation, partition, query, overwrite) if !l.isInstanceOf[InsertableRelation] => + // The relation in l is not an InsertableRelation. + failAnalysis(s"$l does not allow insertion.") + + case CreateTableUsingAsSelect(tableName, _, _, SaveMode.Overwrite, _, query) => + // When the SaveMode is Overwrite, we need to check if the table is an input table of + // the query. If so, we will throw an AnalysisException to let users know it is not allowed. + if (catalog.tableExists(Seq(tableName))) { + // Need to remove SubQuery operator. + EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) match { + // Only do the check if the table is a data source table + // (the relation is a BaseRelation). + case l @ LogicalRelation(dest: BaseRelation) => + // Get all input data source relations of the query. + val srcRelations = query.collect { + case LogicalRelation(src: BaseRelation) => src + } + if (srcRelations.contains(dest)) { + failAnalysis( + s"Cannot overwrite table $tableName that is also being read from.") + } else { + // OK + } + + case _ => // OK + } + } else { + // OK + } + + case _ => // OK + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index 006b16fbe07bd..2fdd798b44bb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types._ * @param y y coordinate */ @SQLUserDefinedType(udt = classOf[ExamplePointUDT]) -private[sql] class ExamplePoint(val x: Double, val y: Double) +private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable /** * User-defined type for [[ExamplePoint]]. @@ -37,7 +37,7 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] { override def sqlType: DataType = ArrayType(DoubleType, false) - override def pyUDT: String = "pyspark.tests.ExamplePointUDT" + override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT" override def serialize(obj: Any): Seq[Double] = { obj match { @@ -59,4 +59,6 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] { } override def userClass: Class[ExamplePoint] = classOf[ExamplePoint] + + private[spark] override def asNullable: ExamplePointUDT = this } diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java deleted file mode 100644 index 639436368c4a3..0000000000000 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -import com.google.common.collect.ImmutableMap; - -import org.apache.spark.sql.Column; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.types.DataTypes; - -import static org.apache.spark.sql.Dsl.*; - -/** - * This test doesn't actually run anything. It is here to check the API compatibility for Java. - */ -public class JavaDsl { - - public static void testDataFrame(final DataFrame df) { - DataFrame df1 = df.select("colA"); - df1 = df.select("colA", "colB"); - - df1 = df.select(col("colA"), col("colB"), lit("literal value").$plus(1)); - - df1 = df.filter(col("colA")); - - java.util.Map aggExprs = ImmutableMap.builder() - .put("colA", "sum") - .put("colB", "avg") - .build(); - - df1 = df.agg(aggExprs); - - df1 = df.groupBy("groupCol").agg(aggExprs); - - df1 = df.join(df1, col("key1").$eq$eq$eq(col("key2")), "outer"); - - df.orderBy("colA"); - df.orderBy("colA", "colB", "colC"); - df.orderBy(col("colA").desc()); - df.orderBy(col("colA").desc(), col("colB").asc()); - - df.sort("colA"); - df.sort("colA", "colB", "colC"); - df.sort(col("colA").desc()); - df.sort(col("colA").desc(), col("colB").asc()); - - df.as("b"); - - df.limit(5); - - df.unionAll(df1); - df.intersect(df1); - df.except(df1); - - df.sample(true, 0.1, 234); - - df.head(); - df.head(5); - df.first(); - df.count(); - } - - public static void testColumn(final Column c) { - c.asc(); - c.desc(); - - c.endsWith("abcd"); - c.startsWith("afgasdf"); - - c.like("asdf%"); - c.rlike("wef%asdf"); - - c.as("newcol"); - - c.cast("int"); - c.cast(DataTypes.IntegerType); - } - - public static void testDsl() { - // Creating a column. - Column c = col("abcd"); - Column c1 = column("abcd"); - - // Literals - Column l1 = lit(1); - Column l2 = lit(1.0); - Column l3 = lit("abcd"); - - // Functions - Column a = upper(c); - a = lower(c); - a = sqrt(c); - a = abs(c); - - // Aggregates - a = min(c); - a = max(c); - a = sum(c); - a = sumDistinct(c); - a = countDistinct(c, a); - a = avg(c); - a = first(c); - a = last(c); - } -} diff --git a/sql/core/src/test/java/org/apache/spark/sql/jdbc/JavaJDBCTest.java b/sql/core/src/test/java/org/apache/spark/sql/jdbc/JavaJDBCTest.java deleted file mode 100644 index 80bd74f5b5525..0000000000000 --- a/sql/core/src/test/java/org/apache/spark/sql/jdbc/JavaJDBCTest.java +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.jdbc; - -import org.junit.*; -import static org.junit.Assert.*; -import java.sql.Connection; -import java.sql.DriverManager; - -import org.apache.spark.SparkEnv; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.api.java.*; -import org.apache.spark.sql.test.TestSQLContext$; - -public class JavaJDBCTest { - static String url = "jdbc:h2:mem:testdb1"; - - static Connection conn = null; - - // This variable will always be null if TestSQLContext is intact when running - // these tests. Some Java tests do not play nicely with others, however; - // they create a SparkContext of their own at startup and stop it at exit. - // This renders TestSQLContext inoperable, meaning we have to do the same - // thing. If this variable is nonnull, that means we allocated a - // SparkContext of our own and that we need to stop it at teardown. - static JavaSparkContext localSparkContext = null; - - static SQLContext sql = TestSQLContext$.MODULE$; - - @Before - public void beforeTest() throws Exception { - if (SparkEnv.get() == null) { // A previous test destroyed TestSQLContext. - localSparkContext = new JavaSparkContext("local", "JavaAPISuite"); - sql = new SQLContext(localSparkContext); - } - Class.forName("org.h2.Driver"); - conn = DriverManager.getConnection(url); - conn.prepareStatement("create schema test").executeUpdate(); - conn.prepareStatement("create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate(); - conn.prepareStatement("insert into test.people values ('fred', 1)").executeUpdate(); - conn.prepareStatement("insert into test.people values ('mary', 2)").executeUpdate(); - conn.prepareStatement("insert into test.people values ('joe', 3)").executeUpdate(); - conn.commit(); - } - - @After - public void afterTest() throws Exception { - if (localSparkContext != null) { - localSparkContext.stop(); - localSparkContext = null; - } - try { - conn.close(); - } finally { - conn = null; - } - } - - @Test - public void basicTest() throws Exception { - DataFrame rdd = JDBCUtils.jdbcRDD(sql, url, "TEST.PEOPLE"); - Row[] rows = rdd.collect(); - assertEquals(rows.length, 3); - } - - @Test - public void partitioningTest() throws Exception { - String[] parts = new String[2]; - parts[0] = "THEID < 2"; - parts[1] = "THEID = 2"; // Deliberately forget about one of them. - DataFrame rdd = JDBCUtils.jdbcRDD(sql, url, "TEST.PEOPLE", parts); - Row[] rows = rdd.collect(); - assertEquals(rows.length, 2); - } - - @Test - public void writeTest() throws Exception { - DataFrame rdd = JDBCUtils.jdbcRDD(sql, url, "TEST.PEOPLE"); - JDBCUtils.createJDBCTable(rdd, url, "TEST.PEOPLECOPY", false); - DataFrame rdd2 = JDBCUtils.jdbcRDD(sql, url, "TEST.PEOPLECOPY"); - Row[] rows = rdd2.collect(); - assertEquals(rows.length, 3); - } -} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java similarity index 72% rename from sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java rename to sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index badd00d34b9b1..c344a9b095c52 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -15,13 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java; +package test.org.apache.spark.sql; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import org.apache.spark.sql.test.TestSQLContext$; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -38,19 +39,18 @@ // see http://stackoverflow.com/questions/758570/. public class JavaApplySchemaSuite implements Serializable { private transient JavaSparkContext javaCtx; - private transient SQLContext javaSqlCtx; + private transient SQLContext sqlContext; @Before public void setUp() { - javaCtx = new JavaSparkContext("local", "JavaApplySchemaSuite"); - javaSqlCtx = new SQLContext(javaCtx); + sqlContext = TestSQLContext$.MODULE$; + javaCtx = new JavaSparkContext(sqlContext.sparkContext()); } @After public void tearDown() { - javaCtx.stop(); javaCtx = null; - javaSqlCtx = null; + sqlContext = null; } public static class Person implements Serializable { @@ -98,9 +98,9 @@ public Row call(Person person) throws Exception { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - DataFrame df = javaSqlCtx.applySchema(rowRDD.rdd(), schema); + DataFrame df = sqlContext.applySchema(rowRDD, schema); df.registerTempTable("people"); - Row[] actual = javaSqlCtx.sql("SELECT * FROM people").collect(); + Row[] actual = sqlContext.sql("SELECT * FROM people").collect(); List expected = new ArrayList(2); expected.add(RowFactory.create("Michael", 29)); @@ -109,6 +109,46 @@ public Row call(Person person) throws Exception { Assert.assertEquals(expected, Arrays.asList(actual)); } + @Test + public void dataFrameRDDOperations() { + List personList = new ArrayList(2); + Person person1 = new Person(); + person1.setName("Michael"); + person1.setAge(29); + personList.add(person1); + Person person2 = new Person(); + person2.setName("Yin"); + person2.setAge(28); + personList.add(person2); + + JavaRDD rowRDD = javaCtx.parallelize(personList).map( + new Function() { + public Row call(Person person) throws Exception { + return RowFactory.create(person.getName(), person.getAge()); + } + }); + + List fields = new ArrayList(2); + fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); + fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); + StructType schema = DataTypes.createStructType(fields); + + DataFrame df = sqlContext.applySchema(rowRDD, schema); + df.registerTempTable("people"); + List actual = sqlContext.sql("SELECT * FROM people").toJavaRDD().map(new Function() { + + public String call(Row row) { + return row.getString(0) + "_" + row.get(1).toString(); + } + }).collect(); + + List expected = new ArrayList(2); + expected.add("Michael_29"); + expected.add("Yin_28"); + + Assert.assertEquals(expected, actual); + } + @Test public void applySchemaToJSON() { JavaRDD jsonRDD = javaCtx.parallelize(Arrays.asList( @@ -122,7 +162,7 @@ public void applySchemaToJSON() { fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(), true)); fields.add(DataTypes.createStructField("boolean", DataTypes.BooleanType, true)); fields.add(DataTypes.createStructField("double", DataTypes.DoubleType, true)); - fields.add(DataTypes.createStructField("integer", DataTypes.IntegerType, true)); + fields.add(DataTypes.createStructField("integer", DataTypes.LongType, true)); fields.add(DataTypes.createStructField("long", DataTypes.LongType, true)); fields.add(DataTypes.createStructField("null", DataTypes.StringType, true)); fields.add(DataTypes.createStructField("string", DataTypes.StringType, true)); @@ -147,18 +187,18 @@ public void applySchemaToJSON() { null, "this is another simple string.")); - DataFrame df1 = javaSqlCtx.jsonRDD(jsonRDD.rdd()); + DataFrame df1 = sqlContext.jsonRDD(jsonRDD); StructType actualSchema1 = df1.schema(); Assert.assertEquals(expectedSchema, actualSchema1); df1.registerTempTable("jsonTable1"); - List actual1 = javaSqlCtx.sql("select * from jsonTable1").collectAsList(); + List actual1 = sqlContext.sql("select * from jsonTable1").collectAsList(); Assert.assertEquals(expectedResult, actual1); - DataFrame df2 = javaSqlCtx.jsonRDD(jsonRDD.rdd(), expectedSchema); + DataFrame df2 = sqlContext.jsonRDD(jsonRDD, expectedSchema); StructType actualSchema2 = df2.schema(); Assert.assertEquals(expectedSchema, actualSchema2); df2.registerTempTable("jsonTable2"); - List actual2 = javaSqlCtx.sql("select * from jsonTable2").collectAsList(); + List actual2 = sqlContext.sql("select * from jsonTable2").collectAsList(); Assert.assertEquals(expectedResult, actual2); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java new file mode 100644 index 0000000000000..2d586f784ac5a --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; + +import org.apache.spark.sql.*; +import org.apache.spark.sql.test.TestSQLContext$; +import static org.apache.spark.sql.functions.*; + + +public class JavaDataFrameSuite { + private transient SQLContext context; + + @Before + public void setUp() { + // Trigger static initializer of TestData + TestData$.MODULE$.testData(); + context = TestSQLContext$.MODULE$; + } + + @After + public void tearDown() { + context = null; + } + + @Test + public void testExecution() { + DataFrame df = context.table("testData").filter("key = 1"); + Assert.assertEquals(df.select("key").collect()[0].get(0), 1); + } + + /** + * See SPARK-5904. Abstract vararg methods defined in Scala do not work in Java. + */ + @Test + public void testVarargMethods() { + DataFrame df = context.table("testData"); + + df.toDF("key1", "value1"); + + df.select("key", "value"); + df.select(col("key"), col("value")); + df.selectExpr("key", "value + 1"); + + df.sort("key", "value"); + df.sort(col("key"), col("value")); + df.orderBy("key", "value"); + df.orderBy(col("key"), col("value")); + + df.groupBy("key", "value").agg(col("key"), col("value"), sum("value")); + df.groupBy(col("key"), col("value")).agg(col("key"), col("value"), sum("value")); + df.agg(first("key"), sum("value")); + + df.groupBy().avg("key"); + df.groupBy().mean("key"); + df.groupBy().max("key"); + df.groupBy().min("key"); + df.groupBy().sum("key"); + + // Varargs in column expressions + df.groupBy().agg(countDistinct("key", "value")); + df.groupBy().agg(countDistinct(col("key"), col("value"))); + df.select(coalesce(col("key"))); + } + + @Ignore + public void testShow() { + // This test case is intended ignored, but to make sure it compiles correctly + DataFrame df = context.table("testData"); + df.show(); + df.show(1000); + } +} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java similarity index 99% rename from sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java rename to sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java index fbfcd3f59d910..4ce1d1dddb26a 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java; +package test.org.apache.spark.sql; import java.math.BigDecimal; import java.sql.Date; diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java similarity index 89% rename from sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java rename to sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index e5588938ea162..79d92734ff375 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java; +package test.org.apache.spark.sql; import java.io.Serializable; @@ -23,28 +23,29 @@ import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.api.java.UDF1; +import org.apache.spark.sql.api.java.UDF2; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.types.DataTypes; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; // see http://stackoverflow.com/questions/758570/. -public class JavaAPISuite implements Serializable { +public class JavaUDFSuite implements Serializable { private transient JavaSparkContext sc; private transient SQLContext sqlContext; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaAPISuite"); - sqlContext = new SQLContext(sc); + sqlContext = TestSQLContext$.MODULE$; + sc = new JavaSparkContext(sqlContext.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; } @SuppressWarnings("unchecked") diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java new file mode 100644 index 0000000000000..b76f7d421f643 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources; + +import java.io.File; +import java.io.IOException; +import java.util.*; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.test.TestSQLContext$; +import org.apache.spark.sql.*; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.util.Utils; + +public class JavaSaveLoadSuite { + + private transient JavaSparkContext sc; + private transient SQLContext sqlContext; + + String originalDefaultSource; + File path; + DataFrame df; + + private void checkAnswer(DataFrame actual, List expected) { + String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); + if (errorMessage != null) { + Assert.fail(errorMessage); + } + } + + @Before + public void setUp() throws IOException { + sqlContext = TestSQLContext$.MODULE$; + sc = new JavaSparkContext(sqlContext.sparkContext()); + + originalDefaultSource = sqlContext.conf().defaultDataSourceName(); + path = + Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile(); + if (path.exists()) { + path.delete(); + } + + List jsonObjects = new ArrayList(10); + for (int i = 0; i < 10; i++) { + jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); + } + JavaRDD rdd = sc.parallelize(jsonObjects); + df = sqlContext.jsonRDD(rdd); + df.registerTempTable("jsonTable"); + } + + @Test + public void saveAndLoad() { + Map options = new HashMap(); + options.put("path", path.toString()); + df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, options); + + DataFrame loadedDF = sqlContext.load("org.apache.spark.sql.json", options); + + checkAnswer(loadedDF, df.collectAsList()); + } + + @Test + public void saveAndLoadWithSchema() { + Map options = new HashMap(); + options.put("path", path.toString()); + df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, options); + + List fields = new ArrayList(); + fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); + StructType schema = DataTypes.createStructType(fields); + DataFrame loadedDF = sqlContext.load("org.apache.spark.sql.json", schema, options); + + checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList()); + } +} diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties index fbed0a782dd3e..28e90b9520b2c 100644 --- a/sql/core/src/test/resources/log4j.properties +++ b/sql/core/src/test/resources/log4j.properties @@ -39,6 +39,9 @@ log4j.appender.FA.Threshold = INFO log4j.additivity.parquet.hadoop.ParquetRecordReader=false log4j.logger.parquet.hadoop.ParquetRecordReader=OFF +log4j.additivity.parquet.hadoop.ParquetOutputCommitter=false +log4j.logger.parquet.hadoop.ParquetOutputCommitter=OFF + log4j.additivity.org.apache.hadoop.hive.serde2.lazy.LazyStruct=false log4j.logger.org.apache.hadoop.hive.serde2.lazy.LazyStruct=OFF diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index c9221f8f934ad..c240f2be955ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -17,11 +17,16 @@ package org.apache.spark.sql +import scala.concurrent.duration._ +import scala.language.{implicitConversions, postfixOps} + +import org.scalatest.concurrent.Eventually._ + import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.storage.{StorageLevel, RDDBlockId} +import org.apache.spark.sql.test.TestSQLContext.implicits._ +import org.apache.spark.storage.{RDDBlockId, StorageLevel} case class BigData(s: String) @@ -52,15 +57,15 @@ class CachedTableSuite extends QueryTest { test("unpersist an uncached table will not raise exception") { assert(None == cacheManager.lookupCachedData(testData)) - testData.unpersist(true) + testData.unpersist(blocking = true) assert(None == cacheManager.lookupCachedData(testData)) - testData.unpersist(false) + testData.unpersist(blocking = false) assert(None == cacheManager.lookupCachedData(testData)) testData.persist() assert(None != cacheManager.lookupCachedData(testData)) - testData.unpersist(true) + testData.unpersist(blocking = true) assert(None == cacheManager.lookupCachedData(testData)) - testData.unpersist(false) + testData.unpersist(blocking = false) assert(None == cacheManager.lookupCachedData(testData)) } @@ -87,7 +92,7 @@ class CachedTableSuite extends QueryTest { test("too big for memory") { val data = "*" * 10000 - sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).registerTempTable("bigData") + sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF().registerTempTable("bigData") table("bigData").persist(StorageLevel.MEMORY_AND_DISK) assert(table("bigData").count() === 200000L) table("bigData").unpersist(blocking = true) @@ -191,7 +196,10 @@ class CachedTableSuite extends QueryTest { sql("UNCACHE TABLE testData") assert(!isCached("testData"), "Table 'testData' should not be cached") - assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + + eventually(timeout(10 seconds)) { + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + } } test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { @@ -204,7 +212,9 @@ class CachedTableSuite extends QueryTest { "Eagerly cached in-memory table should have already been materialized") uncacheTable("testCacheTable") - assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + eventually(timeout(10 seconds)) { + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + } } test("CACHE TABLE tableName AS SELECT ...") { @@ -217,7 +227,9 @@ class CachedTableSuite extends QueryTest { "Eagerly cached in-memory table should have already been materialized") uncacheTable("testCacheTable") - assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + eventually(timeout(10 seconds)) { + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + } } test("CACHE LAZY TABLE tableName") { @@ -235,7 +247,9 @@ class CachedTableSuite extends QueryTest { "Lazily cached in-memory table should have been materialized") uncacheTable("testData") - assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + eventually(timeout(10 seconds)) { + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + } } test("InMemoryRelation statistics") { @@ -266,4 +280,20 @@ class CachedTableSuite extends QueryTest { assert(intercept[RuntimeException](table("t1")).getMessage.startsWith("Table Not Found")) assert(!isCached("t2")) } + + test("Clear all cache") { + sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") + sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") + cacheTable("t1") + cacheTable("t2") + clearCache() + assert(cacheManager.isEmpty) + + sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") + sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") + cacheTable("t1") + cacheTable("t2") + sql("Clear CACHE") + assert(cacheManager.isEmpty) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index fa4cdecbcb340..3c05f9c8c9d5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types.{BooleanType, IntegerType, StructField, StructType} @@ -27,43 +30,10 @@ class ColumnExpressionSuite extends QueryTest { // TODO: Add test cases for bitwise operations. - test("computability check") { - def shouldBeComputable(c: Column): Unit = assert(c.isComputable === true) - - def shouldNotBeComputable(c: Column): Unit = { - assert(c.isComputable === false) - intercept[UnsupportedOperationException] { c.head() } - } - - shouldBeComputable(testData2("a")) - shouldBeComputable(testData2("b")) - - shouldBeComputable(testData2("a") + testData2("b")) - shouldBeComputable(testData2("a") + testData2("b") + 1) - - shouldBeComputable(-testData2("a")) - shouldBeComputable(!testData2("a")) - - shouldBeComputable(testData2.select(($"a" + 1).as("c"))("c") + testData2("b")) - shouldBeComputable( - testData2.select(($"a" + 1).as("c"))("c") + testData2.select(($"b" / 2).as("d"))("d")) - shouldBeComputable( - testData2.select(($"a" + 1).as("c")).select(($"c" + 2).as("d"))("d") + testData2("b")) - - // Literals and unresolved columns should not be computable. - shouldNotBeComputable(col("1")) - shouldNotBeComputable(col("1") + 2) - shouldNotBeComputable(lit(100)) - shouldNotBeComputable(lit(100) + 10) - shouldNotBeComputable(-col("1")) - shouldNotBeComputable(!col("1")) - - // Getting data from different frames should not be computable. - shouldNotBeComputable(testData2("a") + testData("key")) - shouldNotBeComputable(testData2("a") + 1 + testData("key")) - - // Aggregate functions alone should not be computable. - shouldNotBeComputable(sum(testData2("a"))) + test("collect on column produced by a binary operator") { + val df = Seq((1, 2, 3)).toDF("a", "b", "c") + checkAnswer(df.select(df("a") + df("b")), Seq(Row(3))) + checkAnswer(df.select(df("a") + df("b").as("c")), Seq(Row(3))) } test("star") { @@ -71,8 +41,7 @@ class ColumnExpressionSuite extends QueryTest { } test("star qualified by data frame object") { - // This is not yet supported. - val df = testData.toDataFrame + val df = testData.toDF val goldAnswer = df.collect().toSeq checkAnswer(df.select(df("*")), goldAnswer) @@ -149,13 +118,13 @@ class ColumnExpressionSuite extends QueryTest { test("isNull") { checkAnswer( - nullStrings.toDataFrame.where($"s".isNull), + nullStrings.toDF.where($"s".isNull), nullStrings.collect().toSeq.filter(r => r.getString(1) eq null)) } test("isNotNull") { checkAnswer( - nullStrings.toDataFrame.where($"s".isNotNull), + nullStrings.toDF.where($"s".isNotNull), nullStrings.collect().toSeq.filter(r => r.getString(1) ne null)) } @@ -180,7 +149,7 @@ class ColumnExpressionSuite extends QueryTest { } test("!==") { - val nullData = TestSQLContext.applySchema(TestSQLContext.sparkContext.parallelize( + val nullData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize( Row(1, 1) :: Row(1, 2) :: Row(1, null) :: @@ -240,7 +209,7 @@ class ColumnExpressionSuite extends QueryTest { testData2.collect().toSeq.filter(r => r.getInt(0) <= r.getInt(1))) } - val booleanData = TestSQLContext.applySchema(TestSQLContext.sparkContext.parallelize( + val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize( Row(false, false) :: Row(false, true) :: Row(true, false) :: @@ -342,4 +311,15 @@ class ColumnExpressionSuite extends QueryTest { (1 to 100).map(n => Row(null)) ) } + + test("lift alias out of cast") { + compareExpressions( + col("1234").as("name").cast("int").expr, + col("1234").cast("int").as("name").expr) + } + + test("columns can be compared") { + assert('key.desc == 'key.desc) + assert('key.desc != 'key.asc) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala new file mode 100644 index 0000000000000..2d2367d6e7292 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -0,0 +1,55 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.apache.spark.sql.test.TestSQLContext.{sparkContext => sc} +import org.apache.spark.sql.test.TestSQLContext.implicits._ + + +class DataFrameImplicitsSuite extends QueryTest { + + test("RDD of tuples") { + checkAnswer( + sc.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), + (1 to 10).map(i => Row(i, i.toString))) + } + + test("Seq of tuples") { + checkAnswer( + (1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), + (1 to 10).map(i => Row(i, i.toString))) + } + + test("RDD[Int]") { + checkAnswer( + sc.parallelize(1 to 10).toDF("intCol"), + (1 to 10).map(i => Row(i))) + } + + test("RDD[Long]") { + checkAnswer( + sc.parallelize(1L to 10L).toDF("longCol"), + (1L to 10L).map(i => Row(i))) + } + + test("RDD[String]") { + checkAnswer( + sc.parallelize(1 to 10).map(_.toString).toDF("stringCol"), + (1 to 10).map(i => Row(i.toString))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala new file mode 100644 index 0000000000000..0896f175c056f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.JavaConversions._ + +import org.apache.spark.sql.test.TestSQLContext.implicits._ + + +class DataFrameNaFunctionsSuite extends QueryTest { + + def createDF(): DataFrame = { + Seq[(String, java.lang.Integer, java.lang.Double)]( + ("Bob", 16, 176.5), + ("Alice", null, 164.3), + ("David", 60, null), + ("Amy", null, null), + (null, null, null)).toDF("name", "age", "height") + } + + test("drop") { + val input = createDF() + val rows = input.collect() + + checkAnswer( + input.na.drop("name" :: Nil), + rows(0) :: rows(1) :: rows(2) :: rows(3) :: Nil) + + checkAnswer( + input.na.drop("age" :: Nil), + rows(0) :: rows(2) :: Nil) + + checkAnswer( + input.na.drop("age" :: "height" :: Nil), + rows(0) :: Nil) + + checkAnswer( + input.na.drop(), + rows(0)) + + // dropna on an a dataframe with no column should return an empty data frame. + val empty = input.sqlContext.emptyDataFrame.select() + assert(empty.na.drop().count() === 0L) + + // Make sure the columns are properly named. + assert(input.na.drop().columns.toSeq === input.columns.toSeq) + } + + test("drop with how") { + val input = createDF() + val rows = input.collect() + + checkAnswer( + input.na.drop("all"), + rows(0) :: rows(1) :: rows(2) :: rows(3) :: Nil) + + checkAnswer( + input.na.drop("any"), + rows(0) :: Nil) + + checkAnswer( + input.na.drop("any", Seq("age", "height")), + rows(0) :: Nil) + + checkAnswer( + input.na.drop("all", Seq("age", "height")), + rows(0) :: rows(1) :: rows(2) :: Nil) + } + + test("drop with threshold") { + val input = createDF() + val rows = input.collect() + + checkAnswer( + input.na.drop(2, Seq("age", "height")), + rows(0) :: Nil) + + checkAnswer( + input.na.drop(3, Seq("name", "age", "height")), + rows(0)) + + // Make sure the columns are properly named. + assert(input.na.drop(2, Seq("age", "height")).columns.toSeq === input.columns.toSeq) + } + + test("fill") { + val input = createDF() + + val fillNumeric = input.na.fill(50.6) + checkAnswer( + fillNumeric, + Row("Bob", 16, 176.5) :: + Row("Alice", 50, 164.3) :: + Row("David", 60, 50.6) :: + Row("Amy", 50, 50.6) :: + Row(null, 50, 50.6) :: Nil) + + // Make sure the columns are properly named. + assert(fillNumeric.columns.toSeq === input.columns.toSeq) + + // string + checkAnswer( + input.na.fill("unknown").select("name"), + Row("Bob") :: Row("Alice") :: Row("David") :: Row("Amy") :: Row("unknown") :: Nil) + assert(input.na.fill("unknown").columns.toSeq === input.columns.toSeq) + + // fill double with subset columns + checkAnswer( + input.na.fill(50.6, "age" :: Nil), + Row("Bob", 16, 176.5) :: + Row("Alice", 50, 164.3) :: + Row("David", 60, null) :: + Row("Amy", 50, null) :: + Row(null, 50, null) :: Nil) + + // fill string with subset columns + checkAnswer( + Seq[(String, String)]((null, null)).toDF("col1", "col2").na.fill("test", "col1" :: Nil), + Row("test", null)) + } + + test("fill with map") { + val df = Seq[(String, String, java.lang.Long, java.lang.Double)]( + (null, null, null, null)).toDF("a", "b", "c", "d") + checkAnswer( + df.na.fill(Map( + "a" -> "test", + "c" -> 1, + "d" -> 2.2 + )), + Row("test", null, 1, 2.2)) + + // Test Java version + checkAnswer( + df.na.fill(mapAsJavaMap(Map( + "a" -> "test", + "c" -> 1, + "d" -> 2.2 + ))), + Row("test", null, 1, 2.2)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f6b65a81ce05e..1db0cf7daac03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -17,18 +17,24 @@ package org.apache.spark.sql -import org.apache.spark.sql.Dsl._ +import scala.language.postfixOps + +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, TestSQLContext} +import org.apache.spark.sql.test.TestSQLContext.logicalPlanToSparkQuery +import org.apache.spark.sql.test.TestSQLContext.implicits._ +import org.apache.spark.sql.test.TestSQLContext.sql -/* Implicits */ -import org.apache.spark.sql.test.TestSQLContext._ - -import scala.language.postfixOps class DataFrameSuite extends QueryTest { import org.apache.spark.sql.TestData._ test("analysis error should be eagerly reported") { + val oldSetting = TestSQLContext.conf.dataFrameEagerAnalysis + // Eager analysis. + TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") + intercept[Exception] { testData.select('nonExistentName) } intercept[Exception] { testData.groupBy('key).agg(Map("nonExistentName" -> "sum")) @@ -39,6 +45,45 @@ class DataFrameSuite extends QueryTest { intercept[Exception] { testData.groupBy($"abcd").agg(Map("key" -> "sum")) } + + // No more eager analysis once the flag is turned off + TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "false") + testData.select('nonExistentName) + + // Set the flag back to original value before this test. + TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) + } + + test("dataframe toString") { + assert(testData.toString === "[key: int, value: string]") + assert(testData("key").toString === "key") + assert($"test".toString === "test") + } + + test("rename nested groupby") { + val df = Seq((1,(1,1))).toDF() + + checkAnswer( + df.groupBy("_1").agg(col("_1"), sum("_2._1")).toDF("key", "total"), + Row(1, 1) :: Nil) + } + + test("invalid plan toString, debug mode") { + val oldSetting = TestSQLContext.conf.dataFrameEagerAnalysis + TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") + + // Turn on debug mode so we can see invalid query plans. + import org.apache.spark.sql.execution.debug._ + TestSQLContext.debug() + + val badPlan = testData.select('badColumn) + + assert(badPlan.toString contains badPlan.queryExecution.toString, + "toString on bad query plans should include the query execution but was:\n" + + badPlan.toString) + + // Set the flag back to original value before this test. + TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) } test("table scan") { @@ -47,21 +92,122 @@ class DataFrameSuite extends QueryTest { testData.collect().toSeq) } + test("empty data frame") { + assert(TestSQLContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) + assert(TestSQLContext.emptyDataFrame.count() === 0) + } + + test("head and take") { + assert(testData.take(2) === testData.collect().take(2)) + assert(testData.head(2) === testData.collect().take(2)) + assert(testData.head(2).head.schema === testData.schema) + } + + test("self join") { + val df1 = testData.select(testData("key")).as('df1) + val df2 = testData.select(testData("key")).as('df2) + + checkAnswer( + df1.join(df2, $"df1.key" === $"df2.key"), + sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq) + } + + test("simple explode") { + val df = Seq(Tuple1("a b c"), Tuple1("d e")).toDF("words") + + checkAnswer( + df.explode("words", "word") { word: String => word.split(" ").toSeq }.select('word), + Row("a") :: Row("b") :: Row("c") :: Row("d") ::Row("e") :: Nil + ) + } + + test("self join with aliases") { + val df = Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str") + checkAnswer( + df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("x.str").count(), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) + + checkAnswer( + df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").count(), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) + } + + test("explode") { + val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDF("number", "letters") + val df2 = + df.explode('letters) { + case Row(letters: String) => letters.split(" ").map(Tuple1(_)).toSeq + } + + checkAnswer( + df2 + .select('_1 as 'letter, 'number) + .groupBy('letter) + .agg('letter, countDistinct('number)), + Row("a", 3) :: Row("b", 2) :: Row("c", 1) :: Nil + ) + } + + test("selectExpr") { + checkAnswer( + testData.selectExpr("abs(key)", "value"), + testData.collect().map(row => Row(math.abs(row.getInt(0)), row.getString(1))).toSeq) + } + + test("selectExpr with alias") { + checkAnswer( + testData.selectExpr("key as k").select("k"), + testData.select("key").collect().toSeq) + } + + test("filterExpr") { + checkAnswer( + testData.filter("key > 90"), + testData.collect().filter(_.getInt(0) > 90).toSeq) + } + test("repartition") { checkAnswer( testData.select('key).repartition(10).select('key), testData.select('key).collect().toSeq) } - test("agg") { + test("groupBy") { checkAnswer( testData2.groupBy("a").agg($"a", sum($"b")), - Seq(Row(1,3), Row(2,3), Row(3,3)) + Seq(Row(1, 3), Row(2, 3), Row(3, 3)) ) checkAnswer( testData2.groupBy("a").agg($"a", sum($"b").as("totB")).agg(sum('totB)), Row(9) ) + checkAnswer( + testData2.groupBy("a").agg(col("a"), count("*")), + Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil + ) + checkAnswer( + testData2.groupBy("a").agg(Map("*" -> "count")), + Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil + ) + checkAnswer( + testData2.groupBy("a").agg(Map("b" -> "sum")), + Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil + ) + + val df1 = Seq(("a", 1, 0, "b"), ("b", 2, 4, "c"), ("a", 2, 3, "d")) + .toDF("key", "value1", "value2", "rest") + + checkAnswer( + df1.groupBy("key").min(), + df1.groupBy("key").min("value1", "value2").collect() + ) + checkAnswer( + df1.groupBy("key").min("value2"), + Seq(Row("a", 0), Row("b", 4)) + ) + } + + test("agg without groups") { checkAnswer( testData2.agg(sum('b)), Row(9) @@ -117,6 +263,10 @@ class DataFrameSuite extends QueryTest { testData2.orderBy('a.asc, 'b.asc), Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) + checkAnswer( + testData2.orderBy(asc("a"), desc("b")), + Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1))) + checkAnswer( testData2.orderBy('a.asc, 'b.desc), Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1))) @@ -130,20 +280,20 @@ class DataFrameSuite extends QueryTest { Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2))) checkAnswer( - arrayData.orderBy('data.getItem(0).asc), - arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq) + arrayData.toDF().orderBy('data.getItem(0).asc), + arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq) checkAnswer( - arrayData.orderBy('data.getItem(0).desc), - arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq) + arrayData.toDF().orderBy('data.getItem(0).desc), + arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq) checkAnswer( - arrayData.orderBy('data.getItem(1).asc), - arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq) + arrayData.toDF().orderBy('data.getItem(1).asc), + arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq) checkAnswer( - arrayData.orderBy('data.getItem(1).desc), - arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq) + arrayData.toDF().orderBy('data.getItem(1).desc), + arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq) } test("limit") { @@ -152,11 +302,11 @@ class DataFrameSuite extends QueryTest { testData.take(10).toSeq) checkAnswer( - arrayData.limit(1), + arrayData.toDF().limit(1), arrayData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq))) checkAnswer( - mapData.limit(1), + mapData.toDF().limit(1), mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq))) } @@ -280,18 +430,95 @@ class DataFrameSuite extends QueryTest { } test("udf") { - val foo = (a: Int, b: String) => a.toString + b + val foo = udf((a: Int, b: String) => a.toString + b) checkAnswer( // SELECT *, foo(key, value) FROM testData - testData.select($"*", callUDF(foo, 'key, 'value)).limit(3), + testData.select($"*", foo('key, 'value)).limit(3), Row(1, "1", "11") :: Row(2, "2", "22") :: Row(3, "3", "33") :: Nil ) } + test("withColumn") { + val df = testData.toDF().withColumn("newCol", col("key") + 1) + checkAnswer( + df, + testData.collect().map { case Row(key: Int, value: String) => + Row(key, value, key + 1) + }.toSeq) + assert(df.schema.map(_.name).toSeq === Seq("key", "value", "newCol")) + } + + test("withColumnRenamed") { + val df = testData.toDF().withColumn("newCol", col("key") + 1) + .withColumnRenamed("value", "valueRenamed") + checkAnswer( + df, + testData.collect().map { case Row(key: Int, value: String) => + Row(key, value, key + 1) + }.toSeq) + assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol")) + } + + test("describe") { + val describeTestData = Seq( + ("Bob", 16, 176), + ("Alice", 32, 164), + ("David", 60, 192), + ("Amy", 24, 180)).toDF("name", "age", "height") + + val describeResult = Seq( + Row("count", 4, 4), + Row("mean", 33.0, 178.0), + Row("stddev", 16.583123951777, 10.0), + Row("min", 16, 164), + Row("max", 60, 192)) + + val emptyDescribeResult = Seq( + Row("count", 0, 0), + Row("mean", null, null), + Row("stddev", null, null), + Row("min", null, null), + Row("max", null, null)) + + def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) + + val describeTwoCols = describeTestData.describe("age", "height") + assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "age", "height")) + checkAnswer(describeTwoCols, describeResult) + + val describeAllCols = describeTestData.describe() + assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "age", "height")) + checkAnswer(describeAllCols, describeResult) + + val describeOneCol = describeTestData.describe("age") + assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age")) + checkAnswer(describeOneCol, describeResult.map { case Row(s, d, _) => Row(s, d)} ) + + val describeNoCol = describeTestData.select("name").describe() + assert(getSchemaAsSeq(describeNoCol) === Seq("summary")) + checkAnswer(describeNoCol, describeResult.map { case Row(s, _, _) => Row(s)} ) + + val emptyDescription = describeTestData.limit(0).describe() + assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "age", "height")) + checkAnswer(emptyDescription, emptyDescribeResult) + } + test("apply on query results (SPARK-5462)") { val df = testData.sqlContext.sql("select key from testData") - checkAnswer(df("key"), testData.select('key).collect().toSeq) + checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq) + } + + ignore("show") { + // This test case is intended ignored, but to make sure it compiles correctly + testData.select($"*").show() + testData.select($"*").show(1000) } + test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { + val rowRDD = TestSQLContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) + val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) + val df = TestSQLContext.createDataFrame(rowRDD, schema) + df.rdd.collect() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index f0c939dbb195f..e4dee87849fd4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ class JoinSuite extends QueryTest with BeforeAndAfterEach { @@ -33,14 +34,14 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") - val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.analyzed + val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan val planned = planner.HashJoin(join) assert(planned.size === 1) } def assertJoin(sqlString: String, c: Class[_]): Any = { - val rdd = sql(sqlString) - val physical = rdd.queryExecution.sparkPlan + val df = sql(sqlString) + val physical = df.queryExecution.sparkPlan val operators = physical.collect { case j: ShuffledHashJoin => j case j: HashOuterJoin => j @@ -108,7 +109,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("multiple-key equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") - val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.analyzed + val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan val planned = planner.HashJoin(join) assert(planned.size === 1) } @@ -409,8 +410,8 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("left semi join") { - val rdd = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") - checkAnswer(rdd, + val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") + checkAnswer(df, Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala new file mode 100644 index 0000000000000..f9f41eb358bd5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -0,0 +1,87 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} + +class ListTablesSuite extends QueryTest with BeforeAndAfter { + + import org.apache.spark.sql.test.TestSQLContext.implicits._ + + val df = + sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDF("key", "value") + + before { + df.registerTempTable("ListTablesSuiteTable") + } + + after { + catalog.unregisterTable(Seq("ListTablesSuiteTable")) + } + + test("get all tables") { + checkAnswer( + tables().filter("tableName = 'ListTablesSuiteTable'"), + Row("ListTablesSuiteTable", true)) + + checkAnswer( + sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), + Row("ListTablesSuiteTable", true)) + + catalog.unregisterTable(Seq("ListTablesSuiteTable")) + assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + } + + test("getting all Tables with a database name has no impact on returned table names") { + checkAnswer( + tables("DB").filter("tableName = 'ListTablesSuiteTable'"), + Row("ListTablesSuiteTable", true)) + + checkAnswer( + sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), + Row("ListTablesSuiteTable", true)) + + catalog.unregisterTable(Seq("ListTablesSuiteTable")) + assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + } + + test("query the returned DataFrame of tables") { + val expectedSchema = StructType( + StructField("tableName", StringType, false) :: + StructField("isTemporary", BooleanType, false) :: Nil) + + Seq(tables(), sql("SHOW TABLes")).foreach { + case tableDF => + assert(expectedSchema === tableDF.schema) + + tableDF.registerTempTable("tables") + checkAnswer( + sql("SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"), + Row(true, "ListTablesSuiteTable") + ) + checkAnswer( + tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), + Row("tables", true)) + dropTempTable("tables") + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index a7f2faa3ecf75..9b4dd6c620fec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -17,38 +17,89 @@ package org.apache.spark.sql +import java.util.{Locale, TimeZone} + +import scala.collection.JavaConversions._ + import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.columnar.InMemoryRelation class QueryTest extends PlanTest { + // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + // Add Locale setting + Locale.setDefault(Locale.US) + /** * Runs the plan and makes sure the answer contains all of the keywords, or the * none of keywords are listed in the answer - * @param rdd the [[DataFrame]] to be executed + * @param df the [[DataFrame]] to be executed * @param exists true for make sure the keywords are listed in the output, otherwise * to make sure none of the keyword are not listed in the output * @param keywords keyword in string array */ - def checkExistence(rdd: DataFrame, exists: Boolean, keywords: String*) { - val outputs = rdd.collect().map(_.mkString).mkString + def checkExistence(df: DataFrame, exists: Boolean, keywords: String*) { + val outputs = df.collect().map(_.mkString).mkString for (key <- keywords) { if (exists) { - assert(outputs.contains(key), s"Failed for $rdd ($key doens't exist in result)") + assert(outputs.contains(key), s"Failed for $df ($key doesn't exist in result)") } else { - assert(!outputs.contains(key), s"Failed for $rdd ($key existed in the result)") + assert(!outputs.contains(key), s"Failed for $df ($key existed in the result)") } } } /** * Runs the plan and makes sure the answer matches the expected result. - * @param rdd the [[DataFrame]] to be executed - * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ]. + * @param df the [[DataFrame]] to be executed + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. */ - protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = { - val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty + protected def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Unit = { + QueryTest.checkAnswer(df, expectedAnswer) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + protected def checkAnswer(df: DataFrame, expectedAnswer: Row): Unit = { + checkAnswer(df, Seq(expectedAnswer)) + } + + def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = { + test(sqlString) { + checkAnswer(sqlContext.sql(sqlString), expectedAnswer) + } + } + + /** + * Asserts that a given [[DataFrame]] will be executed using the given number of cached results. + */ + def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = { + val planWithCaching = query.queryExecution.withCachedData + val cachedData = planWithCaching collect { + case cached: InMemoryRelation => cached + } + + assert( + cachedData.size == numCachedTables, + s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + + planWithCaching) + } +} + +object QueryTest { + /** + * Runs the plan and makes sure the answer matches the expected result. + * If there was exception during the execution or the contents of the DataFrame does not + * match the expected result, an error message will be returned. Otherwise, a [[None]] will + * be returned. + * @param df the [[DataFrame]] to be executed + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Option[String] = { + val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to @@ -59,61 +110,47 @@ class QueryTest extends PlanTest { case o => o }) } - if (!isSorted) converted.sortBy(_.toString) else converted + if (!isSorted) converted.sortBy(_.toString()) else converted } - val sparkAnswer = try rdd.collect().toSeq catch { + val sparkAnswer = try df.collect().toSeq catch { case e: Exception => - fail( + val errorMessage = s""" |Exception thrown while executing query: - |${rdd.queryExecution} + |${df.queryExecution} |== Exception == |$e |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} - """.stripMargin) + """.stripMargin + return Some(errorMessage) } if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { - fail(s""" + val errorMessage = + s""" |Results do not match for query: - |${rdd.logicalPlan} + |${df.logicalPlan} |== Analyzed Plan == - |${rdd.queryExecution.analyzed} + |${df.queryExecution.analyzed} |== Physical Plan == - |${rdd.queryExecution.executedPlan} + |${df.queryExecution.executedPlan} |== Results == |${sideBySide( - s"== Correct Answer - ${expectedAnswer.size} ==" +: - prepareAnswer(expectedAnswer).map(_.toString), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} - """.stripMargin) + s"== Correct Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer).map(_.toString()), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} + """.stripMargin + return Some(errorMessage) } - } - protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = { - checkAnswer(rdd, Seq(expectedAnswer)) - } - - def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = { - test(sqlString) { - checkAnswer(sqlContext.sql(sqlString), expectedAnswer) - } + return None } - /** - * Asserts that a given [[DataFrame]] will be executed using the given number of cached results. - */ - def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = { - val planWithCaching = query.queryExecution.withCachedData - val cachedData = planWithCaching collect { - case cached: InMemoryRelation => cached + def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): String = { + checkAnswer(df, expectedAnswer.toSeq) match { + case Some(errorMessage) => errorMessage + case None => null } - - assert( - cachedData.size == numCachedTables, - s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + - planWithCaching) } - } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index f5b945f468dad..36465cc2fa11a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql +import org.apache.spark.sql.execution.SparkSqlSerializer import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ class RowSuite extends FunSuite { @@ -50,4 +53,13 @@ class RowSuite extends FunSuite { row(0) = null assert(row.isNullAt(0)) } + + test("serialize w/ kryo") { + val row = Seq((1, Seq(1), Map(1 -> 1), BigDecimal(1))).toDF().first() + val serializer = new SparkSqlSerializer(TestSQLContext.sparkContext.getConf) + val instance = serializer.newInstance() + val ser = instance.serialize(row) + val de = instance.deserialize(ser).asInstanceOf[Row] + assert(de === row) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index e18ba287e8683..87e7cf8c8af9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,32 +17,54 @@ package org.apache.spark.sql -import java.util.TimeZone - +import org.apache.spark.sql.test.TestSQLContext import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types._ -/* Implicits */ import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.{udf => _, _} class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { // Make sure the tables are loaded. TestData - var origZone: TimeZone = _ - override protected def beforeAll() { - origZone = TimeZone.getDefault - TimeZone.setDefault(TimeZone.getTimeZone("UTC")) + import org.apache.spark.sql.test.TestSQLContext.implicits._ + val sqlCtx = TestSQLContext + + test("self join with aliases") { + Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str").registerTempTable("df") + + checkAnswer( + sql( + """ + |SELECT x.str, COUNT(*) + |FROM df x JOIN df y ON x.str = y.str + |GROUP BY x.str + """.stripMargin), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) } - override protected def afterAll() { - TimeZone.setDefault(origZone) + test("self join with alias in agg") { + Seq(1,2,3) + .map(i => (i, i.toString)) + .toDF("int", "str") + .groupBy("str") + .agg($"str", count("str").as("strCount")) + .registerTempTable("df") + + checkAnswer( + sql( + """ + |SELECT x.str, SUM(x.strCount) + |FROM df x JOIN df y ON x.str = y.str + |GROUP BY x.str + """.stripMargin), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) } test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") { @@ -143,26 +165,26 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-3173 Timestamp support in the parser") { checkAnswer(sql( - "SELECT time FROM timestamps WHERE time=CAST('1970-01-01 00:00:00.001' AS TIMESTAMP)"), - Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))) + "SELECT time FROM timestamps WHERE time=CAST('1969-12-31 16:00:00.001' AS TIMESTAMP)"), + Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001"))) checkAnswer(sql( - "SELECT time FROM timestamps WHERE time='1970-01-01 00:00:00.001'"), - Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))) + "SELECT time FROM timestamps WHERE time='1969-12-31 16:00:00.001'"), + Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001"))) checkAnswer(sql( - "SELECT time FROM timestamps WHERE '1970-01-01 00:00:00.001'=time"), - Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))) + "SELECT time FROM timestamps WHERE '1969-12-31 16:00:00.001'=time"), + Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001"))) checkAnswer(sql( - """SELECT time FROM timestamps WHERE time<'1970-01-01 00:00:00.003' - AND time>'1970-01-01 00:00:00.001'"""), - Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002"))) + """SELECT time FROM timestamps WHERE time<'1969-12-31 16:00:00.003' + AND time>'1969-12-31 16:00:00.001'"""), + Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.002"))) checkAnswer(sql( - "SELECT time FROM timestamps WHERE time IN ('1970-01-01 00:00:00.001','1970-01-01 00:00:00.002')"), - Seq(Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")), - Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002")))) + "SELECT time FROM timestamps WHERE time IN ('1969-12-31 16:00:00.001','1969-12-31 16:00:00.002')"), + Seq(Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001")), + Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.002")))) checkAnswer(sql( "SELECT time FROM timestamps WHERE time='123'"), @@ -296,6 +318,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { mapData.collect().take(1).map(Row.fromTuple).toSeq) } + test("date row") { + checkAnswer(sql( + """select cast("2015-01-28" as date) from testData limit 1"""), + Row(java.sql.Date.valueOf("2015-01-28")) + ) + } + test("from follow multiple brackets") { checkAnswer(sql( "select key from ((select * from testData limit 1) union all (select * from testData limit 1)) x limit 1"), @@ -592,7 +621,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { ("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Row(_))) // Column type mismatches where a coercion is not possible, in this case between integer // and array types, trigger a TreeNodeException. - intercept[TreeNodeException[_]] { + intercept[AnalysisException] { sql("SELECT data FROM arrayData UNION SELECT 1 FROM arrayData").collect() } } @@ -672,7 +701,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(values(0).toInt, values(1), values(2).toBoolean, v4) } - val df1 = applySchema(rowRDD1, schema1) + val df1 = sqlCtx.createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") checkAnswer( sql("SELECT * FROM applySchema1"), @@ -702,7 +731,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df2 = applySchema(rowRDD2, schema2) + val df2 = sqlCtx.createDataFrame(rowRDD2, schema2) df2.registerTempTable("applySchema2") checkAnswer( sql("SELECT * FROM applySchema2"), @@ -727,7 +756,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4)) } - val df3 = applySchema(rowRDD3, schema2) + val df3 = sqlCtx.createDataFrame(rowRDD3, schema2) df3.registerTempTable("applySchema3") checkAnswer( @@ -772,7 +801,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { .build() val schemaWithMeta = new StructType(Array( schema("id"), schema("name").copy(metadata = metadata), schema("age"))) - val personWithMeta = applySchema(person.rdd, schemaWithMeta) + val personWithMeta = sqlCtx.createDataFrame(person.rdd, schemaWithMeta) def validateMetadata(rdd: DataFrame): Unit = { assert(rdd.schema("name").metadata.getString(docKey) == docValue) } @@ -787,7 +816,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-3371 Renaming a function expression with group by gives error") { - udf.register("len", (s: String) => s.length) + TestSQLContext.udf.register("len", (s: String) => s.length) checkAnswer( sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), Row(1)) @@ -808,10 +837,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("throw errors for non-aggregate attributes with aggregation") { def checkAggregation(query: String, isInvalidQuery: Boolean = true) { if (isInvalidQuery) { - val e = intercept[TreeNodeException[LogicalPlan]](sql(query).queryExecution.analyzed) - assert( - e.getMessage.startsWith("Expression not in GROUP BY"), - "Non-aggregate attribute(s) not detected\n") + val e = intercept[AnalysisException](sql(query).queryExecution.analyzed) + assert(e.getMessage contains "group by") } else { // Should not throw sql(query).queryExecution.analyzed @@ -1038,10 +1065,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("Supporting relational operator '<=>' in Spark SQL") { val nullCheckData1 = TestData(1,"1") :: TestData(2,null) :: Nil val rdd1 = sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) - rdd1.registerTempTable("nulldata1") + rdd1.toDF().registerTempTable("nulldata1") val nullCheckData2 = TestData(1,"1") :: TestData(2,null) :: Nil val rdd2 = sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) - rdd2.registerTempTable("nulldata2") + rdd2.toDF().registerTempTable("nulldata2") checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " + "nulldata2 on nulldata1.value <=> nulldata2.value"), (1 to 2).map(i => Row(i))) @@ -1050,7 +1077,26 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("Multi-column COUNT(DISTINCT ...)") { val data = TestData(1,"val_1") :: TestData(2,"val_2") :: Nil val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) - rdd.registerTempTable("distinctData") + rdd.toDF().registerTempTable("distinctData") checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2)) } + + test("SPARK-6145: ORDER BY test for nested fields") { + jsonRDD(sparkContext.makeRDD( + """{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)).registerTempTable("nestedOrder") + + checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.b"), Row(1)) + checkAnswer(sql("SELECT a.b FROM nestedOrder ORDER BY a.b"), Row(1)) + checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.a.a"), Row(1)) + checkAnswer(sql("SELECT a.a.a FROM nestedOrder ORDER BY a.a.a"), Row(1)) + checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY c[0].d"), Row(1)) + checkAnswer(sql("SELECT c[0].d FROM nestedOrder ORDER BY c[0].d"), Row(1)) + } + + test("SPARK-6145: special cases") { + jsonRDD(sparkContext.makeRDD( + """{"a": {"b": [1]}, "b": [{"a": 1}], "c0": {"a": 1}}""" :: Nil)).registerTempTable("t") + checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1)) + checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index a015884bae282..23df6e7eac043 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -75,21 +75,25 @@ case class ComplexReflectData( dataField: Data) class ScalaReflectionRelationSuite extends FunSuite { + + import org.apache.spark.sql.test.TestSQLContext.implicits._ + test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3)) val rdd = sparkContext.parallelize(data :: Nil) - rdd.registerTempTable("reflectData") + rdd.toDF().registerTempTable("reflectData") assert(sql("SELECT * FROM reflectData").collect().head === Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, - new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3))) + new java.math.BigDecimal(1), new Date(70, 0, 1), // This is 1970-01-01 + new Timestamp(12345), Seq(1,2,3))) } test("query case class RDD with nulls") { val data = NullReflectData(null, null, null, null, null, null, null) val rdd = sparkContext.parallelize(data :: Nil) - rdd.registerTempTable("reflectNullData") + rdd.toDF().registerTempTable("reflectNullData") assert(sql("SELECT * FROM reflectNullData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } @@ -97,7 +101,7 @@ class ScalaReflectionRelationSuite extends FunSuite { test("query case class RDD with Nones") { val data = OptionalReflectData(None, None, None, None, None, None, None) val rdd = sparkContext.parallelize(data :: Nil) - rdd.registerTempTable("reflectOptionalData") + rdd.toDF().registerTempTable("reflectOptionalData") assert(sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } @@ -105,7 +109,7 @@ class ScalaReflectionRelationSuite extends FunSuite { // Equality is broken for Arrays, so we test that separately. test("query binary data") { val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil) - rdd.registerTempTable("reflectBinary") + rdd.toDF().registerTempTable("reflectBinary") val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]] assert(result.toSeq === Seq[Byte](1)) @@ -124,7 +128,7 @@ class ScalaReflectionRelationSuite extends FunSuite { Map(10 -> Some(100L), 20 -> Some(200L), 30 -> None), Nested(None, "abc"))) val rdd = sparkContext.parallelize(data :: Nil) - rdd.registerTempTable("reflectComplexData") + rdd.toDF().registerTempTable("reflectComplexData") assert(sql("SELECT * FROM reflectComplexData").collect().head === new GenericRow(Array[Any]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala new file mode 100644 index 0000000000000..6f6d3c9c243d4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.FunSuite + +import org.apache.spark.SparkConf +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.sql.test.TestSQLContext + +class SerializationSuite extends FunSuite { + + test("[SPARK-5235] SQLContext should be serializable") { + val sqlContext = new SQLContext(TestSQLContext.sparkContext) + new JavaSerializer(new SparkConf()).newInstance().serialize(sqlContext) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index dd781169ca57f..637f59b2e68ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -20,21 +20,20 @@ package org.apache.spark.sql import java.sql.Timestamp import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ -/* Implicits */ -import org.apache.spark.sql.test.TestSQLContext._ case class TestData(key: Int, value: String) object TestData { val testData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(i, i.toString))).toDataFrame + (1 to 100).map(i => TestData(i, i.toString))).toDF() testData.registerTempTable("testData") val negativeData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(-i, (-i).toString))).toDataFrame + (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() negativeData.registerTempTable("negativeData") case class LargeAndSmallInts(a: Int, b: Int) @@ -45,7 +44,7 @@ object TestData { LargeAndSmallInts(2147483645, 1) :: LargeAndSmallInts(2, 2) :: LargeAndSmallInts(2147483646, 1) :: - LargeAndSmallInts(3, 2) :: Nil).toDataFrame + LargeAndSmallInts(3, 2) :: Nil).toDF() largeAndSmallInts.registerTempTable("largeAndSmallInts") case class TestData2(a: Int, b: Int) @@ -56,7 +55,7 @@ object TestData { TestData2(2, 1) :: TestData2(2, 2) :: TestData2(3, 1) :: - TestData2(3, 2) :: Nil, 2).toDataFrame + TestData2(3, 2) :: Nil, 2).toDF() testData2.registerTempTable("testData2") case class DecimalData(a: BigDecimal, b: BigDecimal) @@ -68,7 +67,7 @@ object TestData { DecimalData(2, 1) :: DecimalData(2, 2) :: DecimalData(3, 1) :: - DecimalData(3, 2) :: Nil).toDataFrame + DecimalData(3, 2) :: Nil).toDF() decimalData.registerTempTable("decimalData") case class BinaryData(a: Array[Byte], b: Int) @@ -78,14 +77,14 @@ object TestData { BinaryData("22".getBytes(), 5) :: BinaryData("122".getBytes(), 3) :: BinaryData("121".getBytes(), 2) :: - BinaryData("123".getBytes(), 4) :: Nil).toDataFrame + BinaryData("123".getBytes(), 4) :: Nil).toDF() binaryData.registerTempTable("binaryData") case class TestData3(a: Int, b: Option[Int]) val testData3 = TestSQLContext.sparkContext.parallelize( TestData3(1, None) :: - TestData3(2, Some(2)) :: Nil).toDataFrame + TestData3(2, Some(2)) :: Nil).toDF() testData3.registerTempTable("testData3") val emptyTableData = logical.LocalRelation($"a".int, $"b".int) @@ -98,7 +97,7 @@ object TestData { UpperCaseData(3, "C") :: UpperCaseData(4, "D") :: UpperCaseData(5, "E") :: - UpperCaseData(6, "F") :: Nil).toDataFrame + UpperCaseData(6, "F") :: Nil).toDF() upperCaseData.registerTempTable("upperCaseData") case class LowerCaseData(n: Int, l: String) @@ -107,7 +106,7 @@ object TestData { LowerCaseData(1, "a") :: LowerCaseData(2, "b") :: LowerCaseData(3, "c") :: - LowerCaseData(4, "d") :: Nil).toDataFrame + LowerCaseData(4, "d") :: Nil).toDF() lowerCaseData.registerTempTable("lowerCaseData") case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) @@ -115,7 +114,7 @@ object TestData { TestSQLContext.sparkContext.parallelize( ArrayData(Seq(1,2,3), Seq(Seq(1,2,3))) :: ArrayData(Seq(2,3,4), Seq(Seq(2,3,4))) :: Nil) - arrayData.registerTempTable("arrayData") + arrayData.toDF().registerTempTable("arrayData") case class MapData(data: scala.collection.Map[Int, String]) val mapData = @@ -125,18 +124,18 @@ object TestData { MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: MapData(Map(1 -> "a4", 2 -> "b4")) :: MapData(Map(1 -> "a5")) :: Nil) - mapData.registerTempTable("mapData") + mapData.toDF().registerTempTable("mapData") case class StringData(s: String) val repeatedData = TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) - repeatedData.registerTempTable("repeatedData") + repeatedData.toDF().registerTempTable("repeatedData") val nullableRepeatedData = TestSQLContext.sparkContext.parallelize( List.fill(2)(StringData(null)) ++ List.fill(2)(StringData("test"))) - nullableRepeatedData.registerTempTable("nullableRepeatedData") + nullableRepeatedData.toDF().registerTempTable("nullableRepeatedData") case class NullInts(a: Integer) val nullInts = @@ -145,7 +144,7 @@ object TestData { NullInts(2) :: NullInts(3) :: NullInts(null) :: Nil - ) + ).toDF() nullInts.registerTempTable("nullInts") val allNulls = @@ -153,7 +152,7 @@ object TestData { NullInts(null) :: NullInts(null) :: NullInts(null) :: - NullInts(null) :: Nil) + NullInts(null) :: Nil).toDF() allNulls.registerTempTable("allNulls") case class NullStrings(n: Int, s: String) @@ -161,11 +160,15 @@ object TestData { TestSQLContext.sparkContext.parallelize( NullStrings(1, "abc") :: NullStrings(2, "ABC") :: - NullStrings(3, null) :: Nil).toDataFrame + NullStrings(3, null) :: Nil).toDF() nullStrings.registerTempTable("nullStrings") case class TableName(tableName: String) - TestSQLContext.sparkContext.parallelize(TableName("test") :: Nil).registerTempTable("tableName") + TestSQLContext + .sparkContext + .parallelize(TableName("test") :: Nil) + .toDF() + .registerTempTable("tableName") val unparsedStrings = TestSQLContext.sparkContext.parallelize( @@ -178,22 +181,22 @@ object TestData { val timestamps = TestSQLContext.sparkContext.parallelize((1 to 3).map { i => TimestampField(new Timestamp(i)) }) - timestamps.registerTempTable("timestamps") + timestamps.toDF().registerTempTable("timestamps") case class IntField(i: Int) // An RDD with 4 elements and 8 partitions val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8) - withEmptyParts.registerTempTable("withEmptyParts") + withEmptyParts.toDF().registerTempTable("withEmptyParts") case class Person(id: Int, name: String, age: Int) case class Salary(personId: Int, salary: Double) val person = TestSQLContext.sparkContext.parallelize( Person(0, "mike", 30) :: - Person(1, "jim", 20) :: Nil) + Person(1, "jim", 20) :: Nil).toDF() person.registerTempTable("person") val salary = TestSQLContext.sparkContext.parallelize( Salary(0, 2000.0) :: - Salary(1, 1000.0) :: Nil) + Salary(1, 1000.0) :: Nil).toDF() salary.registerTempTable("salary") case class ComplexData(m: Map[Int, String], s: TestData, a: Seq[Int], b: Boolean) @@ -201,6 +204,6 @@ object TestData { TestSQLContext.sparkContext.parallelize( ComplexData(Map(1 -> "1"), TestData(1, "1"), Seq(1), true) :: ComplexData(Map(2 -> "2"), TestData(2, "2"), Seq(2), false) - :: Nil).toDataFrame + :: Nil).toDF() complexData.registerTempTable("complexData") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 95923f9aad931..d615542ab50a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql -import org.apache.spark.sql.Dsl.StringToColumn import org.apache.spark.sql.test._ /* Implicits */ import TestSQLContext._ +import TestSQLContext.implicits._ case class FunctionResult(f1: String, f2: String) @@ -50,4 +50,10 @@ class UDFSuite extends QueryTest { .select($"ret.f1").head().getString(0) assert(result === "test") } + + test("udf that is transformed") { + udf.register("makeStruct", (x: Int, y: Int) => (x, y)) + // 1 + 1 is constant folded causing a transformation. + assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 0696a2335e63f..23f424c0bfc7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -17,11 +17,15 @@ package org.apache.spark.sql +import java.io.File + import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Dsl._ -import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql} +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ @@ -58,13 +62,15 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { } override def userClass = classOf[MyDenseVector] + + private[spark] override def asNullable: MyDenseVectorUDT = this } class UserDefinedTypeSuite extends QueryTest { val points = Seq( MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) - val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points) + val pointsRDD = sparkContext.parallelize(points).toDF() test("register user type: MyDenseVector for MyLabeledPoint") { @@ -83,10 +89,32 @@ class UserDefinedTypeSuite extends QueryTest { } test("UDTs and UDFs") { - udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) + TestSQLContext.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) pointsRDD.registerTempTable("points") checkAnswer( sql("SELECT testType(features) from points"), Seq(Row(true), Row(true))) } + + + test("UDTs with Parquet") { + val tempDir = File.createTempFile("parquet", "test") + tempDir.delete() + pointsRDD.saveAsParquetFile(tempDir.getCanonicalPath) + } + + test("Repartition UDTs with Parquet") { + val tempDir = File.createTempFile("parquet", "test") + tempDir.delete() + pointsRDD.repartition(1).saveAsParquetFile(tempDir.getCanonicalPath) + } + + // Tests to make sure that all operators correctly convert types on the way out. + test("Local UDTs") { + val df = Seq((1, new MyDenseVector(Array(0.1, 1.0)))).toDF("int", "vec") + df.collect()(0).getAs[MyDenseVector](1) + df.take(1)(0).getAs[MyDenseVector](1) + df.limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0) + df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index be2b34de077c9..581fccf8ee613 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -30,7 +30,7 @@ class ColumnStatsSuite extends FunSuite { testColumnStats(classOf[FloatColumnStats], FLOAT, Row(Float.MaxValue, Float.MinValue, 0)) testColumnStats(classOf[DoubleColumnStats], DOUBLE, Row(Double.MaxValue, Double.MinValue, 0)) testColumnStats(classOf[StringColumnStats], STRING, Row(null, null, 0)) - testColumnStats(classOf[DateColumnStats], DATE, Row(null, null, 0)) + testColumnStats(classOf[DateColumnStats], DATE, Row(Int.MaxValue, Int.MinValue, 0)) testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(null, null, 0)) def testColumnStats[T <: NativeType, U <: ColumnStats]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 87e608a8853dc..9ce845912f1c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import org.scalatest.FunSuite @@ -34,7 +34,7 @@ class ColumnTypeSuite extends FunSuite with Logging { test("defaultSize") { val checks = Map( INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, BOOLEAN -> 1, - STRING -> 8, DATE -> 8, TIMESTAMP -> 12, BINARY -> 16, GENERIC -> 16) + STRING -> 8, DATE -> 4, TIMESTAMP -> 12, BINARY -> 16, GENERIC -> 16) checks.foreach { case (columnType, expectedSize) => assertResult(expectedSize, s"Wrong defaultSize for $columnType") { @@ -64,7 +64,7 @@ class ColumnTypeSuite extends FunSuite with Logging { checkActualSize(FLOAT, Float.MaxValue, 4) checkActualSize(BOOLEAN, true, 1) checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length) - checkActualSize(DATE, new Date(0L), 8) + checkActualSize(DATE, 0, 4) checkActualSize(TIMESTAMP, new Timestamp(0L), 12) val binary = Array.fill[Byte](4)(0: Byte) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index f941465fa3e35..60ed28cc97bf1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.columnar +import java.sql.Timestamp + import scala.collection.immutable.HashSet import scala.util.Random -import java.sql.{Date, Timestamp} - import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.types.{DataType, NativeType} @@ -50,7 +50,7 @@ object ColumnarTestUtils { case STRING => Random.nextString(Random.nextInt(32)) case BOOLEAN => Random.nextBoolean() case BINARY => randomBytes(Random.nextInt(32)) - case DATE => new Date(Random.nextLong()) + case DATE => Random.nextInt() case TIMESTAMP => val timestamp = new Timestamp(Random.nextLong()) timestamp.setNanos(Random.nextInt(999999999)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 3d33484ab0eb9..38b0f666ab90b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.columnar -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.{QueryTest, TestData} import org.apache.spark.storage.StorageLevel.MEMORY_ONLY @@ -37,7 +38,8 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("default size avoids broadcast") { // TODO: Improve this test when we have better statistics - sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).registerTempTable("sizeTst") + sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) + .toDF().registerTempTable("sizeTst") cacheTable("sizeTst") assert( table("sizeTst").queryExecution.logical.statistics.sizeInBytes > diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index fe9a69edbb920..e57bb06e7263b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} import org.apache.spark.sql._ import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter { val originalColumnBatchSize = conf.columnBatchSize @@ -33,7 +34,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be val pruningData = sparkContext.makeRDD((1 to 100).map { key => val string = if (((key - 1) / 10) % 2 == 0) null else key.toString TestData(key, string) - }, 5) + }, 5).toDF() pruningData.registerTempTable("pruningData") // Enable in-memory partition pruning diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index df108a9d262bb..523be56df65ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -20,12 +20,13 @@ package org.apache.spark.sql.execution import org.scalatest.FunSuite import org.apache.spark.sql.{SQLConf, execution} -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.test.TestSQLContext.planner._ import org.apache.spark.sql.types._ @@ -71,7 +72,7 @@ class PlannerSuite extends FunSuite { val schema = StructType(fields) val row = Row.fromSeq(Seq.fill(fields.size)(null)) val rowRDD = org.apache.spark.sql.test.TestSQLContext.sparkContext.parallelize(row :: Nil) - applySchema(rowRDD, schema).registerTempTable("testLimit") + createDataFrame(rowRDD, schema).registerTempTable("testLimit") val planned = sql( """ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index d25c1390db15c..592ed4b23b7d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -18,40 +18,52 @@ package org.apache.spark.sql.jdbc import java.math.BigDecimal +import java.sql.DriverManager +import java.util.{Calendar, GregorianCalendar, Properties} + import org.apache.spark.sql.test._ +import org.h2.jdbc.JdbcSQLException import org.scalatest.{FunSuite, BeforeAndAfter} -import java.sql.DriverManager import TestSQLContext._ +import TestSQLContext.implicits._ class JDBCSuite extends FunSuite with BeforeAndAfter { val url = "jdbc:h2:mem:testdb0" + val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass" var conn: java.sql.Connection = null val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte) before { Class.forName("org.h2.Driver") - conn = DriverManager.getConnection(url) + // Extra properties that will be specified for our database. We need these to test + // usage of parameters from OPTIONS clause in queries. + val properties = new Properties() + properties.setProperty("user", "testUser") + properties.setProperty("password", "testPass") + properties.setProperty("rowId", "false") + + conn = DriverManager.getConnection(url, properties) conn.prepareStatement("create schema test").executeUpdate() conn.prepareStatement("create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() conn.prepareStatement("insert into test.people values ('fred', 1)").executeUpdate() conn.prepareStatement("insert into test.people values ('mary', 2)").executeUpdate() - conn.prepareStatement("insert into test.people values ('joe', 3)").executeUpdate() + conn.prepareStatement("insert into test.people values ('joe ''foo'' \"bar\"', 3)").executeUpdate() conn.commit() sql( s""" |CREATE TEMPORARY TABLE foobar |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable 'TEST.PEOPLE') + |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) sql( s""" |CREATE TEMPORARY TABLE parts |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', - |partitionColumn 'THEID', lowerBound '1', upperBound '4', numPartitions '3') + |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass', + | partitionColumn 'THEID', lowerBound '1', upperBound '4', numPartitions '3') """.stripMargin.replaceAll("\n", " ")) conn.prepareStatement("create table test.inttypes (a INT, b BOOLEAN, c TINYINT, " @@ -65,12 +77,12 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { s""" |CREATE TEMPORARY TABLE inttypes |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable 'TEST.INTTYPES') + |OPTIONS (url '$url', dbtable 'TEST.INTTYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) conn.prepareStatement("create table test.strtypes (a BINARY(20), b VARCHAR(20), " + "c VARCHAR_IGNORECASE(20), d CHAR(20), e BLOB, f CLOB)").executeUpdate() - var stmt = conn.prepareStatement("insert into test.strtypes values (?, ?, ?, ?, ?, ?)") + val stmt = conn.prepareStatement("insert into test.strtypes values (?, ?, ?, ?, ?, ?)") stmt.setBytes(1, testBytes) stmt.setString(2, "Sensitive") stmt.setString(3, "Insensitive") @@ -82,7 +94,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { s""" |CREATE TEMPORARY TABLE strtypes |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable 'TEST.STRTYPES') + |OPTIONS (url '$url', dbtable 'TEST.STRTYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) conn.prepareStatement("create table test.timetypes (a TIME, b DATE, c TIMESTAMP)" @@ -94,7 +106,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { s""" |CREATE TEMPORARY TABLE timetypes |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable 'TEST.TIMETYPES') + |OPTIONS (url '$url', dbtable 'TEST.TIMETYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) @@ -109,7 +121,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { s""" |CREATE TEMPORARY TABLE flttypes |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable 'TEST.FLTTYPES') + |OPTIONS (url '$url', dbtable 'TEST.FLTTYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) // Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types. @@ -127,13 +139,20 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { assert(sql("SELECT * FROM foobar WHERE THEID < 1").collect().size == 0) assert(sql("SELECT * FROM foobar WHERE THEID != 2").collect().size == 2) assert(sql("SELECT * FROM foobar WHERE THEID = 1").collect().size == 1) + assert(sql("SELECT * FROM foobar WHERE NAME = 'fred'").collect().size == 1) + assert(sql("SELECT * FROM foobar WHERE NAME > 'fred'").collect().size == 2) + assert(sql("SELECT * FROM foobar WHERE NAME != 'fred'").collect().size == 2) + } + + test("SELECT * WHERE (quoted strings)") { + assert(sql("select * from foobar").where('NAME === "joe 'foo' \"bar\"").collect().size == 1) } test("SELECT first field") { val names = sql("SELECT NAME FROM foobar").collect().map(x => x.getString(0)).sortWith(_ < _) assert(names.size == 3) assert(names(0).equals("fred")) - assert(names(1).equals("joe")) + assert(names(1).equals("joe 'foo' \"bar\"")) assert(names(2).equals("mary")) } @@ -164,17 +183,17 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { } test("Basic API") { - assert(TestSQLContext.jdbcRDD(url, "TEST.PEOPLE").collect.size == 3) + assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE").collect.size == 3) } test("Partitioning via JDBCPartitioningInfo API") { - val parts = JDBCPartitioningInfo("THEID", 0, 4, 3) - assert(TestSQLContext.jdbcRDD(url, "TEST.PEOPLE", parts).collect.size == 3) + assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3) + .collect.size == 3) } test("Partitioning via list-of-where-clauses API") { val parts = Array[String]("THEID < 2", "THEID >= 2") - assert(TestSQLContext.jdbcRDD(url, "TEST.PEOPLE", parts).collect.size == 3) + assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts).collect.size == 3) } test("H2 integral types") { @@ -209,18 +228,22 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { test("H2 time types") { val rows = sql("SELECT * FROM timetypes").collect() - assert(rows(0).getAs[java.sql.Timestamp](0).getHours == 12) - assert(rows(0).getAs[java.sql.Timestamp](0).getMinutes == 34) - assert(rows(0).getAs[java.sql.Timestamp](0).getSeconds == 56) - assert(rows(0).getAs[java.sql.Date](1).getYear == 96) - assert(rows(0).getAs[java.sql.Date](1).getMonth == 0) - assert(rows(0).getAs[java.sql.Date](1).getDate == 1) - assert(rows(0).getAs[java.sql.Timestamp](2).getYear == 102) - assert(rows(0).getAs[java.sql.Timestamp](2).getMonth == 1) - assert(rows(0).getAs[java.sql.Timestamp](2).getDate == 20) - assert(rows(0).getAs[java.sql.Timestamp](2).getHours == 11) - assert(rows(0).getAs[java.sql.Timestamp](2).getMinutes == 22) - assert(rows(0).getAs[java.sql.Timestamp](2).getSeconds == 33) + val cal = new GregorianCalendar(java.util.Locale.ROOT) + cal.setTime(rows(0).getAs[java.sql.Timestamp](0)) + assert(cal.get(Calendar.HOUR_OF_DAY) == 12) + assert(cal.get(Calendar.MINUTE) == 34) + assert(cal.get(Calendar.SECOND) == 56) + cal.setTime(rows(0).getAs[java.sql.Timestamp](1)) + assert(cal.get(Calendar.YEAR) == 1996) + assert(cal.get(Calendar.MONTH) == 0) + assert(cal.get(Calendar.DAY_OF_MONTH) == 1) + cal.setTime(rows(0).getAs[java.sql.Timestamp](2)) + assert(cal.get(Calendar.YEAR) == 2002) + assert(cal.get(Calendar.MONTH) == 1) + assert(cal.get(Calendar.DAY_OF_MONTH) == 20) + assert(cal.get(Calendar.HOUR) == 11) + assert(cal.get(Calendar.MINUTE) == 22) + assert(cal.get(Calendar.SECOND) == 33) assert(rows(0).getAs[java.sql.Timestamp](2).getNanos == 543543543) } @@ -232,17 +255,31 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { .equals(new BigDecimal("123456789012345.54321543215432100000"))) } - test("SQL query as table name") { sql( s""" |CREATE TEMPORARY TABLE hack |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable '(SELECT B, B*B FROM TEST.FLTTYPES)') + |OPTIONS (url '$url', dbtable '(SELECT B, B*B FROM TEST.FLTTYPES)', + | user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) val rows = sql("SELECT * FROM hack").collect() assert(rows(0).getDouble(0) == 1.00000011920928955) // Yes, I meant ==. // For some reason, H2 computes this square incorrectly... assert(math.abs(rows(0).getDouble(1) - 1.00000023841859331) < 1e-12) } + + test("Pass extra properties via OPTIONS") { + // We set rowId to false during setup, which means that _ROWID_ column should be absent from + // all tables. If rowId is true (default), the query below doesn't throw an exception. + intercept[JdbcSQLException] { + sql( + s""" + |CREATE TEMPORARY TABLE abc + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable '(SELECT _ROWID_ FROM test.people)', + | user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index e581ac9b50c2b..ee5c7620d1a22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.jdbc -import java.math.BigDecimal +import java.sql.DriverManager + +import org.scalatest.{BeforeAndAfter, FunSuite} + import org.apache.spark.sql.Row -import org.apache.spark.sql.types._ import org.apache.spark.sql.test._ -import org.scalatest.{FunSuite, BeforeAndAfter} -import java.sql.DriverManager -import TestSQLContext._ +import org.apache.spark.sql.types._ class JDBCWriteSuite extends FunSuite with BeforeAndAfter { val url = "jdbc:h2:mem:testdb2" @@ -54,53 +54,53 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter { StructField("seq", IntegerType) :: Nil) test("Basic CREATE") { - val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2) + val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) - srdd.createJDBCTable(url, "TEST.BASICCREATETEST", false) - assert(2 == TestSQLContext.jdbcRDD(url, "TEST.BASICCREATETEST").count) - assert(2 == TestSQLContext.jdbcRDD(url, "TEST.BASICCREATETEST").collect()(0).length) + df.createJDBCTable(url, "TEST.BASICCREATETEST", false) + assert(2 == TestSQLContext.jdbc(url, "TEST.BASICCREATETEST").count) + assert(2 == TestSQLContext.jdbc(url, "TEST.BASICCREATETEST").collect()(0).length) } test("CREATE with overwrite") { - val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x3), schema3) - val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2) + val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3) + val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) - srdd.createJDBCTable(url, "TEST.DROPTEST", false) - assert(2 == TestSQLContext.jdbcRDD(url, "TEST.DROPTEST").count) - assert(3 == TestSQLContext.jdbcRDD(url, "TEST.DROPTEST").collect()(0).length) + df.createJDBCTable(url, "TEST.DROPTEST", false) + assert(2 == TestSQLContext.jdbc(url, "TEST.DROPTEST").count) + assert(3 == TestSQLContext.jdbc(url, "TEST.DROPTEST").collect()(0).length) - srdd2.createJDBCTable(url, "TEST.DROPTEST", true) - assert(1 == TestSQLContext.jdbcRDD(url, "TEST.DROPTEST").count) - assert(2 == TestSQLContext.jdbcRDD(url, "TEST.DROPTEST").collect()(0).length) + df2.createJDBCTable(url, "TEST.DROPTEST", true) + assert(1 == TestSQLContext.jdbc(url, "TEST.DROPTEST").count) + assert(2 == TestSQLContext.jdbc(url, "TEST.DROPTEST").collect()(0).length) } test("CREATE then INSERT to append") { - val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2) - val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2) + val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) + val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) - srdd.createJDBCTable(url, "TEST.APPENDTEST", false) - srdd2.insertIntoJDBC(url, "TEST.APPENDTEST", false) - assert(3 == TestSQLContext.jdbcRDD(url, "TEST.APPENDTEST").count) - assert(2 == TestSQLContext.jdbcRDD(url, "TEST.APPENDTEST").collect()(0).length) + df.createJDBCTable(url, "TEST.APPENDTEST", false) + df2.insertIntoJDBC(url, "TEST.APPENDTEST", false) + assert(3 == TestSQLContext.jdbc(url, "TEST.APPENDTEST").count) + assert(2 == TestSQLContext.jdbc(url, "TEST.APPENDTEST").collect()(0).length) } test("CREATE then INSERT to truncate") { - val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2) - val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2) + val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) + val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) - srdd.createJDBCTable(url, "TEST.TRUNCATETEST", false) - srdd2.insertIntoJDBC(url, "TEST.TRUNCATETEST", true) - assert(1 == TestSQLContext.jdbcRDD(url, "TEST.TRUNCATETEST").count) - assert(2 == TestSQLContext.jdbcRDD(url, "TEST.TRUNCATETEST").collect()(0).length) + df.createJDBCTable(url, "TEST.TRUNCATETEST", false) + df2.insertIntoJDBC(url, "TEST.TRUNCATETEST", true) + assert(1 == TestSQLContext.jdbc(url, "TEST.TRUNCATETEST").count) + assert(2 == TestSQLContext.jdbc(url, "TEST.TRUNCATETEST").collect()(0).length) } test("Incompatible INSERT to append") { - val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2) - val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr2x3), schema3) + val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) + val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3) - srdd.createJDBCTable(url, "TEST.INCOMPATIBLETEST", false) + df.createJDBCTable(url, "TEST.INCOMPATIBLETEST", false) intercept[org.apache.spark.SparkException] { - srdd2.insertIntoJDBC(url, "TEST.INCOMPATIBLETEST", true) + df2.insertIntoJDBC(url, "TEST.INCOMPATIBLETEST", true) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala deleted file mode 100644 index 89920f2650c3a..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala +++ /dev/null @@ -1,235 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.jdbc - -import java.math.BigDecimal -import java.sql.{Date, DriverManager, Timestamp} -import com.spotify.docker.client.{DefaultDockerClient, DockerClient} -import com.spotify.docker.client.messages.ContainerConfig -import org.scalatest.{FunSuite, BeforeAndAfterAll, Ignore} - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ -import org.apache.spark.sql._ -import org.apache.spark.sql.test._ -import TestSQLContext._ - -import org.apache.spark.sql.jdbc._ - -class MySQLDatabase { - val docker: DockerClient = DockerClientFactory.get() - val containerId = { - println("Pulling mysql") - docker.pull("mysql") - println("Configuring container") - val config = (ContainerConfig.builder().image("mysql") - .env("MYSQL_ROOT_PASSWORD=rootpass") - .build()) - println("Creating container") - val id = docker.createContainer(config).id - println("Starting container " + id) - docker.startContainer(id) - id - } - val ip = docker.inspectContainer(containerId).networkSettings.ipAddress - - def close() { - try { - println("Killing container " + containerId) - docker.killContainer(containerId) - println("Removing container " + containerId) - docker.removeContainer(containerId) - println("Closing docker client") - DockerClientFactory.close(docker) - } catch { - case e: Exception => { - println(e) - println("You may need to clean this up manually.") - throw e - } - } - } -} - -@Ignore class MySQLIntegration extends FunSuite with BeforeAndAfterAll { - var ip: String = null - - def url(ip: String): String = url(ip, "mysql") - def url(ip: String, db: String): String = s"jdbc:mysql://$ip:3306/$db?user=root&password=rootpass" - - def waitForDatabase(ip: String, maxMillis: Long) { - println("Waiting for database to start up.") - val before = System.currentTimeMillis() - var lastException: java.sql.SQLException = null - while (true) { - if (System.currentTimeMillis() > before + maxMillis) { - throw new java.sql.SQLException(s"Database not up after $maxMillis ms.", lastException) - } - try { - val conn = java.sql.DriverManager.getConnection(url(ip)) - conn.close() - println("Database is up.") - return; - } catch { - case e: java.sql.SQLException => { - lastException = e - java.lang.Thread.sleep(250) - } - } - } - } - - def setupDatabase(ip: String) { - val conn = java.sql.DriverManager.getConnection(url(ip)) - try { - conn.prepareStatement("CREATE DATABASE foo").executeUpdate() - conn.prepareStatement("CREATE TABLE foo.tbl (x INTEGER, y TEXT(8))").executeUpdate() - conn.prepareStatement("INSERT INTO foo.tbl VALUES (42,'fred')").executeUpdate() - conn.prepareStatement("INSERT INTO foo.tbl VALUES (17,'dave')").executeUpdate() - - conn.prepareStatement("CREATE TABLE foo.numbers (onebit BIT(1), tenbits BIT(10), " - + "small SMALLINT, med MEDIUMINT, nor INT, big BIGINT, deci DECIMAL(40,20), flt FLOAT, " - + "dbl DOUBLE)").executeUpdate() - conn.prepareStatement("INSERT INTO foo.numbers VALUES (b'0', b'1000100101', " - + "17, 77777, 123456789, 123456789012345, 123456789012345.123456789012345, " - + "42.75, 1.0000000000000002)").executeUpdate() - - conn.prepareStatement("CREATE TABLE foo.dates (d DATE, t TIME, dt DATETIME, ts TIMESTAMP, " - + "yr YEAR)").executeUpdate() - conn.prepareStatement("INSERT INTO foo.dates VALUES ('1991-11-09', '13:31:24', " - + "'1996-01-01 01:23:45', '2009-02-13 23:31:30', '2001')").executeUpdate() - - // TODO: Test locale conversion for strings. - conn.prepareStatement("CREATE TABLE foo.strings (a CHAR(10), b VARCHAR(10), c TINYTEXT, " - + "d TEXT, e MEDIUMTEXT, f LONGTEXT, g BINARY(4), h VARBINARY(10), i BLOB)" - ).executeUpdate() - conn.prepareStatement("INSERT INTO foo.strings VALUES ('the', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog')").executeUpdate() - } finally { - conn.close() - } - } - - var db: MySQLDatabase = null - - override def beforeAll() { - // If you load the MySQL driver here, DriverManager will deadlock. The - // MySQL driver gets loaded when its jar gets loaded, unlike the Postgres - // and H2 drivers. - //Class.forName("com.mysql.jdbc.Driver") - - db = new MySQLDatabase() - waitForDatabase(db.ip, 60000) - setupDatabase(db.ip) - ip = db.ip - } - - override def afterAll() { - db.close() - } - - test("Basic test") { - val rdd = TestSQLContext.jdbcRDD(url(ip, "foo"), "tbl") - val rows = rdd.collect - assert(rows.length == 2) - val types = rows(0).toSeq.map(x => x.getClass.toString) - assert(types.length == 2) - assert(types(0).equals("class java.lang.Integer")) - assert(types(1).equals("class java.lang.String")) - } - - test("Numeric types") { - val rdd = TestSQLContext.jdbcRDD(url(ip, "foo"), "numbers") - val rows = rdd.collect - assert(rows.length == 1) - val types = rows(0).toSeq.map(x => x.getClass.toString) - assert(types.length == 9) - println(types(1)) - assert(types(0).equals("class java.lang.Boolean")) - assert(types(1).equals("class java.lang.Long")) - assert(types(2).equals("class java.lang.Integer")) - assert(types(3).equals("class java.lang.Integer")) - assert(types(4).equals("class java.lang.Integer")) - assert(types(5).equals("class java.lang.Long")) - assert(types(6).equals("class java.math.BigDecimal")) - assert(types(7).equals("class java.lang.Double")) - assert(types(8).equals("class java.lang.Double")) - assert(rows(0).getBoolean(0) == false) - assert(rows(0).getLong(1) == 0x225) - assert(rows(0).getInt(2) == 17) - assert(rows(0).getInt(3) == 77777) - assert(rows(0).getInt(4) == 123456789) - assert(rows(0).getLong(5) == 123456789012345L) - val bd = new BigDecimal("123456789012345.12345678901234500000") - assert(rows(0).getAs[BigDecimal](6).equals(bd)) - assert(rows(0).getDouble(7) == 42.75) - assert(rows(0).getDouble(8) == 1.0000000000000002) - } - - test("Date types") { - val rdd = TestSQLContext.jdbcRDD(url(ip, "foo"), "dates") - val rows = rdd.collect - assert(rows.length == 1) - val types = rows(0).toSeq.map(x => x.getClass.toString) - assert(types.length == 5) - assert(types(0).equals("class java.sql.Date")) - assert(types(1).equals("class java.sql.Timestamp")) - assert(types(2).equals("class java.sql.Timestamp")) - assert(types(3).equals("class java.sql.Timestamp")) - assert(types(4).equals("class java.sql.Date")) - assert(rows(0).getAs[Date](0).equals(new Date(91, 10, 9))) - assert(rows(0).getAs[Timestamp](1).equals(new Timestamp(70, 0, 1, 13, 31, 24, 0))) - assert(rows(0).getAs[Timestamp](2).equals(new Timestamp(96, 0, 1, 1, 23, 45, 0))) - assert(rows(0).getAs[Timestamp](3).equals(new Timestamp(109, 1, 13, 23, 31, 30, 0))) - assert(rows(0).getAs[Date](4).equals(new Date(101, 0, 1))) - } - - test("String types") { - val rdd = TestSQLContext.jdbcRDD(url(ip, "foo"), "strings") - val rows = rdd.collect - assert(rows.length == 1) - val types = rows(0).toSeq.map(x => x.getClass.toString) - assert(types.length == 9) - assert(types(0).equals("class java.lang.String")) - assert(types(1).equals("class java.lang.String")) - assert(types(2).equals("class java.lang.String")) - assert(types(3).equals("class java.lang.String")) - assert(types(4).equals("class java.lang.String")) - assert(types(5).equals("class java.lang.String")) - assert(types(6).equals("class [B")) - assert(types(7).equals("class [B")) - assert(types(8).equals("class [B")) - assert(rows(0).getString(0).equals("the")) - assert(rows(0).getString(1).equals("quick")) - assert(rows(0).getString(2).equals("brown")) - assert(rows(0).getString(3).equals("fox")) - assert(rows(0).getString(4).equals("jumps")) - assert(rows(0).getString(5).equals("over")) - assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), Array[Byte](116, 104, 101, 0))) - assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](7), Array[Byte](108, 97, 122, 121))) - assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](8), Array[Byte](100, 111, 103))) - } - - test("Basic write test") { - val rdd1 = TestSQLContext.jdbcRDD(url(ip, "foo"), "numbers") - val rdd2 = TestSQLContext.jdbcRDD(url(ip, "foo"), "dates") - val rdd3 = TestSQLContext.jdbcRDD(url(ip, "foo"), "strings") - rdd1.createJDBCTable(url(ip, "foo"), "numberscopy", false) - rdd2.createJDBCTable(url(ip, "foo"), "datescopy", false) - rdd3.createJDBCTable(url(ip, "foo"), "stringscopy", false) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala deleted file mode 100644 index c174d7adb7204..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.jdbc - -import java.math.BigDecimal -import org.apache.spark.sql.test._ -import org.scalatest.{FunSuite, BeforeAndAfterAll, Ignore} -import java.sql.DriverManager -import TestSQLContext._ -import com.spotify.docker.client.{DefaultDockerClient, DockerClient} -import com.spotify.docker.client.messages.ContainerConfig - -class PostgresDatabase { - val docker: DockerClient = DockerClientFactory.get() - val containerId = { - println("Pulling postgres") - docker.pull("postgres") - println("Configuring container") - val config = (ContainerConfig.builder().image("postgres") - .env("POSTGRES_PASSWORD=rootpass") - .build()) - println("Creating container") - val id = docker.createContainer(config).id - println("Starting container " + id) - docker.startContainer(id) - id - } - val ip = docker.inspectContainer(containerId).networkSettings.ipAddress - - def close() { - try { - println("Killing container " + containerId) - docker.killContainer(containerId) - println("Removing container " + containerId) - docker.removeContainer(containerId) - println("Closing docker client") - DockerClientFactory.close(docker) - } catch { - case e: Exception => { - println(e) - println("You may need to clean this up manually.") - throw e - } - } - } -} - -@Ignore class PostgresIntegration extends FunSuite with BeforeAndAfterAll { - lazy val db = new PostgresDatabase() - - def url(ip: String) = s"jdbc:postgresql://$ip:5432/postgres?user=postgres&password=rootpass" - - def waitForDatabase(ip: String, maxMillis: Long) { - val before = System.currentTimeMillis() - var lastException: java.sql.SQLException = null - while (true) { - if (System.currentTimeMillis() > before + maxMillis) { - throw new java.sql.SQLException(s"Database not up after $maxMillis ms.", - lastException) - } - try { - val conn = java.sql.DriverManager.getConnection(url(ip)) - conn.close() - println("Database is up.") - return; - } catch { - case e: java.sql.SQLException => { - lastException = e - java.lang.Thread.sleep(250) - } - } - } - } - - def setupDatabase(ip: String) { - val conn = DriverManager.getConnection(url(ip)) - try { - conn.prepareStatement("CREATE DATABASE foo").executeUpdate() - conn.setCatalog("foo") - conn.prepareStatement("CREATE TABLE bar (a text, b integer, c double precision, d bigint, " - + "e bit(1), f bit(10), g bytea, h boolean, i inet, j cidr)").executeUpdate() - conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', " - + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16')").executeUpdate() - } finally { - conn.close() - } - } - - override def beforeAll() { - println("Waiting for database to start up.") - waitForDatabase(db.ip, 60000) - println("Setting up database.") - setupDatabase(db.ip) - } - - override def afterAll() { - db.close() - } - - test("Type mapping for various types") { - val rdd = TestSQLContext.jdbcRDD(url(db.ip), "public.bar") - val rows = rdd.collect - assert(rows.length == 1) - val types = rows(0).toSeq.map(x => x.getClass.toString) - assert(types.length == 10) - assert(types(0).equals("class java.lang.String")) - assert(types(1).equals("class java.lang.Integer")) - assert(types(2).equals("class java.lang.Double")) - assert(types(3).equals("class java.lang.Long")) - assert(types(4).equals("class java.lang.Boolean")) - assert(types(5).equals("class [B")) - assert(types(6).equals("class [B")) - assert(types(7).equals("class java.lang.Boolean")) - assert(types(8).equals("class java.lang.String")) - assert(types(9).equals("class java.lang.String")) - assert(rows(0).getString(0).equals("hello")) - assert(rows(0).getInt(1) == 42) - assert(rows(0).getDouble(2) == 1.25) - assert(rows(0).getLong(3) == 123456789012345L) - assert(rows(0).getBoolean(4) == false) - // BIT(10)'s come back as ASCII strings of ten ASCII 0's and 1's... - assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](5), Array[Byte](49,48,48,48,49,48,48,49,48,49))) - assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), Array[Byte](0xDE.toByte, 0xAD.toByte, 0xBE.toByte, 0xEF.toByte))) - assert(rows(0).getBoolean(7) == true) - assert(rows(0).getString(8) == "172.16.0.42") - assert(rows(0).getString(9) == "192.168.0.0/16") - } - - test("Basic write test") { - val rdd = TestSQLContext.jdbcRDD(url(db.ip), "public.bar") - rdd.createJDBCTable(url(db.ip), "public.barcopy", false) - // Test only that it doesn't bomb out. - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index cb615388da0c7..320b80d80e997 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -19,12 +19,16 @@ package org.apache.spark.sql.json import java.sql.{Date, Timestamp} +import org.scalactic.Tolerance._ + import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.json.JsonRDD.{compatibleType, enforceCorrectType} +import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{QueryTest, Row, SQLConf} @@ -67,14 +71,15 @@ class JsonSuite extends QueryTest { checkTypePromotion(Timestamp.valueOf(strTime), enforceCorrectType(strTime, TimestampType)) val strDate = "2014-10-15" - checkTypePromotion(Date.valueOf(strDate), enforceCorrectType(strDate, DateType)) + checkTypePromotion( + DateUtils.fromJavaDate(Date.valueOf(strDate)), enforceCorrectType(strDate, DateType)) val ISO8601Time1 = "1970-01-01T01:00:01.0Z" checkTypePromotion(new Timestamp(3601000), enforceCorrectType(ISO8601Time1, TimestampType)) - checkTypePromotion(new Date(3601000), enforceCorrectType(ISO8601Time1, DateType)) + checkTypePromotion(DateUtils.millisToDays(3601000), enforceCorrectType(ISO8601Time1, DateType)) val ISO8601Time2 = "1970-01-01T02:00:01-01:00" checkTypePromotion(new Timestamp(10801000), enforceCorrectType(ISO8601Time2, TimestampType)) - checkTypePromotion(new Date(10801000), enforceCorrectType(ISO8601Time2, DateType)) + checkTypePromotion(DateUtils.millisToDays(10801000), enforceCorrectType(ISO8601Time2, DateType)) } test("Get compatible type") { @@ -220,7 +225,7 @@ class JsonSuite extends QueryTest { StructField("bigInteger", DecimalType.Unlimited, true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: - StructField("integer", IntegerType, true) :: + StructField("integer", LongType, true) :: StructField("long", LongType, true) :: StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) @@ -245,26 +250,26 @@ class JsonSuite extends QueryTest { val jsonDF = jsonRDD(complexFieldAndType1) val expectedSchema = StructType( - StructField("arrayOfArray1", ArrayType(ArrayType(StringType, false), false), true) :: - StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType, false), false), true) :: - StructField("arrayOfBigInteger", ArrayType(DecimalType.Unlimited, false), true) :: - StructField("arrayOfBoolean", ArrayType(BooleanType, false), true) :: - StructField("arrayOfDouble", ArrayType(DoubleType, false), true) :: - StructField("arrayOfInteger", ArrayType(IntegerType, false), true) :: - StructField("arrayOfLong", ArrayType(LongType, false), true) :: + StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: + StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType, true), true), true) :: + StructField("arrayOfBigInteger", ArrayType(DecimalType.Unlimited, true), true) :: + StructField("arrayOfBoolean", ArrayType(BooleanType, true), true) :: + StructField("arrayOfDouble", ArrayType(DoubleType, true), true) :: + StructField("arrayOfInteger", ArrayType(LongType, true), true) :: + StructField("arrayOfLong", ArrayType(LongType, true), true) :: StructField("arrayOfNull", ArrayType(StringType, true), true) :: - StructField("arrayOfString", ArrayType(StringType, false), true) :: + StructField("arrayOfString", ArrayType(StringType, true), true) :: StructField("arrayOfStruct", ArrayType( StructType( StructField("field1", BooleanType, true) :: StructField("field2", StringType, true) :: - StructField("field3", StringType, true) :: Nil), false), true) :: + StructField("field3", StringType, true) :: Nil), true), true) :: StructField("struct", StructType( StructField("field1", BooleanType, true) :: StructField("field2", DecimalType.Unlimited, true) :: Nil), true) :: StructField("structWithArrayFields", StructType( - StructField("field1", ArrayType(IntegerType, false), true) :: - StructField("field2", ArrayType(StringType, false), true) :: Nil), true) :: Nil) + StructField("field1", ArrayType(LongType, true), true) :: + StructField("field2", ArrayType(StringType, true), true) :: Nil), true) :: Nil) assert(expectedSchema === jsonDF.schema) @@ -340,21 +345,19 @@ class JsonSuite extends QueryTest { ) } - ignore("Complex field and type inferring (Ignored)") { + test("GetField operation on complex data type") { val jsonDF = jsonRDD(complexFieldAndType1) jsonDF.registerTempTable("jsonTable") - // Right now, "field1" and "field2" are treated as aliases. We should fix it. checkAnswer( sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), Row(true, "str1") ) - // Right now, the analyzer cannot resolve arrayOfStruct.field1 and arrayOfStruct.field2. // Getting all values of a specific field from an array of structs. checkAnswer( sql("select arrayOfStruct.field1, arrayOfStruct.field2 from jsonTable"), - Row(Seq(true, false), Seq("str1", null)) + Row(Seq(true, false, null), Seq("str1", null, null)) ) } @@ -486,7 +489,7 @@ class JsonSuite extends QueryTest { val jsonDF = jsonRDD(complexFieldValueTypeConflict) val expectedSchema = StructType( - StructField("array", ArrayType(IntegerType, false), true) :: + StructField("array", ArrayType(LongType, true), true) :: StructField("num_struct", StringType, true) :: StructField("str_array", StringType, true) :: StructField("struct", StructType( @@ -512,8 +515,8 @@ class JsonSuite extends QueryTest { val expectedSchema = StructType( StructField("array1", ArrayType(StringType, true), true) :: StructField("array2", ArrayType(StructType( - StructField("field", LongType, true) :: Nil), false), true) :: - StructField("array3", ArrayType(StringType, false), true) :: Nil) + StructField("field", LongType, true) :: Nil), true), true) :: + StructField("array3", ArrayType(StringType, true), true) :: Nil) assert(expectedSchema === jsonDF.schema) @@ -540,7 +543,7 @@ class JsonSuite extends QueryTest { val expectedSchema = StructType( StructField("a", BooleanType, true) :: StructField("b", LongType, true) :: - StructField("c", ArrayType(IntegerType, false), true) :: + StructField("c", ArrayType(LongType, true), true) :: StructField("d", StructType( StructField("field", BooleanType, true) :: Nil), true) :: StructField("e", StringType, true) :: Nil) @@ -550,6 +553,32 @@ class JsonSuite extends QueryTest { jsonDF.registerTempTable("jsonTable") } + test("jsonFile should be based on JSONRelation") { + val file = getTempFilePath("json") + val path = file.toString + sparkContext.parallelize(1 to 100).map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) + val jsonDF = jsonFile(path, 0.49) + + val analyzed = jsonDF.queryExecution.analyzed + assert( + analyzed.isInstanceOf[LogicalRelation], + "The DataFrame returned by jsonFile should be based on JSONRelation.") + val relation = analyzed.asInstanceOf[LogicalRelation].relation + assert( + relation.isInstanceOf[JSONRelation], + "The DataFrame returned by jsonFile should be based on JSONRelation.") + assert(relation.asInstanceOf[JSONRelation].path === path) + assert(relation.asInstanceOf[JSONRelation].samplingRatio === (0.49 +- 0.001)) + + val schema = StructType(StructField("a", LongType, true) :: Nil) + val logicalRelation = + jsonFile(path, schema).queryExecution.analyzed.asInstanceOf[LogicalRelation] + val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation] + assert(relationWithSchema.path === path) + assert(relationWithSchema.schema === schema) + assert(relationWithSchema.samplingRatio > 0.99) + } + test("Loading a JSON dataset from a text file") { val file = getTempFilePath("json") val path = file.toString @@ -560,7 +589,7 @@ class JsonSuite extends QueryTest { StructField("bigInteger", DecimalType.Unlimited, true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: - StructField("integer", IntegerType, true) :: + StructField("integer", LongType, true) :: StructField("long", LongType, true) :: StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) @@ -656,6 +685,62 @@ class JsonSuite extends QueryTest { ) } + test("Applying schemas with MapType") { + val schemaWithSimpleMap = StructType( + StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) + val jsonWithSimpleMap = jsonRDD(mapType1, schemaWithSimpleMap) + + jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap") + + checkAnswer( + sql("select map from jsonWithSimpleMap"), + Row(Map("a" -> 1)) :: + Row(Map("b" -> 2)) :: + Row(Map("c" -> 3)) :: + Row(Map("c" -> 1, "d" -> 4)) :: + Row(Map("e" -> null)) :: Nil + ) + + checkAnswer( + sql("select map['c'] from jsonWithSimpleMap"), + Row(null) :: + Row(null) :: + Row(3) :: + Row(1) :: + Row(null) :: Nil + ) + + val innerStruct = StructType( + StructField("field1", ArrayType(IntegerType, true), true) :: + StructField("field2", IntegerType, true) :: Nil) + val schemaWithComplexMap = StructType( + StructField("map", MapType(StringType, innerStruct, true), false) :: Nil) + + val jsonWithComplexMap = jsonRDD(mapType2, schemaWithComplexMap) + + jsonWithComplexMap.registerTempTable("jsonWithComplexMap") + + checkAnswer( + sql("select map from jsonWithComplexMap"), + Row(Map("a" -> Row(Seq(1, 2, 3, null), null))) :: + Row(Map("b" -> Row(null, 2))) :: + Row(Map("c" -> Row(Seq(), 4))) :: + Row(Map("c" -> Row(null, 3), "d" -> Row(Seq(null), null))) :: + Row(Map("e" -> null)) :: + Row(Map("f" -> Row(null, null))) :: Nil + ) + + checkAnswer( + sql("select map['a'].field1, map['c'].field2 from jsonWithComplexMap"), + Row(Seq(1, 2, 3, null), null) :: + Row(null, null) :: + Row(null, 4) :: + Row(null, 3) :: + Row(null, null) :: + Row(null, null) :: Nil + ) + } + test("SPARK-2096 Correctly parse dot notations") { val jsonDF = jsonRDD(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") @@ -778,15 +863,15 @@ class JsonSuite extends QueryTest { val schema = StructType( StructField("field1", - ArrayType(ArrayType(ArrayType(ArrayType(StringType, false), false), true), false), true) :: + ArrayType(ArrayType(ArrayType(ArrayType(StringType, true), true), true), true), true) :: StructField("field2", ArrayType(ArrayType( - StructType(StructField("Test", IntegerType, true) :: Nil), false), true), true) :: + StructType(StructField("Test", LongType, true) :: Nil), true), true), true) :: StructField("field3", ArrayType(ArrayType( - StructType(StructField("Test", StringType, true) :: Nil), true), false), true) :: + StructType(StructField("Test", StringType, true) :: Nil), true), true), true) :: StructField("field4", - ArrayType(ArrayType(ArrayType(IntegerType, false), true), false), true) :: Nil) + ArrayType(ArrayType(ArrayType(LongType, true), true), true), true) :: Nil) assert(schema === jsonDF.schema) @@ -820,12 +905,12 @@ class JsonSuite extends QueryTest { Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5) } - val df1 = applySchema(rowRDD1, schema1) + val df1 = createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") - val df2 = df1.toDataFrame + val df2 = df1.toDF val result = df2.toJSON.collect() - assert(result(0) == "{\"f1\":1,\"f2\":\"A1\",\"f3\":true,\"f4\":[\"1\",\" A1\",\" true\",\" null\"]}") - assert(result(3) == "{\"f1\":4,\"f2\":\"D4\",\"f3\":true,\"f4\":[\"4\",\" D4\",\" true\",\" 2147483644\"],\"f5\":2147483644}") + assert(result(0) === "{\"f1\":1,\"f2\":\"A1\",\"f3\":true,\"f4\":[\"1\",\" A1\",\" true\",\" null\"]}") + assert(result(3) === "{\"f1\":4,\"f2\":\"D4\",\"f3\":true,\"f4\":[\"4\",\" D4\",\" true\",\" 2147483644\"],\"f5\":2147483644}") val schema2 = StructType( StructField("f1", StructType( @@ -841,13 +926,13 @@ class JsonSuite extends QueryTest { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df3 = applySchema(rowRDD2, schema2) + val df3 = createDataFrame(rowRDD2, schema2) df3.registerTempTable("applySchema2") - val df4 = df3.toDataFrame + val df4 = df3.toDF val result2 = df4.toJSON.collect() - assert(result2(1) == "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") - assert(result2(3) == "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") + assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") + assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") val jsonDF = jsonRDD(primitiveFieldAndType) val primTable = jsonRDD(jsonDF.toJSON) @@ -922,6 +1007,37 @@ class JsonSuite extends QueryTest { sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from complexTable"), Row(5, null) ) + } + test("JSONRelation equality test") { + val relation1 = + JSONRelation("path", 1.0, Some(StructType(StructField("a", IntegerType, true) :: Nil)))(null) + val logicalRelation1 = LogicalRelation(relation1) + val relation2 = + JSONRelation("path", 0.5, Some(StructType(StructField("a", IntegerType, true) :: Nil)))( + org.apache.spark.sql.test.TestSQLContext) + val logicalRelation2 = LogicalRelation(relation2) + val relation3 = + JSONRelation("path", 1.0, Some(StructType(StructField("b", StringType, true) :: Nil)))(null) + val logicalRelation3 = LogicalRelation(relation3) + + assert(relation1 === relation2) + assert(logicalRelation1.sameResult(logicalRelation2), + s"$logicalRelation1 and $logicalRelation2 should be considered having the same result.") + + assert(relation1 !== relation3) + assert(!logicalRelation1.sameResult(logicalRelation3), + s"$logicalRelation1 and $logicalRelation3 should be considered not having the same result.") + + assert(relation2 !== relation3) + assert(!logicalRelation2.sameResult(logicalRelation3), + s"$logicalRelation2 and $logicalRelation3 should be considered not having the same result.") } + + test("SPARK-6245 JsonRDD.inferSchema on empty RDD") { + // This is really a test that it doesn't throw an exception + val emptySchema = JsonRDD.inferSchema(empty, 1.0, "") + assert(StructType(Seq()) === emptySchema) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index 3370b3c98b4be..47a97a49daabb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -146,6 +146,23 @@ object TestJsonData { ]] }""" :: Nil) + val mapType1 = + TestSQLContext.sparkContext.parallelize( + """{"map": {"a": 1}}""" :: + """{"map": {"b": 2}}""" :: + """{"map": {"c": 3}}""" :: + """{"map": {"c": 1, "d": 4}}""" :: + """{"map": {"e": null}}""" :: Nil) + + val mapType2 = + TestSQLContext.sparkContext.parallelize( + """{"map": {"a": {"field1": [1, 2, 3, null]}}}""" :: + """{"map": {"b": {"field2": 2}}}""" :: + """{"map": {"c": {"field1": [], "field2": 4}}}""" :: + """{"map": {"c": {"field2": 3}, "d": {"field1": [null]}}}""" :: + """{"map": {"e": null}}""" :: + """{"map": {"f": {"field1": null}}}""" :: Nil) + val nullsInArrays = TestSQLContext.sparkContext.parallelize( """{"field1":[[null], [[["Test"]]]]}""" :: @@ -168,4 +185,7 @@ object TestJsonData { """{"a":{, b:3}""" :: """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """]""" :: Nil) + + val empty = + TestSQLContext.sparkContext.parallelize(Seq[String]()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index ff91a0eb42049..6a2c2a7c4080a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -17,13 +17,16 @@ package org.apache.spark.sql.parquet +import org.scalatest.BeforeAndAfterAll import parquet.filter2.predicate.Operators._ import parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, Predicate, Row} -import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame, QueryTest, SQLConf} /** @@ -38,11 +41,11 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, SQLConf} * 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to ensure the inferred * data type is nullable. */ -class ParquetFilterSuite extends QueryTest with ParquetTest { +class ParquetFilterSuiteBase extends QueryTest with ParquetTest { val sqlContext = TestSQLContext private def checkFilterPredicate( - rdd: DataFrame, + df: DataFrame, predicate: Predicate, filterClass: Class[_ <: FilterPredicate], checker: (DataFrame, Seq[Row]) => Unit, @@ -50,13 +53,21 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { val output = predicate.collect { case a: Attribute => a }.distinct withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { - val query = rdd + val query = df .select(output.map(e => Column(e)): _*) .where(Column(predicate)) - val maybeAnalyzedPredicate = query.queryExecution.executedPlan.collect { - case plan: ParquetTableScan => plan.columnPruningPred - }.flatten.reduceOption(_ && _) + val maybeAnalyzedPredicate = { + val forParquetTableScan = query.queryExecution.executedPlan.collect { + case plan: ParquetTableScan => plan.columnPruningPred + }.flatten.reduceOption(_ && _) + + val forParquetDataSource = query.queryExecution.optimizedPlan.collect { + case PhysicalOperation(_, filters, LogicalRelation(_: ParquetRelation2)) => filters + }.flatten.reduceOption(_ && _) + + forParquetTableScan.orElse(forParquetDataSource) + } assert(maybeAnalyzedPredicate.isDefined) maybeAnalyzedPredicate.foreach { pred => @@ -74,52 +85,73 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { private def checkFilterPredicate (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) - (implicit rdd: DataFrame): Unit = { - checkFilterPredicate(rdd, predicate, filterClass, checkAnswer(_, _: Seq[Row]), expected) + (implicit df: DataFrame): Unit = { + checkFilterPredicate(df, predicate, filterClass, checkAnswer(_, _: Seq[Row]), expected) } private def checkFilterPredicate[T] (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: T) - (implicit rdd: DataFrame): Unit = { - checkFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd) + (implicit df: DataFrame): Unit = { + checkFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df) + } + + private def checkBinaryFilterPredicate + (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) + (implicit df: DataFrame): Unit = { + def checkBinaryAnswer(df: DataFrame, expected: Seq[Row]) = { + assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).toSeq.sorted) { + df.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted + } + } + + checkFilterPredicate(df, predicate, filterClass, checkBinaryAnswer _, expected) + } + + private def checkBinaryFilterPredicate + (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Array[Byte]) + (implicit df: DataFrame): Unit = { + checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df) } test("filter pushdown - boolean") { - withParquetRDD((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + withParquetDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], Seq(Row(true), Row(false))) - checkFilterPredicate('_1 === true, classOf[Eq [_]], true) + checkFilterPredicate('_1 === true, classOf[Eq[_]], true) checkFilterPredicate('_1 !== true, classOf[NotEq[_]], false) } } test("filter pushdown - short") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit rdd => - checkFilterPredicate(Cast('_1, IntegerType) === 1, classOf[Eq [_]], 1) - checkFilterPredicate(Cast('_1, IntegerType) !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - - checkFilterPredicate(Cast('_1, IntegerType) < 2, classOf[Lt [_]], 1) - checkFilterPredicate(Cast('_1, IntegerType) > 3, classOf[Gt [_]], 4) + withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit df => + checkFilterPredicate(Cast('_1, IntegerType) === 1, classOf[Eq[_]], 1) + checkFilterPredicate( + Cast('_1, IntegerType) !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate(Cast('_1, IntegerType) < 2, classOf[Lt[_]], 1) + checkFilterPredicate(Cast('_1, IntegerType) > 3, classOf[Gt[_]], 4) checkFilterPredicate(Cast('_1, IntegerType) <= 1, classOf[LtEq[_]], 1) checkFilterPredicate(Cast('_1, IntegerType) >= 4, classOf[GtEq[_]], 4) - - checkFilterPredicate(Literal(1) === Cast('_1, IntegerType), classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > Cast('_1, IntegerType), classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < Cast('_1, IntegerType), classOf[Gt [_]], 4) - checkFilterPredicate(Literal(1) >= Cast('_1, IntegerType), classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= Cast('_1, IntegerType), classOf[GtEq[_]], 4) - + + checkFilterPredicate(Literal(1) === Cast('_1, IntegerType), classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > Cast('_1, IntegerType), classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < Cast('_1, IntegerType), classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= Cast('_1, IntegerType), classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= Cast('_1, IntegerType), classOf[GtEq[_]], 4) + checkFilterPredicate(!(Cast('_1, IntegerType) < 4), classOf[GtEq[_]], 4) - checkFilterPredicate(Cast('_1, IntegerType) > 2 && Cast('_1, IntegerType) < 4, - classOf[Operators.And], 3) - checkFilterPredicate(Cast('_1, IntegerType) < 2 || Cast('_1, IntegerType) > 3, - classOf[Operators.Or], Seq(Row(1), Row(4))) + checkFilterPredicate( + Cast('_1, IntegerType) > 2 && Cast('_1, IntegerType) < 4, classOf[Operators.And], 3) + checkFilterPredicate( + Cast('_1, IntegerType) < 2 || Cast('_1, IntegerType) > 3, + classOf[Operators.Or], + Seq(Row(1), Row(4))) } } test("filter pushdown - integer") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { implicit rdd => + withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) @@ -131,53 +163,53 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } test("filter pushdown - long") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit rdd => + withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } test("filter pushdown - float") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit rdd => + withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) @@ -189,46 +221,47 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } test("filter pushdown - double") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit rdd => + withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } test("filter pushdown - string") { - withParquetRDD((1 to 4).map(i => Tuple1(i.toString))) { implicit rdd => + withParquetDataFrame((1 to 4).map(i => Tuple1(i.toString))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate( '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString))) checkFilterPredicate('_1 === "1", classOf[Eq[_]], "1") - checkFilterPredicate('_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) + checkFilterPredicate( + '_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) - checkFilterPredicate('_1 < "2", classOf[Lt[_]], "1") - checkFilterPredicate('_1 > "3", classOf[Gt[_]], "4") + checkFilterPredicate('_1 < "2", classOf[Lt[_]], "1") + checkFilterPredicate('_1 > "3", classOf[Gt[_]], "4") checkFilterPredicate('_1 <= "1", classOf[LtEq[_]], "1") checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4") @@ -244,35 +277,18 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { } } - def checkBinaryFilterPredicate - (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) - (implicit rdd: DataFrame): Unit = { - def checkBinaryAnswer(rdd: DataFrame, expected: Seq[Row]) = { - assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).toSeq.sorted) { - rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted - } - } - - checkFilterPredicate(rdd, predicate, filterClass, checkBinaryAnswer _, expected) - } - - def checkBinaryFilterPredicate - (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Array[Byte]) - (implicit rdd: DataFrame): Unit = { - checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd) - } - test("filter pushdown - binary") { implicit class IntToBinary(int: Int) { def b: Array[Byte] = int.toString.getBytes("UTF-8") } - withParquetRDD((1 to 4).map(i => Tuple1(i.b))) { implicit rdd => + withParquetDataFrame((1 to 4).map(i => Tuple1(i.b))) { implicit df => + checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq[_]], 1.b) + checkBinaryFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkBinaryFilterPredicate( '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.b)).toSeq) - checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq [_]], 1.b) checkBinaryFilterPredicate( '_1 !== 1.b, classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.b)).toSeq) @@ -294,3 +310,44 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { } } } + +class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll { + val originalConf = sqlContext.conf.parquetUseDataSourceApi + + override protected def beforeAll(): Unit = { + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + } + + override protected def afterAll(): Unit = { + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } + + test("SPARK-6554: don't push down predicates which reference partition columns") { + import sqlContext.implicits._ + + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/part=1" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").saveAsParquetFile(path) + + // If the "part = 1" filter gets pushed down, this query will throw an exception since + // "part" is not a valid column in the actual Parquet file + checkAnswer( + sqlContext.parquetFile(path).filter("part = 1"), + (1 to 3).map(i => Row(i, i.toString, 1))) + } + } + } +} + +class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll { + val originalConf = sqlContext.conf.parquetUseDataSourceApi + + override protected def beforeAll(): Unit = { + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + } + + override protected def afterAll(): Unit = { + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index d9ab16baf9a66..203bc79f153dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -21,24 +21,25 @@ import scala.collection.JavaConversions._ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.scalatest.BeforeAndAfterAll import parquet.example.data.simple.SimpleGroup import parquet.example.data.{Group, GroupWriter} import parquet.hadoop.api.WriteSupport import parquet.hadoop.api.WriteSupport.WriteContext -import parquet.hadoop.metadata.CompressionCodecName -import parquet.hadoop.{ParquetFileWriter, ParquetWriter} +import parquet.hadoop.metadata.{ParquetMetadata, FileMetaData, CompressionCodecName} +import parquet.hadoop.{Footer, ParquetFileWriter, ParquetWriter} import parquet.io.api.RecordConsumer import parquet.schema.{MessageType, MessageTypeParser} -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf} -import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.test.TestSQLContext.implicits._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf, SaveMode} // Write support class for nested groups: ParquetWriter initializes GroupWriteSupport // with an empty configuration (it is after all not intended to be used in this way?) @@ -63,14 +64,16 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS /** * A test suite that tests basic Parquet I/O. */ -class ParquetIOSuite extends QueryTest with ParquetTest { +class ParquetIOSuiteBase extends QueryTest with ParquetTest { val sqlContext = TestSQLContext + import sqlContext.implicits.localSeqToDataFrameHolder + /** * Writes `data` to a Parquet file, reads it back and check file contents. */ protected def checkParquetFile[T <: Product : ClassTag: TypeTag](data: Seq[T]): Unit = { - withParquetRDD(data)(r => checkAnswer(r, data.map(Row.fromTuple))) + withParquetDataFrame(data)(r => checkAnswer(r, data.map(Row.fromTuple))) } test("basic data types (without binary)") { @@ -82,9 +85,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest { test("raw binary") { val data = (1 to 4).map(i => Tuple1(Array.fill(3)(i.toByte))) - withParquetRDD(data) { rdd => + withParquetDataFrame(data) { df => assertResult(data.map(_._1.mkString(",")).sorted) { - rdd.collect().map(_.getAs[Array[Byte]](0).mkString(",")).sorted + df.collect().map(_.getAs[Array[Byte]](0).mkString(",")).sorted } } } @@ -98,11 +101,14 @@ class ParquetIOSuite extends QueryTest with ParquetTest { } test("fixed-length decimals") { + def makeDecimalRDD(decimal: DecimalType): DataFrame = sparkContext .parallelize(0 to 1000) .map(i => Tuple1(i / 100.0)) - .select($"_1" cast decimal as "abcd") + .toDF() + // Parquet doesn't allow column names with spaces, have to add an alias here + .select($"_1" cast decimal as "dec") for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { withTempPath { dir => @@ -129,6 +135,21 @@ class ParquetIOSuite extends QueryTest with ParquetTest { } } + test("date type") { + def makeDateRDD(): DataFrame = + sparkContext + .parallelize(0 to 1000) + .map(i => Tuple1(DateUtils.toJavaDate(i))) + .toDF() + .select($"_1") + + withTempPath { dir => + val data = makeDateRDD() + data.saveAsParquetFile(dir.getCanonicalPath) + checkAnswer(parquetFile(dir.getCanonicalPath), data.collect().toSeq) + } + } + test("map") { val data = (1 to 4).map(i => Tuple1(Map(i -> s"val_$i"))) checkParquetFile(data) @@ -141,9 +162,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest { test("struct") { val data = (1 to 4).map(i => Tuple1((i, s"val_$i"))) - withParquetRDD(data) { rdd => + withParquetDataFrame(data) { df => // Structs are converted to `Row`s - checkAnswer(rdd, data.map { case Tuple1(struct) => + checkAnswer(df, data.map { case Tuple1(struct) => Row(Row(struct.productIterator.toSeq: _*)) }) } @@ -151,9 +172,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest { test("nested struct with array of array as field") { val data = (1 to 4).map(i => Tuple1((i, Seq(Seq(s"val_$i"))))) - withParquetRDD(data) { rdd => + withParquetDataFrame(data) { df => // Structs are converted to `Row`s - checkAnswer(rdd, data.map { case Tuple1(struct) => + checkAnswer(df, data.map { case Tuple1(struct) => Row(Row(struct.productIterator.toSeq: _*)) }) } @@ -161,8 +182,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest { test("nested map with struct as value type") { val data = (1 to 4).map(i => Tuple1(Map(i -> (i, s"val_$i")))) - withParquetRDD(data) { rdd => - checkAnswer(rdd, data.map { case Tuple1(m) => + withParquetDataFrame(data) { df => + checkAnswer(df, data.map { case Tuple1(m) => Row(m.mapValues(struct => Row(struct.productIterator.toSeq: _*))) }) } @@ -176,8 +197,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest { null.asInstanceOf[java.lang.Float], null.asInstanceOf[java.lang.Double]) - withParquetRDD(allNulls :: Nil) { rdd => - val rows = rdd.collect() + withParquetDataFrame(allNulls :: Nil) { df => + val rows = df.collect() assert(rows.size === 1) assert(rows.head === Row(Seq.fill(5)(null): _*)) } @@ -189,8 +210,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest { None.asInstanceOf[Option[Long]], None.asInstanceOf[Option[String]]) - withParquetRDD(allNones :: Nil) { rdd => - val rows = rdd.collect() + withParquetDataFrame(allNones :: Nil) { df => + val rows = df.collect() assert(rows.size === 1) assert(rows.head === Row(Seq.fill(3)(null): _*)) } @@ -285,4 +306,115 @@ class ParquetIOSuite extends QueryTest with ParquetTest { expectedSchema.checkContains(actualSchema) } } + + test("save - overwrite") { + withParquetFile((1 to 10).map(i => (i, i.toString))) { file => + val newData = (11 to 20).map(i => (i, i.toString)) + newData.toDF().save("org.apache.spark.sql.parquet", SaveMode.Overwrite, Map("path" -> file)) + checkAnswer(parquetFile(file), newData.map(Row.fromTuple)) + } + } + + test("save - ignore") { + val data = (1 to 10).map(i => (i, i.toString)) + withParquetFile(data) { file => + val newData = (11 to 20).map(i => (i, i.toString)) + newData.toDF().save("org.apache.spark.sql.parquet", SaveMode.Ignore, Map("path" -> file)) + checkAnswer(parquetFile(file), data.map(Row.fromTuple)) + } + } + + test("save - throw") { + val data = (1 to 10).map(i => (i, i.toString)) + withParquetFile(data) { file => + val newData = (11 to 20).map(i => (i, i.toString)) + val errorMessage = intercept[Throwable] { + newData.toDF().save( + "org.apache.spark.sql.parquet", SaveMode.ErrorIfExists, Map("path" -> file)) + }.getMessage + assert(errorMessage.contains("already exists")) + } + } + + test("save - append") { + val data = (1 to 10).map(i => (i, i.toString)) + withParquetFile(data) { file => + val newData = (11 to 20).map(i => (i, i.toString)) + newData.toDF().save("org.apache.spark.sql.parquet", SaveMode.Append, Map("path" -> file)) + checkAnswer(parquetFile(file), (data ++ newData).map(Row.fromTuple)) + } + } + + test("SPARK-6315 regression test") { + // Spark 1.1 and prior versions write Spark schema as case class string into Parquet metadata. + // This has been deprecated by JSON format since 1.2. Notice that, 1.3 further refactored data + // types API, and made StructType.fields an array. This makes the result of StructType.toString + // different from prior versions: there's no "Seq" wrapping the fields part in the string now. + val sparkSchema = + "StructType(Seq(StructField(a,BooleanType,false),StructField(b,IntegerType,false)))" + + // The Parquet schema is intentionally made different from the Spark schema. Because the new + // Parquet data source simply falls back to the Parquet schema once it fails to parse the Spark + // schema. By making these two different, we are able to assert the old style case class string + // is parsed successfully. + val parquetSchema = MessageTypeParser.parseMessageType( + """message root { + | required int32 c; + |} + """.stripMargin) + + withTempPath { location => + val extraMetadata = Map(RowReadSupport.SPARK_METADATA_KEY -> sparkSchema.toString) + val fileMetadata = new FileMetaData(parquetSchema, extraMetadata, "Spark") + val path = new Path(location.getCanonicalPath) + + ParquetFileWriter.writeMetadataFile( + sparkContext.hadoopConfiguration, + path, + new Footer(path, new ParquetMetadata(fileMetadata, Nil)) :: Nil) + + assertResult(parquetFile(path.toString).schema) { + StructType( + StructField("a", BooleanType, nullable = false) :: + StructField("b", IntegerType, nullable = false) :: + Nil) + } + } + } +} + +class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { + val originalConf = sqlContext.conf.parquetUseDataSourceApi + + override protected def beforeAll(): Unit = { + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + } + + override protected def afterAll(): Unit = { + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } + + test("SPARK-6330 regression test") { + // In 1.3.0, save to fs other than file: without configuring core-site.xml would get: + // IllegalArgumentException: Wrong FS: hdfs://..., expected: file:/// + intercept[java.io.FileNotFoundException] { + sqlContext.parquetFile("file:///nonexistent") + } + val errorMessage = intercept[Throwable] { + sqlContext.parquetFile("hdfs://nonexistent") + }.toString + assert(errorMessage.contains("UnknownHostException")) + } +} + +class ParquetDataSourceOffIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { + val originalConf = sqlContext.conf.parquetUseDataSourceApi + + override protected def beforeAll(): Unit = { + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + } + + override protected def afterAll(): Unit = { + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala new file mode 100644 index 0000000000000..adb3c9391f6c2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -0,0 +1,343 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.parquet + +import scala.collection.mutable.ArrayBuffer + +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.parquet.ParquetRelation2._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{QueryTest, Row, SQLContext} + +// The data where the partitioning key exists only in the directory structure. +case class ParquetData(intField: Int, stringField: String) + +// The data that also includes the partitioning key +case class ParquetDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) + +class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { + override val sqlContext: SQLContext = TestSQLContext + + import sqlContext._ + import sqlContext.implicits._ + + val defaultPartitionName = "__NULL__" + + test("column type inference") { + def check(raw: String, literal: Literal): Unit = { + assert(inferPartitionColumnValue(raw, defaultPartitionName) === literal) + } + + check("10", Literal(10, IntegerType)) + check("1000000000000000", Literal(1000000000000000L, LongType)) + check("1.5", Literal(1.5, FloatType)) + check("hello", Literal("hello", StringType)) + check(defaultPartitionName, Literal(null, NullType)) + } + + test("parse partition") { + def check(path: String, expected: PartitionValues): Unit = { + assert(expected === parsePartition(new Path(path), defaultPartitionName)) + } + + def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = { + val message = intercept[T] { + parsePartition(new Path(path), defaultPartitionName) + }.getMessage + + assert(message.contains(expected)) + } + + check( + "file:///", + PartitionValues( + ArrayBuffer.empty[String], + ArrayBuffer.empty[Literal])) + + check( + "file://path/a=10", + PartitionValues( + ArrayBuffer("a"), + ArrayBuffer(Literal(10, IntegerType)))) + + check( + "file://path/a=10/b=hello/c=1.5", + PartitionValues( + ArrayBuffer("a", "b", "c"), + ArrayBuffer( + Literal(10, IntegerType), + Literal("hello", StringType), + Literal(1.5, FloatType)))) + + check( + "file://path/a=10/b_hello/c=1.5", + PartitionValues( + ArrayBuffer("c"), + ArrayBuffer(Literal(1.5, FloatType)))) + + checkThrows[AssertionError]("file://path/=10", "Empty partition column name") + checkThrows[AssertionError]("file://path/a=", "Empty partition column value") + } + + test("parse partitions") { + def check(paths: Seq[String], spec: PartitionSpec): Unit = { + assert(parsePartitions(paths.map(new Path(_)), defaultPartitionName) === spec) + } + + check(Seq( + "hdfs://host:9000/path/a=10/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", IntegerType), + StructField("b", StringType))), + Seq(Partition(Row(10, "hello"), "hdfs://host:9000/path/a=10/b=hello")))) + + check(Seq( + "hdfs://host:9000/path/a=10/b=20", + "hdfs://host:9000/path/a=10.5/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", FloatType), + StructField("b", StringType))), + Seq( + Partition(Row(10, "20"), "hdfs://host:9000/path/a=10/b=20"), + Partition(Row(10.5, "hello"), "hdfs://host:9000/path/a=10.5/b=hello")))) + + check(Seq( + s"hdfs://host:9000/path/a=10/b=20", + s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", IntegerType), + StructField("b", StringType))), + Seq( + Partition(Row(10, "20"), s"hdfs://host:9000/path/a=10/b=20"), + Partition(Row(null, "hello"), s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello")))) + + check(Seq( + s"hdfs://host:9000/path/a=10/b=$defaultPartitionName", + s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName"), + PartitionSpec( + StructType(Seq( + StructField("a", FloatType), + StructField("b", StringType))), + Seq( + Partition(Row(10, null), s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"), + Partition(Row(10.5, null), s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName")))) + } + + test("read partitioned table - normal case") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } { + makeParquetFile( + (1 to 10).map(i => ParquetData(i, i.toString)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + parquetFile(base.getCanonicalPath).registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } yield Row(i, i.toString, pi, ps)) + + checkAnswer( + sql("SELECT intField, pi FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + _ <- Seq("foo", "bar") + } yield Row(i, pi)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi = 1"), + for { + i <- 1 to 10 + ps <- Seq("foo", "bar") + } yield Row(i, i.toString, 1, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps = 'foo'"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, i.toString, pi, "foo")) + } + } + } + + test("read partitioned table - partition key included in Parquet file") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } { + makeParquetFile( + (1 to 10).map(i => ParquetDataWithKey(i, pi, i.toString, ps)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + parquetFile(base.getCanonicalPath).registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } yield Row(i, pi, i.toString, ps)) + + checkAnswer( + sql("SELECT intField, pi FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + _ <- Seq("foo", "bar") + } yield Row(i, pi)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi = 1"), + for { + i <- 1 to 10 + ps <- Seq("foo", "bar") + } yield Row(i, 1, i.toString, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps = 'foo'"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, pi, i.toString, "foo")) + } + } + } + + test("read partitioned table - with nulls") { + withTempDir { base => + for { + // Must be `Integer` rather than `Int` here. `null.asInstanceOf[Int]` results in a zero... + pi <- Seq(1, null.asInstanceOf[Integer]) + ps <- Seq("foo", null.asInstanceOf[String]) + } { + makeParquetFile( + (1 to 10).map(i => ParquetData(i, i.toString)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + val parquetRelation = load( + "org.apache.spark.sql.parquet", + Map( + "path" -> base.getCanonicalPath, + ParquetRelation2.DEFAULT_PARTITION_NAME -> defaultPartitionName)) + + parquetRelation.registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, null.asInstanceOf[Integer]) + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, i.toString, pi, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi IS NULL"), + for { + i <- 1 to 10 + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, i.toString, null, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps IS NULL"), + for { + i <- 1 to 10 + pi <- Seq(1, null.asInstanceOf[Integer]) + } yield Row(i, i.toString, pi, null)) + } + } + } + + test("read partitioned table - with nulls and partition keys are included in Parquet file") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", null.asInstanceOf[String]) + } { + makeParquetFile( + (1 to 10).map(i => ParquetDataWithKey(i, pi, i.toString, ps)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + val parquetRelation = load( + "org.apache.spark.sql.parquet", + Map( + "path" -> base.getCanonicalPath, + ParquetRelation2.DEFAULT_PARTITION_NAME -> defaultPartitionName)) + + parquetRelation.registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, pi, i.toString, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps IS NULL"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, pi, i.toString, null)) + } + } + } + + test("read partitioned table - merging compatible schemas") { + withTempDir { base => + makeParquetFile( + (1 to 10).map(i => Tuple1(i)).toDF("intField"), + makePartitionDir(base, defaultPartitionName, "pi" -> 1)) + + makeParquetFile( + (1 to 10).map(i => (i, i.toString)).toDF("intField", "stringField"), + makePartitionDir(base, defaultPartitionName, "pi" -> 2)) + + load(base.getCanonicalPath, "org.apache.spark.sql.parquet").registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + (1 to 10).map(i => Row(i, null, 1)) ++ (1 to 10).map(i => Row(i, i.toString, 2))) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 5ec7a156d9353..b98ba09ccfc2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.parquet -import org.apache.spark.sql.QueryTest +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.{SQLConf, QueryTest} import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ @@ -25,31 +27,34 @@ import org.apache.spark.sql.test.TestSQLContext._ /** * A test suite that tests various Parquet queries. */ -class ParquetQuerySuite extends QueryTest with ParquetTest { +class ParquetQuerySuiteBase extends QueryTest with ParquetTest { val sqlContext = TestSQLContext - test("simple projection") { + test("simple select queries") { withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { - checkAnswer(sql("SELECT _1 FROM t"), (0 until 10).map(Row.apply(_))) + checkAnswer(sql("SELECT _1 FROM t where t._1 > 5"), (6 until 10).map(Row.apply(_))) + checkAnswer(sql("SELECT _1 FROM t as tmp where tmp._1 < 5"), (0 until 5).map(Row.apply(_))) } } test("appending") { val data = (0 until 10).map(i => (i, i.toString)) + createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { - sql("INSERT INTO TABLE t SELECT * FROM t") + sql("INSERT INTO TABLE t SELECT * FROM tmp") checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) } + catalog.unregisterTable(Seq("tmp")) } - // This test case will trigger the NPE mentioned in - // https://issues.apache.org/jira/browse/PARQUET-151. - ignore("overwriting") { + test("overwriting") { val data = (0 until 10).map(i => (i, i.toString)) + createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { - sql("INSERT OVERWRITE TABLE t SELECT * FROM t") + sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") checkAnswer(table("t"), data.map(Row.fromTuple)) } + catalog.unregisterTable(Seq("tmp")) } test("self-join") { @@ -63,8 +68,8 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1") val queryOutput = selfJoin.queryExecution.analyzed.output - assertResult(4, s"Field count mismatches")(queryOutput.size) - assertResult(2, s"Duplicated expression ID in query plan:\n $selfJoin") { + assertResult(4, "Field count mismatches")(queryOutput.size) + assertResult(2, "Duplicated expression ID in query plan:\n $selfJoin") { queryOutput.filter(_.name == "_1").map(_.exprId).size } @@ -73,7 +78,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { } test("nested data - struct with array field") { - val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) + val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i")))) withParquetTable(data, "t") { checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map { case Tuple1((_, Seq(string))) => Row(string) @@ -82,7 +87,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { } test("nested data - array of struct") { - val data = (1 to 10).map(i => Tuple1(Seq(i -> s"val_$i"))) + val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i"))) withParquetTable(data, "t") { checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map { case Tuple1(Seq((_, string))) => Row(string) @@ -92,18 +97,42 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { withParquetTable((1 to 10).map(Tuple1.apply), "t") { - checkAnswer(sql(s"SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) + checkAnswer(sql("SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) } } test("SPARK-5309 strings stored using dictionary compression in parquet") { withParquetTable((0 until 1000).map(i => ("same", "run_" + i /100, 1)), "t") { - checkAnswer(sql(s"SELECT _1, _2, SUM(_3) FROM t GROUP BY _1, _2"), + checkAnswer(sql("SELECT _1, _2, SUM(_3) FROM t GROUP BY _1, _2"), (0 until 10).map(i => Row("same", "run_" + i, 100))) - checkAnswer(sql(s"SELECT _1, _2, SUM(_3) FROM t WHERE _2 = 'run_5' GROUP BY _1, _2"), + checkAnswer(sql("SELECT _1, _2, SUM(_3) FROM t WHERE _2 = 'run_5' GROUP BY _1, _2"), List(Row("same", "run_5", 100))) } } } + +class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll { + val originalConf = sqlContext.conf.parquetUseDataSourceApi + + override protected def beforeAll(): Unit = { + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + } + + override protected def afterAll(): Unit = { + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } +} + +class ParquetDataSourceOffQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll { + val originalConf = sqlContext.conf.parquetUseDataSourceApi + + override protected def beforeAll(): Unit = { + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + } + + override protected def afterAll(): Unit = { + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index 64274950b868e..61f1cf347ab0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -25,6 +25,7 @@ import parquet.schema.MessageTypeParser import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.types._ class ParquetSchemaSuite extends FunSuite with ParquetTest { val sqlContext = TestSQLContext @@ -33,9 +34,10 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { * Checks whether the reflected Parquet message type for product type `T` conforms `messageType`. */ private def testSchema[T <: Product: ClassTag: TypeTag]( - testName: String, messageType: String): Unit = { + testName: String, messageType: String, isThriftDerived: Boolean = false): Unit = { test(testName) { - val actual = ParquetTypesConverter.convertFromAttributes(ScalaReflection.attributesFor[T]) + val actual = ParquetTypesConverter.convertFromAttributes( + ScalaReflection.attributesFor[T], isThriftDerived) val expected = MessageTypeParser.parseMessageType(messageType) actual.checkContains(expected) expected.checkContains(actual) @@ -55,7 +57,7 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { |} """.stripMargin) - testSchema[(Byte, Short, Int, Long)]( + testSchema[(Byte, Short, Int, Long, java.sql.Date)]( "logical integral types", """ |message root { @@ -63,6 +65,7 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { | required int32 _2 (INT_16); | required int32 _3 (INT_32); | required int64 _4 (INT_64); + | optional int32 _5 (DATE); |} """.stripMargin) @@ -146,6 +149,29 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { |} """.stripMargin) + // Test for SPARK-4520 -- ensure that thrift generated parquet schema is generated + // as expected from attributes + testSchema[(Array[Byte], Array[Byte], Array[Byte], Seq[Int], Map[Array[Byte], Seq[Int]])]( + "thrift generated parquet schema", + """ + |message root { + | optional binary _1 (UTF8); + | optional binary _2 (UTF8); + | optional binary _3 (UTF8); + | optional group _4 (LIST) { + | repeated int32 _4_tuple; + | } + | optional group _5 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required binary key (UTF8); + | optional group value (LIST) { + | repeated int32 value_tuple; + | } + | } + | } + |} + """.stripMargin, isThriftDerived = true) + test("DataType string parser compatibility") { // This is the generated string from previous versions of the Spark SQL, using the following: // val schema = StructType(List( @@ -168,4 +194,86 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { assert(a.nullable === b.nullable) } } + + test("merge with metastore schema") { + // Field type conflict resolution + assertResult( + StructType(Seq( + StructField("lowerCase", StringType), + StructField("UPPERCase", DoubleType, nullable = false)))) { + + ParquetRelation2.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("lowercase", StringType), + StructField("uppercase", DoubleType, nullable = false))), + + StructType(Seq( + StructField("lowerCase", BinaryType), + StructField("UPPERCase", IntegerType, nullable = true)))) + } + + // MetaStore schema is subset of parquet schema + assertResult( + StructType(Seq( + StructField("UPPERCase", DoubleType, nullable = false)))) { + + ParquetRelation2.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("uppercase", DoubleType, nullable = false))), + + StructType(Seq( + StructField("lowerCase", BinaryType), + StructField("UPPERCase", IntegerType, nullable = true)))) + } + + // Metastore schema contains additional non-nullable fields. + assert(intercept[Throwable] { + ParquetRelation2.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("uppercase", DoubleType, nullable = false), + StructField("lowerCase", BinaryType, nullable = false))), + + StructType(Seq( + StructField("UPPERCase", IntegerType, nullable = true)))) + }.getMessage.contains("detected conflicting schemas")) + + // Conflicting non-nullable field names + intercept[Throwable] { + ParquetRelation2.mergeMetastoreParquetSchema( + StructType(Seq(StructField("lower", StringType, nullable = false))), + StructType(Seq(StructField("lowerCase", BinaryType)))) + } + } + + test("merge missing nullable fields from Metastore schema") { + // Standard case: Metastore schema contains additional nullable fields not present + // in the Parquet file schema. + assertResult( + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = true)))) { + ParquetRelation2.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("firstfield", StringType, nullable = true), + StructField("secondfield", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = true))), + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true)))) + } + + // Merge should fail if the Metastore contains any additional fields that are not + // nullable. + assert(intercept[Throwable] { + ParquetRelation2.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("firstfield", StringType, nullable = true), + StructField("secondfield", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = false))), + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true)))) + }.getMessage.contains("detected conflicting schemas")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index b02389978b625..3fe25e2a54e86 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.sources -import java.io.File +import java.io.{IOException, File} +import org.apache.spark.sql.AnalysisException import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.catalyst.util @@ -62,6 +63,29 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { dropTempTable("jsonTable") } + test("CREATE TEMPORARY TABLE AS SELECT based on the file without write permission") { + val childPath = new File(path.toString, "child") + path.mkdir() + childPath.createNewFile() + path.setWritable(false) + + val e = intercept[IOException] { + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT a, b FROM jt + """.stripMargin) + sql("SELECT a, b FROM jsonTable").collect() + } + assert(e.getMessage().contains("Unable to clear output directory")) + + path.setWritable(true) + } + test("create a table, drop it and create another one with the same name") { sql( s""" @@ -77,12 +101,10 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql("SELECT a, b FROM jsonTable"), sql("SELECT a, b FROM jt").collect()) - dropTempTable("jsonTable") - - val message = intercept[RuntimeException]{ + val message = intercept[DDLException]{ sql( s""" - |CREATE TEMPORARY TABLE jsonTable + |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( | path '${path.toString}' @@ -91,10 +113,25 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { """.stripMargin) }.getMessage assert( - message.contains(s"path ${path.toString} already exists."), + message.contains(s"a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause."), "CREATE TEMPORARY TABLE IF NOT EXISTS should not be allowed.") - // Explicitly delete it. + // Overwrite the temporary table. + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT a * 4 FROM jt + """.stripMargin) + checkAnswer( + sql("SELECT * FROM jsonTable"), + sql("SELECT a * 4 FROM jt").collect()) + + dropTempTable("jsonTable") + // Explicitly delete the data. if (path.exists()) Utils.deleteRecursively(path) sql( @@ -104,12 +141,12 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { |OPTIONS ( | path '${path.toString}' |) AS - |SELECT a * 4 FROM jt + |SELECT b FROM jt """.stripMargin) checkAnswer( sql("SELECT * FROM jsonTable"), - sql("SELECT a * 4 FROM jt").collect()) + sql("SELECT b FROM jt").collect()) dropTempTable("jsonTable") } @@ -144,4 +181,31 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { """.stripMargin) } } + + test("it is not allowed to write to a table while querying it.") { + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT a, b FROM jt + """.stripMargin) + + val message = intercept[AnalysisException] { + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT a, b FROM jsonTable + """.stripMargin) + }.getMessage + assert( + message.contains("Cannot overwrite table "), + "Writing to a table while querying it should not be allowed.") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala new file mode 100644 index 0000000000000..54af50c6e10ad --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -0,0 +1,99 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql._ +import org.apache.spark.sql.types._ + +class DDLScanSource extends RelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + SimpleDDLScan(parameters("from").toInt, parameters("TO").toInt)(sqlContext) + } +} + +case class SimpleDDLScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) + extends BaseRelation with TableScan { + + override def schema = + StructType(Seq( + StructField("intType", IntegerType, nullable = false, + new MetadataBuilder().putString("comment", "test comment").build()), + StructField("stringType", StringType, nullable = false), + StructField("dateType", DateType, nullable = false), + StructField("timestampType", TimestampType, nullable = false), + StructField("doubleType", DoubleType, nullable = false), + StructField("bigintType", LongType, nullable = false), + StructField("tinyintType", ByteType, nullable = false), + StructField("decimalType", DecimalType.Unlimited, nullable = false), + StructField("fixedDecimalType", DecimalType(5,1), nullable = false), + StructField("binaryType", BinaryType, nullable = false), + StructField("booleanType", BooleanType, nullable = false), + StructField("smallIntType", ShortType, nullable = false), + StructField("floatType", FloatType, nullable = false), + StructField("mapType", MapType(StringType, StringType)), + StructField("arrayType", ArrayType(StringType)), + StructField("structType", + StructType(StructField("f1",StringType) :: + (StructField("f2",IntegerType)) :: Nil + ) + ) + )) + + + override def buildScan() = sqlContext.sparkContext.parallelize(from to to). + map(e => Row(s"people$e", e * 2)) +} + +class DDLTestSuite extends DataSourceTest { + import caseInsensisitiveContext._ + + before { + sql( + """ + |CREATE TEMPORARY TABLE ddlPeople + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10' + |) + """.stripMargin) + } + + sqlTest( + "describe ddlPeople", + Seq( + Row("intType", "int", "test comment"), + Row("stringType", "string", ""), + Row("dateType", "date", ""), + Row("timestampType", "timestamp", ""), + Row("doubleType", "double", ""), + Row("bigintType", "bigint", ""), + Row("tinyintType", "tinyint", ""), + Row("decimalType", "decimal(10,0)", ""), + Row("fixedDecimalType", "decimal(5,1)", ""), + Row("binaryType", "binary", ""), + Row("booleanType", "boolean", ""), + Row("smallIntType", "smallint", ""), + Row("floatType", "float", ""), + Row("mapType", "map", ""), + Row("arrayType", "array", ""), + Row("structType", "struct", "") + )) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 9626252e742e5..33c67355967dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -28,7 +28,15 @@ abstract class DataSourceTest extends QueryTest with BeforeAndAfter { implicit val caseInsensisitiveContext = new SQLContext(TestSQLContext.sparkContext) { @transient override protected[sql] lazy val analyzer: Analyzer = - new Analyzer(catalog, functionRegistry, caseSensitive = false) + new Analyzer(catalog, functionRegistry, caseSensitive = false) { + override val extendedResolutionRules = + PreInsertCastAndRename :: + Nil + + override val extendedCheckRules = Seq( + sources.PreWriteCheck(catalog) + ) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 390538d35a348..72ddc0ea2c8cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -32,31 +32,55 @@ class FilteredScanSource extends RelationProvider { } case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) - extends PrunedFilteredScan { + extends BaseRelation + with PrunedFilteredScan { - override def schema = + override def schema: StructType = StructType( StructField("a", IntegerType, nullable = false) :: - StructField("b", IntegerType, nullable = false) :: Nil) + StructField("b", IntegerType, nullable = false) :: + StructField("c", StringType, nullable = false) :: Nil) override def buildScan(requiredColumns: Array[String], filters: Array[Filter]) = { val rowBuilders = requiredColumns.map { case "a" => (i: Int) => Seq(i) case "b" => (i: Int) => Seq(i * 2) + case "c" => (i: Int) => Seq((i - 1 + 'a').toChar.toString * 10) } FiltersPushed.list = filters - val filterFunctions = filters.collect { + // Predicate test on integer column + def translateFilterOnA(filter: Filter): Int => Boolean = filter match { case EqualTo("a", v) => (a: Int) => a == v case LessThan("a", v: Int) => (a: Int) => a < v case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v case GreaterThan("a", v: Int) => (a: Int) => a > v case GreaterThanOrEqual("a", v: Int) => (a: Int) => a >= v case In("a", values) => (a: Int) => values.map(_.asInstanceOf[Int]).toSet.contains(a) + case IsNull("a") => (a: Int) => false // Int can't be null + case IsNotNull("a") => (a: Int) => true + case Not(pred) => (a: Int) => !translateFilterOnA(pred)(a) + case And(left, right) => (a: Int) => + translateFilterOnA(left)(a) && translateFilterOnA(right)(a) + case Or(left, right) => (a: Int) => + translateFilterOnA(left)(a) || translateFilterOnA(right)(a) + case _ => (a: Int) => true } - def eval(a: Int) = !filterFunctions.map(_(a)).contains(false) + // Predicate test on string column + def translateFilterOnC(filter: Filter): String => Boolean = filter match { + case StringStartsWith("c", v) => _.startsWith(v) + case StringEndsWith("c", v) => _.endsWith(v) + case StringContains("c", v) => _.contains(v) + case _ => (c: String) => true + } + + def eval(a: Int) = { + val c = (a - 1 + 'a').toChar.toString * 10 + !filters.map(translateFilterOnA(_)(a)).contains(false) && + !filters.map(translateFilterOnC(_)(c)).contains(false) + } sqlContext.sparkContext.parallelize(from to to).filter(eval).map(i => Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty))) @@ -86,7 +110,7 @@ class FilteredScanSuite extends DataSourceTest { sqlTest( "SELECT * FROM oneToTenFiltered", - (1 to 10).map(i => Row(i, i * 2)).toSeq) + (1 to 10).map(i => Row(i, i * 2, (i - 1 + 'a').toChar.toString * 10)).toSeq) sqlTest( "SELECT a, b FROM oneToTenFiltered", @@ -121,20 +145,52 @@ class FilteredScanSuite extends DataSourceTest { (2 to 10 by 2).map(i => Row(i, i)).toSeq) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE a = 1", - Seq(1).map(i => Row(i, i * 2)).toSeq) + "SELECT a, b FROM oneToTenFiltered WHERE a = 1", + Seq(1).map(i => Row(i, i * 2))) + + sqlTest( + "SELECT a, b FROM oneToTenFiltered WHERE a IN (1,3,5)", + Seq(1,3,5).map(i => Row(i, i * 2))) + + sqlTest( + "SELECT a, b FROM oneToTenFiltered WHERE A = 1", + Seq(1).map(i => Row(i, i * 2))) + + sqlTest( + "SELECT a, b FROM oneToTenFiltered WHERE b = 2", + Seq(1).map(i => Row(i, i * 2))) + + sqlTest( + "SELECT a, b FROM oneToTenFiltered WHERE a IS NULL", + Seq.empty[Row]) + + sqlTest( + "SELECT a, b FROM oneToTenFiltered WHERE a IS NOT NULL", + (1 to 10).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT a, b FROM oneToTenFiltered WHERE a < 5 AND a > 1", + (2 to 4).map(i => Row(i, i * 2)).toSeq) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)", - Seq(1,3,5).map(i => Row(i, i * 2)).toSeq) + "SELECT a, b FROM oneToTenFiltered WHERE a < 3 OR a > 8", + Seq(1, 2, 9, 10).map(i => Row(i, i * 2))) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE A = 1", - Seq(1).map(i => Row(i, i * 2)).toSeq) + "SELECT a, b FROM oneToTenFiltered WHERE NOT (a < 6)", + (6 to 10).map(i => Row(i, i * 2)).toSeq) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE b = 2", - Seq(1).map(i => Row(i, i * 2)).toSeq) + "SELECT a, b, c FROM oneToTenFiltered WHERE c like 'c%'", + Seq(Row(3, 3 * 2, "c" * 10))) + + sqlTest( + "SELECT a, b, c FROM oneToTenFiltered WHERE c like 'd%'", + Seq(Row(4, 4 * 2, "d" * 10))) + + sqlTest( + "SELECT a, b, c FROM oneToTenFiltered WHERE c like '%e%'", + Seq(Row(5, 5 * 2, "e" * 10))) testPushDown("SELECT * FROM oneToTenFiltered WHERE A = 1", 1) testPushDown("SELECT a FROM oneToTenFiltered WHERE A = 1", 1) @@ -162,6 +218,10 @@ class FilteredScanSuite extends DataSourceTest { testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0) testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 5 AND a > 1", 3) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 3 OR a > 8", 4) + testPushDown("SELECT * FROM oneToTenFiltered WHERE NOT (a < 6)", 5) + def testPushDown(sqlString: String, expectedCount: Int): Unit = { test(s"PushDown Returns $expectedCount: $sqlString") { val queryExecution = sql(sqlString).queryExecution diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertIntoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertIntoSuite.scala deleted file mode 100644 index f91cea6a37060..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertIntoSuite.scala +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources - -import java.io.File - -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.util -import org.apache.spark.util.Utils - -class InsertIntoSuite extends DataSourceTest with BeforeAndAfterAll { - - import caseInsensisitiveContext._ - - var path: File = null - - override def beforeAll: Unit = { - path = util.getTempFilePath("jsonCTAS").getCanonicalFile - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - jsonRDD(rdd).registerTempTable("jt") - sql( - s""" - |CREATE TEMPORARY TABLE jsonTable (a int, b string) - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${path.toString}' - |) - """.stripMargin) - } - - override def afterAll: Unit = { - dropTempTable("jsonTable") - dropTempTable("jt") - if (path.exists()) Utils.deleteRecursively(path) - } - - test("Simple INSERT OVERWRITE a JSONRelation") { - sql( - s""" - |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt - """.stripMargin) - - checkAnswer( - sql("SELECT a, b FROM jsonTable"), - (1 to 10).map(i => Row(i, s"str$i")) - ) - } - - test("INSERT OVERWRITE a JSONRelation multiple times") { - sql( - s""" - |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt - """.stripMargin) - - sql( - s""" - |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt - """.stripMargin) - - sql( - s""" - |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt - """.stripMargin) - - checkAnswer( - sql("SELECT a, b FROM jsonTable"), - (1 to 10).map(i => Row(i, s"str$i")) - ) - } - - test("INSERT INTO not supported for JSONRelation for now") { - intercept[RuntimeException]{ - sql( - s""" - |INSERT INTO TABLE jsonTable SELECT a, b FROM jt - """.stripMargin) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala new file mode 100644 index 0000000000000..b5b16f9546691 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import java.io.File + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.catalyst.util +import org.apache.spark.util.Utils + +class InsertSuite extends DataSourceTest with BeforeAndAfterAll { + + import caseInsensisitiveContext._ + + var path: File = null + + override def beforeAll: Unit = { + path = util.getTempFilePath("jsonCTAS").getCanonicalFile + val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) + jsonRDD(rdd).registerTempTable("jt") + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable (a int, b string) + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${path.toString}' + |) + """.stripMargin) + } + + override def afterAll: Unit = { + dropTempTable("jsonTable") + dropTempTable("jt") + if (path.exists()) Utils.deleteRecursively(path) + } + + test("Simple INSERT OVERWRITE a JSONRelation") { + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt + """.stripMargin) + + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i, s"str$i")) + ) + } + + test("PreInsert casting and renaming") { + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a * 2, a * 4 FROM jt + """.stripMargin) + + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i * 2, s"${i * 4}")) + ) + + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a * 4 AS A, a * 6 as c FROM jt + """.stripMargin) + + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i * 4, s"${i * 6}")) + ) + } + + test("SELECT clause generating a different number of columns is not allowed.") { + val message = intercept[RuntimeException] { + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a FROM jt + """.stripMargin) + }.getMessage + assert( + message.contains("generates the same number of columns as its schema"), + "SELECT clause generating a different number of columns should not be not allowed." + ) + } + + test("INSERT OVERWRITE a JSONRelation multiple times") { + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt + """.stripMargin) + + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt + """.stripMargin) + + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt + """.stripMargin) + + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i, s"str$i")) + ) + } + + test("INSERT INTO not supported for JSONRelation for now") { + intercept[RuntimeException]{ + sql( + s""" + |INSERT INTO TABLE jsonTable SELECT a, b FROM jt + """.stripMargin) + } + } + + test("it is not allowed to write to a table while querying it.") { + val message = intercept[AnalysisException] { + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jsonTable + """.stripMargin) + }.getMessage + assert( + message.contains("Cannot insert overwrite into table that is also being read from."), + "INSERT OVERWRITE to a table while querying it should not be allowed.") + } + + test("Caching") { + // Cached Query Execution + cacheTable("jsonTable") + assertCached(sql("SELECT * FROM jsonTable")) + checkAnswer( + sql("SELECT * FROM jsonTable"), + (1 to 10).map(i => Row(i, s"str$i"))) + + assertCached(sql("SELECT a FROM jsonTable")) + checkAnswer( + sql("SELECT a FROM jsonTable"), + (1 to 10).map(Row(_)).toSeq) + + assertCached(sql("SELECT a FROM jsonTable WHERE a < 5")) + checkAnswer( + sql("SELECT a FROM jsonTable WHERE a < 5"), + (1 to 4).map(Row(_)).toSeq) + + assertCached(sql("SELECT a * 2 FROM jsonTable")) + checkAnswer( + sql("SELECT a * 2 FROM jsonTable"), + (1 to 10).map(i => Row(i * 2)).toSeq) + + assertCached(sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2) + checkAnswer( + sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), + (2 to 10).map(i => Row(i, i - 1)).toSeq) + + // Insert overwrite and keep the same schema. + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a * 2, b FROM jt + """.stripMargin) + // jsonTable should be recached. + assertCached(sql("SELECT * FROM jsonTable")) + // The cached data is the new data. + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + sql("SELECT a * 2, b FROM jt").collect()) + + // Verify uncaching + uncacheTable("jsonTable") + assertCached(sql("SELECT * FROM jsonTable"), 0) + } + + test("it's not allowed to insert into a relation that is not an InsertableRelation") { + sql( + """ + |CREATE TEMPORARY TABLE oneToTen + |USING org.apache.spark.sql.sources.SimpleScanSource + |OPTIONS ( + | From '1', + | To '10' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM oneToTen"), + (1 to 10).map(Row(_)).toSeq + ) + + val message = intercept[AnalysisException] { + sql( + s""" + |INSERT OVERWRITE TABLE oneToTen SELECT CAST(a AS INT) FROM jt + """.stripMargin) + }.getMessage + assert( + message.contains("does not allow insertion."), + "It is not allowed to insert into a table that is not an InsertableRelation." + ) + + dropTempTable("oneToTen") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index a33cf1172cac9..08fb5380dc026 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -31,7 +31,8 @@ class PrunedScanSource extends RelationProvider { } case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) - extends PrunedScan { + extends BaseRelation + with PrunedScan { override def schema = StructType( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala new file mode 100644 index 0000000000000..8331a14c9295c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -0,0 +1,34 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.scalatest.FunSuite + +class ResolvedDataSourceSuite extends FunSuite { + + test("builtin sources") { + assert(ResolvedDataSource.lookupDataSource("jdbc") === + classOf[org.apache.spark.sql.jdbc.DefaultSource]) + + assert(ResolvedDataSource.lookupDataSource("json") === + classOf[org.apache.spark.sql.json.DefaultSource]) + + assert(ResolvedDataSource.lookupDataSource("parquet") === + classOf[org.apache.spark.sql.parquet.DefaultSource]) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index fe2f76cc397f5..607488ccfdd6a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -21,10 +21,10 @@ import java.io.File import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.DataFrame -import org.apache.spark.util.Utils - import org.apache.spark.sql.catalyst.util +import org.apache.spark.sql.{SaveMode, SQLConf, DataFrame} +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { @@ -38,42 +38,60 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { override def beforeAll(): Unit = { originalDefaultSource = conf.defaultDataSourceName - conf.setConf("spark.sql.default.datasource", "org.apache.spark.sql.json") path = util.getTempFilePath("datasource").getCanonicalFile val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) df = jsonRDD(rdd) + df.registerTempTable("jsonTable") } override def afterAll(): Unit = { - conf.setConf("spark.sql.default.datasource", originalDefaultSource) + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) } after { + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) if (path.exists()) Utils.deleteRecursively(path) } def checkLoad(): Unit = { + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") checkAnswer(load(path.toString), df.collect()) - checkAnswer(load("org.apache.spark.sql.json", ("path", path.toString)), df.collect()) + + // Test if we can pick up the data source name passed in load. + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + checkAnswer(load(path.toString, "org.apache.spark.sql.json"), df.collect()) + checkAnswer(load("org.apache.spark.sql.json", Map("path" -> path.toString)), df.collect()) + val schema = StructType(StructField("b", StringType, true) :: Nil) + checkAnswer( + load("org.apache.spark.sql.json", schema, Map("path" -> path.toString)), + sql("SELECT b FROM jsonTable").collect()) } - test("save with overwrite and load") { + test("save with path and load") { + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") df.save(path.toString) - checkLoad + checkLoad() + } + + test("save with path and datasource, and load") { + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + df.save(path.toString, "org.apache.spark.sql.json") + checkLoad() } test("save with data source and options, and load") { - df.save("org.apache.spark.sql.json", ("path", path.toString)) - checkLoad + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, Map("path" -> path.toString)) + checkLoad() } test("save and save again") { - df.save(path.toString) + df.save(path.toString, "org.apache.spark.sql.json") - val message = intercept[RuntimeException] { - df.save(path.toString) + var message = intercept[RuntimeException] { + df.save(path.toString, "org.apache.spark.sql.json") }.getMessage assert( @@ -82,7 +100,18 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { if (path.exists()) Utils.deleteRecursively(path) - df.save(path.toString) - checkLoad + df.save(path.toString, "org.apache.spark.sql.json") + checkLoad() + + df.save("org.apache.spark.sql.json", SaveMode.Overwrite, Map("path" -> path.toString)) + checkLoad() + + message = intercept[RuntimeException] { + df.save("org.apache.spark.sql.json", SaveMode.Append, Map("path" -> path.toString)) + }.getMessage + + assert( + message.contains("Append mode is not supported"), + "We should complain that 'Append mode is not supported' for JSON source.") } } \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 0a4d4b6342d4f..7928600ac2fb5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -33,7 +33,7 @@ class SimpleScanSource extends RelationProvider { } case class SimpleScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) - extends TableScan { + extends BaseRelation with TableScan { override def schema = StructType(StructField("i", IntegerType, nullable = false) :: Nil) @@ -51,10 +51,11 @@ class AllDataTypesScanSource extends SchemaRelationProvider { } case class AllDataTypesScan( - from: Int, - to: Int, - userSpecifiedSchema: StructType)(@transient val sqlContext: SQLContext) - extends TableScan { + from: Int, + to: Int, + userSpecifiedSchema: StructType)(@transient val sqlContext: SQLContext) + extends BaseRelation + with TableScan { override def schema = userSpecifiedSchema diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 123a1f629ab1c..140a592098a5b 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -21,8 +21,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.3.2-SNAPSHOT ../../pom.xml diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala old mode 100755 new mode 100644 index bb19ac232fcbe..6e3c92f721b3f --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -194,8 +194,8 @@ private[hive] object SparkSQLCLIDriver { val currentDB = ReflectionUtils.invokeStatic(classOf[CliDriver], "getFormattedDb", classOf[HiveConf] -> conf, classOf[CliSessionState] -> sessionState) - def promptWithCurrentDB = s"$prompt$currentDB" - def continuedPromptWithDBSpaces = continuedPrompt + ReflectionUtils.invokeStatic( + def promptWithCurrentDB: String = s"$prompt$currentDB" + def continuedPromptWithDBSpaces: String = continuedPrompt + ReflectionUtils.invokeStatic( classOf[CliDriver], "spacesForString", classOf[String] -> currentDB) var currentPrompt = promptWithCurrentDB @@ -292,9 +292,13 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { } } + var counter = 0 try { while (!out.checkError() && driver.getResults(res)) { - res.foreach(out.println) + res.foreach{ l => + counter += 1 + out.println(l) + } res.clear() } } catch { @@ -311,7 +315,11 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { ret = cret } - console.printInfo(s"Time taken: $timeTaken seconds", null) + var responseMsg = s"Time taken: $timeTaken seconds" + if (counter != 0) { + responseMsg += s", Fetched $counter row(s)" + } + console.printInfo(responseMsg , null) // Destroy the driver to release all the locks. driver.destroy() } else { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 60953576d0e37..8bca4b33b3ad1 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -121,6 +121,6 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { } test("Single command with -e") { - runCliWithin(1.minute, Seq("-e", "SHOW TABLES;"))("" -> "OK") + runCliWithin(1.minute, Seq("-e", "SHOW DATABASES;"))("" -> "OK") } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala deleted file mode 100644 index b52a51d11e4ad..0000000000000 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ /dev/null @@ -1,387 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.thriftserver - -import java.io.File -import java.net.ServerSocket -import java.sql.{Date, DriverManager, Statement} -import java.util.concurrent.TimeoutException - -import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer -import scala.concurrent.duration._ -import scala.concurrent.{Await, Promise} -import scala.sys.process.{Process, ProcessLogger} -import scala.util.Try - -import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hive.jdbc.HiveDriver -import org.apache.hive.service.auth.PlainSaslHelper -import org.apache.hive.service.cli.GetInfoType -import org.apache.hive.service.cli.thrift.TCLIService.Client -import org.apache.hive.service.cli.thrift._ -import org.apache.thrift.protocol.TBinaryProtocol -import org.apache.thrift.transport.TSocket -import org.scalatest.FunSuite - -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.util.getTempFilePath -import org.apache.spark.sql.hive.HiveShim - -/** - * Tests for the HiveThriftServer2 using JDBC. - * - * NOTE: SPARK_PREPEND_CLASSES is explicitly disabled in this test suite. Assembly jar must be - * rebuilt after changing HiveThriftServer2 related code. - */ -class HiveThriftServer2Suite extends FunSuite with Logging { - Class.forName(classOf[HiveDriver].getCanonicalName) - - object TestData { - def getTestDataFilePath(name: String) = { - Thread.currentThread().getContextClassLoader.getResource(s"data/files/$name") - } - - val smallKv = getTestDataFilePath("small_kv.txt") - val smallKvWithNull = getTestDataFilePath("small_kv_with_null.txt") - } - - def randomListeningPort = { - // Let the system to choose a random available port to avoid collision with other parallel - // builds. - val socket = new ServerSocket(0) - val port = socket.getLocalPort - socket.close() - port - } - - def withJdbcStatement( - serverStartTimeout: FiniteDuration = 1.minute, - httpMode: Boolean = false)( - f: Statement => Unit) { - val port = randomListeningPort - - startThriftServer(port, serverStartTimeout, httpMode) { - val jdbcUri = if (httpMode) { - s"jdbc:hive2://${"localhost"}:$port/" + - "default?hive.server2.transport.mode=http;hive.server2.thrift.http.path=cliservice" - } else { - s"jdbc:hive2://${"localhost"}:$port/" - } - - val user = System.getProperty("user.name") - val connection = DriverManager.getConnection(jdbcUri, user, "") - val statement = connection.createStatement() - - try { - f(statement) - } finally { - statement.close() - connection.close() - } - } - } - - def withCLIServiceClient( - serverStartTimeout: FiniteDuration = 1.minute)( - f: ThriftCLIServiceClient => Unit) { - val port = randomListeningPort - - startThriftServer(port) { - // Transport creation logics below mimics HiveConnection.createBinaryTransport - val rawTransport = new TSocket("localhost", port) - val user = System.getProperty("user.name") - val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport) - val protocol = new TBinaryProtocol(transport) - val client = new ThriftCLIServiceClient(new Client(protocol)) - - transport.open() - - try { - f(client) - } finally { - transport.close() - } - } - } - - def startThriftServer( - port: Int, - serverStartTimeout: FiniteDuration = 1.minute, - httpMode: Boolean = false)( - f: => Unit) { - val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator) - val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator) - - val warehousePath = getTempFilePath("warehouse") - val metastorePath = getTempFilePath("metastore") - val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" - - val command = - if (httpMode) { - s"""$startScript - | --master local - | --hiveconf hive.root.logger=INFO,console - | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri - | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath - | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost - | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=http - | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT}=$port - | --driver-class-path ${sys.props("java.class.path")} - | --conf spark.ui.enabled=false - """.stripMargin.split("\\s+").toSeq - } else { - s"""$startScript - | --master local - | --hiveconf hive.root.logger=INFO,console - | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri - | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath - | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost - | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$port - | --driver-class-path ${sys.props("java.class.path")} - | --conf spark.ui.enabled=false - """.stripMargin.split("\\s+").toSeq - } - - val serverRunning = Promise[Unit]() - val buffer = new ArrayBuffer[String]() - val LOGGING_MARK = - s"starting ${HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$")}, logging to " - var logTailingProcess: Process = null - var logFilePath: String = null - - def captureLogOutput(line: String): Unit = { - buffer += line - if (line.contains("ThriftBinaryCLIService listening on") || - line.contains("Started ThriftHttpCLIService in http")) { - serverRunning.success(()) - } - } - - def captureThriftServerOutput(source: String)(line: String): Unit = { - if (line.startsWith(LOGGING_MARK)) { - logFilePath = line.drop(LOGGING_MARK.length).trim - // Ensure that the log file is created so that the `tail' command won't fail - Try(new File(logFilePath).createNewFile()) - logTailingProcess = Process(s"/usr/bin/env tail -f $logFilePath") - .run(ProcessLogger(captureLogOutput, _ => ())) - } - } - - val env = Seq( - // Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths - "SPARK_TESTING" -> "0") - - Process(command, None, env: _*).run(ProcessLogger( - captureThriftServerOutput("stdout"), - captureThriftServerOutput("stderr"))) - - try { - Await.result(serverRunning.future, serverStartTimeout) - f - } catch { - case cause: Exception => - cause match { - case _: TimeoutException => - logError(s"Failed to start Hive Thrift server within $serverStartTimeout", cause) - case _ => - } - logError( - s""" - |===================================== - |HiveThriftServer2Suite failure output - |===================================== - |HiveThriftServer2 command line: ${command.mkString(" ")} - |Binding port: $port - |System user: ${System.getProperty("user.name")} - | - |${buffer.mkString("\n")} - |========================================= - |End HiveThriftServer2Suite failure output - |========================================= - """.stripMargin, cause) - throw cause - } finally { - warehousePath.delete() - metastorePath.delete() - Process(stopScript, None, env: _*).run().exitValue() - // The `spark-daemon.sh' script uses kill, which is not synchronous, have to wait for a while. - Thread.sleep(3.seconds.toMillis) - Option(logTailingProcess).map(_.destroy()) - Option(logFilePath).map(new File(_).delete()) - } - } - - test("Test JDBC query execution") { - withJdbcStatement() { statement => - val queries = Seq( - "SET spark.sql.shuffle.partitions=3", - "DROP TABLE IF EXISTS test", - "CREATE TABLE test(key INT, val STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test", - "CACHE TABLE test") - - queries.foreach(statement.execute) - - assertResult(5, "Row count mismatch") { - val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test") - resultSet.next() - resultSet.getInt(1) - } - } - } - - test("Test JDBC query execution in Http Mode") { - withJdbcStatement(httpMode = true) { statement => - val queries = Seq( - "SET spark.sql.shuffle.partitions=3", - "DROP TABLE IF EXISTS test", - "CREATE TABLE test(key INT, val STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test", - "CACHE TABLE test") - - queries.foreach(statement.execute) - - assertResult(5, "Row count mismatch") { - val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test") - resultSet.next() - resultSet.getInt(1) - } - } - } - - test("SPARK-3004 regression: result set containing NULL") { - withJdbcStatement() { statement => - val queries = Seq( - "DROP TABLE IF EXISTS test_null", - "CREATE TABLE test_null(key INT, val STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKvWithNull}' OVERWRITE INTO TABLE test_null") - - queries.foreach(statement.execute) - - val resultSet = statement.executeQuery("SELECT * FROM test_null WHERE key IS NULL") - - (0 until 5).foreach { _ => - resultSet.next() - assert(resultSet.getInt(1) === 0) - assert(resultSet.wasNull()) - } - - assert(!resultSet.next()) - } - } - - test("GetInfo Thrift API") { - withCLIServiceClient() { client => - val user = System.getProperty("user.name") - val sessionHandle = client.openSession(user, "") - - assertResult("Spark SQL", "Wrong GetInfo(CLI_DBMS_NAME) result") { - client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_NAME).getStringValue - } - - assertResult("Spark SQL", "Wrong GetInfo(CLI_SERVER_NAME) result") { - client.getInfo(sessionHandle, GetInfoType.CLI_SERVER_NAME).getStringValue - } - - assertResult(true, "Spark version shouldn't be \"Unknown\"") { - val version = client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_VER).getStringValue - logInfo(s"Spark version: $version") - version != "Unknown" - } - } - } - - test("Checks Hive version") { - withJdbcStatement() { statement => - val resultSet = statement.executeQuery("SET spark.sql.hive.version") - resultSet.next() - assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") - } - } - - test("Checks Hive version in Http Mode") { - withJdbcStatement(httpMode = true) { statement => - val resultSet = statement.executeQuery("SET spark.sql.hive.version") - resultSet.next() - assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") - } - } - - test("SPARK-4292 regression: result set iterator issue") { - withJdbcStatement() { statement => - val queries = Seq( - "DROP TABLE IF EXISTS test_4292", - "CREATE TABLE test_4292(key INT, val STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_4292") - - queries.foreach(statement.execute) - - val resultSet = statement.executeQuery("SELECT key FROM test_4292") - - Seq(238, 86, 311, 27, 165).foreach { key => - resultSet.next() - assert(resultSet.getInt(1) === key) - } - - statement.executeQuery("DROP TABLE IF EXISTS test_4292") - } - } - - test("SPARK-4309 regression: Date type support") { - withJdbcStatement() { statement => - val queries = Seq( - "DROP TABLE IF EXISTS test_date", - "CREATE TABLE test_date(key INT, value STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_date") - - queries.foreach(statement.execute) - - assertResult(Date.valueOf("2011-01-01")) { - val resultSet = statement.executeQuery( - "SELECT CAST('2011-01-01' as date) FROM test_date LIMIT 1") - resultSet.next() - resultSet.getDate(1) - } - } - } - - test("SPARK-4407 regression: Complex type support") { - withJdbcStatement() { statement => - val queries = Seq( - "DROP TABLE IF EXISTS test_map", - "CREATE TABLE test_map(key INT, value STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map") - - queries.foreach(statement.execute) - - assertResult("""{238:"val_238"}""") { - val resultSet = statement.executeQuery("SELECT MAP(key, value) FROM test_map LIMIT 1") - resultSet.next() - resultSet.getString(1) - } - - assertResult("""["238","val_238"]""") { - val resultSet = statement.executeQuery( - "SELECT ARRAY(CAST(key AS STRING), value) FROM test_map LIMIT 1") - resultSet.next() - resultSet.getString(1) - } - } - } -} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala new file mode 100644 index 0000000000000..d783d487b5c60 --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -0,0 +1,412 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.io.File +import java.sql.{Date, DriverManager, Statement} + +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ +import scala.concurrent.{Await, Promise} +import scala.sys.process.{Process, ProcessLogger} +import scala.util.{Random, Try} + +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hive.jdbc.HiveDriver +import org.apache.hive.service.auth.PlainSaslHelper +import org.apache.hive.service.cli.GetInfoType +import org.apache.hive.service.cli.thrift.TCLIService.Client +import org.apache.hive.service.cli.thrift.ThriftCLIServiceClient +import org.apache.thrift.protocol.TBinaryProtocol +import org.apache.thrift.transport.TSocket +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.util +import org.apache.spark.sql.hive.HiveShim +import org.apache.spark.util.Utils + +object TestData { + def getTestDataFilePath(name: String) = { + Thread.currentThread().getContextClassLoader.getResource(s"data/files/$name") + } + + val smallKv = getTestDataFilePath("small_kv.txt") + val smallKvWithNull = getTestDataFilePath("small_kv_with_null.txt") +} + +class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { + override def mode = ServerMode.binary + + private def withCLIServiceClient(f: ThriftCLIServiceClient => Unit): Unit = { + // Transport creation logics below mimics HiveConnection.createBinaryTransport + val rawTransport = new TSocket("localhost", serverPort) + val user = System.getProperty("user.name") + val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport) + val protocol = new TBinaryProtocol(transport) + val client = new ThriftCLIServiceClient(new Client(protocol)) + + transport.open() + try f(client) finally transport.close() + } + + test("GetInfo Thrift API") { + withCLIServiceClient { client => + val user = System.getProperty("user.name") + val sessionHandle = client.openSession(user, "") + + assertResult("Spark SQL", "Wrong GetInfo(CLI_DBMS_NAME) result") { + client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_NAME).getStringValue + } + + assertResult("Spark SQL", "Wrong GetInfo(CLI_SERVER_NAME) result") { + client.getInfo(sessionHandle, GetInfoType.CLI_SERVER_NAME).getStringValue + } + + assertResult(true, "Spark version shouldn't be \"Unknown\"") { + val version = client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_VER).getStringValue + logInfo(s"Spark version: $version") + version != "Unknown" + } + } + } + + test("JDBC query execution") { + withJdbcStatement { statement => + val queries = Seq( + "SET spark.sql.shuffle.partitions=3", + "DROP TABLE IF EXISTS test", + "CREATE TABLE test(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test", + "CACHE TABLE test") + + queries.foreach(statement.execute) + + assertResult(5, "Row count mismatch") { + val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test") + resultSet.next() + resultSet.getInt(1) + } + } + } + + test("Checks Hive version") { + withJdbcStatement { statement => + val resultSet = statement.executeQuery("SET spark.sql.hive.version") + resultSet.next() + assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") + } + } + + test("SPARK-3004 regression: result set containing NULL") { + withJdbcStatement { statement => + val queries = Seq( + "DROP TABLE IF EXISTS test_null", + "CREATE TABLE test_null(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKvWithNull}' OVERWRITE INTO TABLE test_null") + + queries.foreach(statement.execute) + + val resultSet = statement.executeQuery("SELECT * FROM test_null WHERE key IS NULL") + + (0 until 5).foreach { _ => + resultSet.next() + assert(resultSet.getInt(1) === 0) + assert(resultSet.wasNull()) + } + + assert(!resultSet.next()) + } + } + + test("SPARK-4292 regression: result set iterator issue") { + withJdbcStatement { statement => + val queries = Seq( + "DROP TABLE IF EXISTS test_4292", + "CREATE TABLE test_4292(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_4292") + + queries.foreach(statement.execute) + + val resultSet = statement.executeQuery("SELECT key FROM test_4292") + + Seq(238, 86, 311, 27, 165).foreach { key => + resultSet.next() + assert(resultSet.getInt(1) === key) + } + + statement.executeQuery("DROP TABLE IF EXISTS test_4292") + } + } + + test("SPARK-4309 regression: Date type support") { + withJdbcStatement { statement => + val queries = Seq( + "DROP TABLE IF EXISTS test_date", + "CREATE TABLE test_date(key INT, value STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_date") + + queries.foreach(statement.execute) + + assertResult(Date.valueOf("2011-01-01")) { + val resultSet = statement.executeQuery( + "SELECT CAST('2011-01-01' as date) FROM test_date LIMIT 1") + resultSet.next() + resultSet.getDate(1) + } + } + } + + test("SPARK-4407 regression: Complex type support") { + withJdbcStatement { statement => + val queries = Seq( + "DROP TABLE IF EXISTS test_map", + "CREATE TABLE test_map(key INT, value STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map") + + queries.foreach(statement.execute) + + assertResult("""{238:"val_238"}""") { + val resultSet = statement.executeQuery("SELECT MAP(key, value) FROM test_map LIMIT 1") + resultSet.next() + resultSet.getString(1) + } + + assertResult("""["238","val_238"]""") { + val resultSet = statement.executeQuery( + "SELECT ARRAY(CAST(key AS STRING), value) FROM test_map LIMIT 1") + resultSet.next() + resultSet.getString(1) + } + } + } +} + +class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { + override def mode = ServerMode.http + + test("JDBC query execution") { + withJdbcStatement { statement => + val queries = Seq( + "SET spark.sql.shuffle.partitions=3", + "DROP TABLE IF EXISTS test", + "CREATE TABLE test(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test", + "CACHE TABLE test") + + queries.foreach(statement.execute) + + assertResult(5, "Row count mismatch") { + val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test") + resultSet.next() + resultSet.getInt(1) + } + } + } + + test("Checks Hive version") { + withJdbcStatement { statement => + val resultSet = statement.executeQuery("SET spark.sql.hive.version") + resultSet.next() + assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") + } + } +} + +object ServerMode extends Enumeration { + val binary, http = Value +} + +abstract class HiveThriftJdbcTest extends HiveThriftServer2Test { + Class.forName(classOf[HiveDriver].getCanonicalName) + + private def jdbcUri = if (mode == ServerMode.http) { + s"""jdbc:hive2://localhost:$serverPort/ + |default? + |hive.server2.transport.mode=http; + |hive.server2.thrift.http.path=cliservice + """.stripMargin.split("\n").mkString.trim + } else { + s"jdbc:hive2://localhost:$serverPort/" + } + + protected def withJdbcStatement(f: Statement => Unit): Unit = { + val connection = DriverManager.getConnection(jdbcUri, user, "") + val statement = connection.createStatement() + + try f(statement) finally { + statement.close() + connection.close() + } + } +} + +abstract class HiveThriftServer2Test extends FunSuite with BeforeAndAfterAll with Logging { + def mode: ServerMode.Value + + private val CLASS_NAME = HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$") + private val LOG_FILE_MARK = s"starting $CLASS_NAME, logging to " + + private val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator) + private val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator) + + private var listeningPort: Int = _ + protected def serverPort: Int = listeningPort + + protected def user = System.getProperty("user.name") + + private var warehousePath: File = _ + private var metastorePath: File = _ + private def metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" + + private val pidDir: File = Utils.createTempDir("thriftserver-pid") + private var logPath: File = _ + private var logTailingProcess: Process = _ + private var diagnosisBuffer: ArrayBuffer[String] = ArrayBuffer.empty[String] + + private def serverStartCommand(port: Int) = { + val portConf = if (mode == ServerMode.binary) { + ConfVars.HIVE_SERVER2_THRIFT_PORT + } else { + ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT + } + + s"""$startScript + | --master local + | --hiveconf hive.root.logger=INFO,console + | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri + | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath + | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost + | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode + | --hiveconf $portConf=$port + | --driver-class-path ${sys.props("java.class.path")} + | --conf spark.ui.enabled=false + """.stripMargin.split("\\s+").toSeq + } + + private def startThriftServer(port: Int, attempt: Int) = { + warehousePath = util.getTempFilePath("warehouse") + metastorePath = util.getTempFilePath("metastore") + logPath = null + logTailingProcess = null + + val command = serverStartCommand(port) + + diagnosisBuffer ++= + s""" + |### Attempt $attempt ### + |HiveThriftServer2 command line: $command + |Listening port: $port + |System user: $user + """.stripMargin.split("\n") + + logInfo(s"Trying to start HiveThriftServer2: port=$port, mode=$mode, attempt=$attempt") + + val env = Seq( + // Disables SPARK_TESTING to exclude log4j.properties in test directories. + "SPARK_TESTING" -> "0", + // Points SPARK_PID_DIR to SPARK_HOME, otherwise only 1 Thrift server instance can be started + // at a time, which is not Jenkins friendly. + "SPARK_PID_DIR" -> pidDir.getCanonicalPath) + + logPath = Process(command, None, env: _*).lines.collectFirst { + case line if line.contains(LOG_FILE_MARK) => new File(line.drop(LOG_FILE_MARK.length)) + }.getOrElse { + throw new RuntimeException("Failed to find HiveThriftServer2 log file.") + } + + val serverStarted = Promise[Unit]() + + // Ensures that the following "tail" command won't fail. + logPath.createNewFile() + logTailingProcess = + // Using "-n +0" to make sure all lines in the log file are checked. + Process(s"/usr/bin/env tail -n +0 -f ${logPath.getCanonicalPath}").run(ProcessLogger( + (line: String) => { + diagnosisBuffer += line + + if (line.contains("ThriftBinaryCLIService listening on") || + line.contains("Started ThriftHttpCLIService in http")) { + serverStarted.trySuccess(()) + } else if (line.contains("HiveServer2 is stopped")) { + // This log line appears when the server fails to start and terminates gracefully (e.g. + // because of port contention). + serverStarted.tryFailure(new RuntimeException("Failed to start HiveThriftServer2")) + } + })) + + Await.result(serverStarted.future, 2.minute) + } + + private def stopThriftServer(): Unit = { + // The `spark-daemon.sh' script uses kill, which is not synchronous, have to wait for a while. + Process(stopScript, None, "SPARK_PID_DIR" -> pidDir.getCanonicalPath).run().exitValue() + Thread.sleep(3.seconds.toMillis) + + warehousePath.delete() + warehousePath = null + + metastorePath.delete() + metastorePath = null + + Option(logPath).foreach(_.delete()) + logPath = null + + Option(logTailingProcess).foreach(_.destroy()) + logTailingProcess = null + } + + private def dumpLogs(): Unit = { + logError( + s""" + |===================================== + |HiveThriftServer2Suite failure output + |===================================== + |${diagnosisBuffer.mkString("\n")} + |========================================= + |End HiveThriftServer2Suite failure output + |========================================= + """.stripMargin) + } + + override protected def beforeAll(): Unit = { + // Chooses a random port between 10000 and 19999 + listeningPort = 10000 + Random.nextInt(10000) + diagnosisBuffer.clear() + + // Retries up to 3 times with different port numbers if the server fails to start + (1 to 3).foldLeft(Try(startThriftServer(listeningPort, 0))) { case (started, attempt) => + started.orElse { + listeningPort += 1 + stopThriftServer() + Try(startThriftServer(listeningPort, attempt)) + } + }.recover { + case cause: Throwable => + dumpLogs() + throw cause + }.get + + logInfo(s"HiveThriftServer2 started successfully") + } + + override protected def afterAll(): Unit = { + stopThriftServer() + logInfo("HiveThriftServer2 stopped") + } +} diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala index ea9d61d8d0f5e..13116b40bb259 100644 --- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala +++ b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala @@ -185,6 +185,10 @@ private[hive] class SparkExecuteStatementOperation( def run(): Unit = { logInfo(s"Running query '$statement'") setState(OperationState.RUNNING) + hiveContext.sparkContext.setJobDescription(statement) + sessionToActivePool.get(parentSession.getSessionHandle).foreach { pool => + hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) + } try { result = hiveContext.sql(statement) logDebug(result.queryExecution.toString()) @@ -194,10 +198,6 @@ private[hive] class SparkExecuteStatementOperation( logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") case _ => } - hiveContext.sparkContext.setJobDescription(statement) - sessionToActivePool.get(parentSession.getSessionHandle).foreach { pool => - hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) - } iter = { val useIncrementalCollect = hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala index 71e3954b2c7ac..9b8faeff94eab 100644 --- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala +++ b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala @@ -156,6 +156,10 @@ private[hive] class SparkExecuteStatementOperation( def run(): Unit = { logInfo(s"Running query '$statement'") setState(OperationState.RUNNING) + hiveContext.sparkContext.setJobDescription(statement) + sessionToActivePool.get(parentSession.getSessionHandle).foreach { pool => + hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) + } try { result = hiveContext.sql(statement) logDebug(result.queryExecution.toString()) @@ -165,10 +169,6 @@ private[hive] class SparkExecuteStatementOperation( logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") case _ => } - hiveContext.sparkContext.setJobDescription(statement) - sessionToActivePool.get(parentSession.getSessionHandle).foreach { pool => - hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) - } iter = { val useIncrementalCollect = hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 0d934620aca09..6126ce7130426 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -225,6 +225,11 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Needs constant object inspectors "udf_round", + // the table src(key INT, value STRING) is not the same as HIVE unittest. In Hive + // is src(key STRING, value STRING), and in the reflect.q, it failed in + // Integer.valueOf, which expect the first argument passed as STRING type not INT. + "udf_reflect", + // Sort with Limit clause causes failure. "ctas", "ctas_hadoop20", @@ -357,6 +362,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "database_drop", "database_location", "database_properties", + "date_1", "date_2", "date_3", "date_4", @@ -517,6 +523,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "inputddl2", "inputddl3", "inputddl4", + "inputddl5", "inputddl6", "inputddl7", "inputddl8", @@ -625,6 +632,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "mapreduce8", "merge1", "merge2", + "merge4", "mergejoins", "multiMapJoin1", "multiMapJoin2", @@ -638,6 +646,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "nonblock_op_deduplicate", "notable_alias1", "notable_alias2", + "nullformatCTAS", "nullgroup", "nullgroup2", "nullgroup3", @@ -883,6 +892,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_power", "udf_radians", "udf_rand", + "udf_reflect2", "udf_regexp", "udf_regexp_extract", "udf_regexp_replace", diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 58b0722464be8..e27cf9a294c69 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -21,8 +21,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.3.2-SNAPSHOT ../../pom.xml @@ -84,6 +84,11 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index f6d9027f90a99..5c8782fc7a49d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -18,30 +18,28 @@ package org.apache.spark.sql.hive import java.io.{BufferedReader, InputStreamReader, PrintStream} -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import scala.collection.JavaConversions._ import scala.language.implicitConversions -import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.Table -import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.parse.VariableSubstitution +import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateAnalysisOperators, OverrideCatalog, OverrideFunctionRegistry} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateSubQueries, OverrideCatalog, OverrideFunctionRegistry} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand, QueryExecutionException} -import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DescribeHiveTableCommand} -import org.apache.spark.sql.sources.{CreateTableUsing, DataSourceStrategy} +import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, QueryExecutionException, SetCommand} +import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} +import org.apache.spark.sql.sources.{DDLParser, DataSourceStrategy} import org.apache.spark.sql.types._ /** @@ -63,34 +61,49 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { protected[sql] def convertMetastoreParquet: Boolean = getConf("spark.sql.hive.convertMetastoreParquet", "true") == "true" + /** + * When true, also tries to merge possibly different but compatible Parquet schemas in different + * Parquet data files. + * + * This configuration is only effective when "spark.sql.hive.convertMetastoreParquet" is true. + */ + protected[sql] def convertMetastoreParquetWithSchemaMerging: Boolean = + getConf("spark.sql.hive.convertMetastoreParquet.mergeSchema", "false") == "true" + + /** + * When true, a table created by a Hive CTAS statement (no USING clause) will be + * converted to a data source table, using the data source set by spark.sql.sources.default. + * The table in CTAS statement will be converted when it meets any of the following conditions: + * - The CTAS does not specify any of a SerDe (ROW FORMAT SERDE), a File Format (STORED AS), or + * a Storage Hanlder (STORED BY), and the value of hive.default.fileformat in hive-site.xml + * is either TextFile or SequenceFile. + * - The CTAS statement specifies TextFile (STORED AS TEXTFILE) as the file format and no SerDe + * is specified (no ROW FORMAT SERDE clause). + * - The CTAS statement specifies SequenceFile (STORED AS SEQUENCEFILE) as the file format + * and no SerDe is specified (no ROW FORMAT SERDE clause). + */ + protected[sql] def convertCTAS: Boolean = + getConf("spark.sql.hive.convertCTAS", "false").toBoolean + override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution(plan) + @transient + protected[sql] val ddlParserWithHiveQL = new DDLParser(HiveQl.parseSql(_)) + override def sql(sqlText: String): DataFrame = { val substituted = new VariableSubstitution().substitute(hiveconf, sqlText) // TODO: Create a framework for registering parsers instead of just hardcoding if statements. if (conf.dialect == "sql") { super.sql(substituted) } else if (conf.dialect == "hiveql") { - DataFrame(this, - ddlParser(sqlText, exceptionOnError = false).getOrElse(HiveQl.parseSql(substituted))) + val ddlPlan = ddlParserWithHiveQL(sqlText, exceptionOnError = false) + DataFrame(this, ddlPlan.getOrElse(HiveQl.parseSql(substituted))) } else { - sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'") + sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'") } } - /** - * Creates a table using the schema of the given class. - * - * @param tableName The name of the table to create. - * @param allowExisting When false, an exception will be thrown if the table already exists. - * @tparam A A case class that is used to describe the schema of the table to be created. - */ - @Deprecated - def createTable[A <: Product : TypeTag](tableName: String, allowExisting: Boolean = true) { - catalog.createTable("default", tableName, ScalaReflection.attributesFor[A], allowExisting) - } - /** * Invalidate and refresh all the cached the metadata of the given table. For performance reasons, * Spark SQL or the external data source library it uses might cache certain metadata about a @@ -107,70 +120,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { catalog.invalidateTable("default", tableName) } - @Experimental - def createTable(tableName: String, path: String, allowExisting: Boolean): Unit = { - val dataSourceName = conf.defaultDataSourceName - createTable(tableName, dataSourceName, allowExisting, ("path", path)) - } - - @Experimental - def createTable( - tableName: String, - dataSourceName: String, - allowExisting: Boolean, - option: (String, String), - options: (String, String)*): Unit = { - val cmd = - CreateTableUsing( - tableName, - userSpecifiedSchema = None, - dataSourceName, - temporary = false, - (option +: options).toMap, - allowExisting) - executePlan(cmd).toRdd - } - - @Experimental - def createTable( - tableName: String, - dataSourceName: String, - schema: StructType, - allowExisting: Boolean, - option: (String, String), - options: (String, String)*): Unit = { - val cmd = - CreateTableUsing( - tableName, - userSpecifiedSchema = Some(schema), - dataSourceName, - temporary = false, - (option +: options).toMap, - allowExisting) - executePlan(cmd).toRdd - } - - @Experimental - def createTable( - tableName: String, - dataSourceName: String, - allowExisting: Boolean, - options: java.util.Map[String, String]): Unit = { - val opts = options.toSeq - createTable(tableName, dataSourceName, allowExisting, opts.head, opts.tail:_*) - } - - @Experimental - def createTable( - tableName: String, - dataSourceName: String, - schema: StructType, - allowExisting: Boolean, - options: java.util.Map[String, String]): Unit = { - val opts = options.toSeq - createTable(tableName, dataSourceName, schema, allowExisting, opts.head, opts.tail:_*) - } - /** * Analyzes the given table in the current database to generate statistics, which will be * used in query optimizations. @@ -180,7 +129,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { */ @Experimental def analyze(tableName: String) { - val relation = EliminateAnalysisOperators(catalog.lookupRelation(Seq(tableName))) + val relation = EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) relation match { case relation: MetastoreRelation => @@ -236,18 +185,19 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val tableFullName = relation.hiveQlTable.getDbName + "." + relation.hiveQlTable.getTableName - catalog.client.alterTable(tableFullName, new Table(hiveTTable)) + catalog.synchronized { + catalog.client.alterTable(tableFullName, new Table(hiveTTable)) + } } case otherRelation => - throw new NotImplementedError( - s"Analyze has only implemented for Hive tables, " + - s"but $tableName is a ${otherRelation.nodeName}") + throw new UnsupportedOperationException( + s"Analyze only works for Hive tables, but $tableName is a ${otherRelation.nodeName}") } } // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. @transient - protected lazy val outputBuffer = new java.io.OutputStream { + protected lazy val outputBuffer = new java.io.OutputStream { var pos: Int = 0 var buffer = new Array[Int](10240) def write(i: Int): Unit = { @@ -255,7 +205,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { pos = (pos + 1) % buffer.size } - override def toString = { + override def toString: String = { val (end, start) = buffer.splitAt(pos) val input = new java.io.InputStream { val iterator = (start ++ end).iterator @@ -282,22 +232,25 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * SQLConf. Additionally, any properties set by set() or a SET command inside sql() will be * set in the SQLConf *as well as* in the HiveConf. */ - @transient protected[hive] lazy val (hiveconf, sessionState) = - Option(SessionState.get()) - .orElse { - val newState = new SessionState(new HiveConf(classOf[SessionState])) - // Only starts newly created `SessionState` instance. Any existing `SessionState` instance - // returned by `SessionState.get()` must be the most recently started one. - SessionState.start(newState) - Some(newState) - } - .map { state => - setConf(state.getConf.getAllProperties) - if (state.out == null) state.out = new PrintStream(outputBuffer, true, "UTF-8") - if (state.err == null) state.err = new PrintStream(outputBuffer, true, "UTF-8") - (state.getConf, state) - } - .get + @transient protected[hive] lazy val sessionState: SessionState = { + var state = SessionState.get() + if (state == null) { + state = new SessionState(new HiveConf(classOf[SessionState])) + SessionState.start(state) + } + if (state.out == null) { + state.out = new PrintStream(outputBuffer, true, "UTF-8") + } + if (state.err == null) { + state.err = new PrintStream(outputBuffer, true, "UTF-8") + } + state + } + + @transient protected[hive] lazy val hiveconf: HiveConf = { + setConf(sessionState.getConf.getAllProperties) + sessionState.getConf + } override def setConf(key: String, value: String): Unit = { super.setConf(key, value) @@ -311,16 +264,21 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { // Note that HiveUDFs will be overridden by functions registered in this context. @transient override protected[sql] lazy val functionRegistry = - new HiveFunctionRegistry with OverrideFunctionRegistry + new HiveFunctionRegistry with OverrideFunctionRegistry { + def caseSensitive: Boolean = false + } /* An analyzer that uses the Hive metastore. */ @transient override protected[sql] lazy val analyzer = new Analyzer(catalog, functionRegistry, caseSensitive = false) { - override val extendedRules = + override val extendedResolutionRules = + catalog.ParquetConversions :: catalog.CreateTables :: catalog.PreInsertionCasts :: ExtractPythonUdfs :: + ResolveUdtfsAlias :: + sources.PreInsertCastAndRename :: Nil } @@ -420,6 +378,11 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /** Extends QueryExecution with hive specific features. */ protected[sql] class QueryExecution(logicalPlan: LogicalPlan) extends super.QueryExecution(logicalPlan) { + // Like what we do in runHive, makes sure the session represented by the + // `sessionState` field is activated. + if (SessionState.get() != sessionState) { + SessionState.start(sessionState) + } /** * Returns the result as a hive compatible sequence of strings. For native commands, the @@ -475,7 +438,7 @@ private object HiveContext { toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) }.toSeq.sorted.mkString("{", ",", "}") case (null, _) => "NULL" - case (d: Date, DateType) => new DateWritable(d).toString + case (d: Int, DateType) => new DateWritable(d).toString case (t: Timestamp, TimestampType) => new TimestampWritable(t).toString case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8") case (decimal: java.math.BigDecimal, DecimalType()) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 82dba99900df9..4afa2e71d77cc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -267,7 +267,8 @@ private[hive] trait HiveInspectors { val temp = new Array[Byte](writable.getLength) System.arraycopy(writable.getBytes, 0, temp, 0, temp.length) temp - case poi: WritableConstantDateObjectInspector => poi.getWritableConstantValue.get() + case poi: WritableConstantDateObjectInspector => + DateUtils.fromJavaDate(poi.getWritableConstantValue.get()) case mi: StandardConstantMapObjectInspector => // take the value from the map inspector object, rather than the input data mi.getWritableConstantValue.map { case (k, v) => @@ -304,7 +305,8 @@ private[hive] trait HiveInspectors { System.arraycopy(bw.getBytes(), 0, result, 0, bw.getLength()) result case x: DateObjectInspector if x.preferWritable() => - x.getPrimitiveWritableObject(data).get() + DateUtils.fromJavaDate(x.getPrimitiveWritableObject(data).get()) + case x: DateObjectInspector => DateUtils.fromJavaDate(x.getPrimitiveJavaObject(data)) // org.apache.hadoop.hive.serde2.io.TimestampWritable.set will reset current time object // if next timestamp is null, so Timestamp object is cloned case x: TimestampObjectInspector if x.preferWritable() => @@ -343,6 +345,9 @@ private[hive] trait HiveInspectors { case _: JavaHiveDecimalObjectInspector => (o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toJavaBigDecimal) + case _: JavaDateObjectInspector => + (o: Any) => DateUtils.toJavaDate(o.asInstanceOf[Int]) + case soi: StandardStructObjectInspector => val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector)) (o: Any) => { @@ -426,7 +431,7 @@ private[hive] trait HiveInspectors { case _: BinaryObjectInspector if x.preferWritable() => HiveShim.getBinaryWritable(a) case _: BinaryObjectInspector => a.asInstanceOf[Array[Byte]] case _: DateObjectInspector if x.preferWritable() => HiveShim.getDateWritable(a) - case _: DateObjectInspector => a.asInstanceOf[java.sql.Date] + case _: DateObjectInspector => DateUtils.toJavaDate(a.asInstanceOf[Int]) case _: TimestampObjectInspector if x.preferWritable() => HiveShim.getTimestampWritable(a) case _: TimestampObjectInspector => a.asInstanceOf[java.sql.Timestamp] } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 48bea6c1bd685..a906c921df83a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -20,25 +20,27 @@ package org.apache.spark.sql.hive import java.io.IOException import java.util.{List => JList} -import com.google.common.cache.{LoadingCache, CacheLoader, CacheBuilder} - -import org.apache.hadoop.util.ReflectionUtils -import org.apache.hadoop.hive.metastore.{Warehouse, TableType} -import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition, AlreadyExistsException, FieldSchema} +import com.google.common.base.Objects +import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} +import org.apache.hadoop.hive.metastore.api.{FieldSchema, Partition => TPartition, Table => TTable} +import org.apache.hadoop.hive.metastore.{TableType, Warehouse} import org.apache.hadoop.hive.ql.metadata._ import org.apache.hadoop.hive.ql.plan.CreateTableDesc import org.apache.hadoop.hive.serde.serdeConstants -import org.apache.hadoop.hive.serde2.{Deserializer, SerDeException} import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe +import org.apache.hadoop.hive.serde2.{Deserializer, SerDeException} +import org.apache.hadoop.util.ReflectionUtils import org.apache.spark.Logging -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.analysis.{Catalog, OverrideCatalog} +import org.apache.spark.sql.{SaveMode, AnalysisException, SQLContext} +import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NoSuchTableException, Catalog, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.sources.{DDLParser, LogicalRelation, ResolvedDataSource} +import org.apache.spark.sql.parquet.{ParquetRelation2, Partition => ParquetPartition, PartitionSpec} +import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, DDLParser, LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -51,12 +53,13 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with /** Connection to hive metastore. Usages should lock on `this`. */ protected[hive] val client = Hive.get(hive.hiveconf) + /** Usages should lock on `this`. */ protected[hive] lazy val hiveWarehouse = new Warehouse(hive.hiveconf) // TODO: Use this everywhere instead of tuples or databaseName, tableName,. /** A fully qualified identifier for a table (i.e., database.tableName) */ case class QualifiedTableName(database: String, name: String) { - def toLowerCase = QualifiedTableName(database.toLowerCase, name.toLowerCase) + def toLowerCase: QualifiedTableName = QualifiedTableName(database.toLowerCase, name.toLowerCase) } /** A cache of Spark SQL data source tables that have been accessed. */ @@ -64,14 +67,36 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with val cacheLoader = new CacheLoader[QualifiedTableName, LogicalPlan]() { override def load(in: QualifiedTableName): LogicalPlan = { logDebug(s"Creating new cached data source for $in") - val table = client.getTable(in.database, in.name) - val schemaString = table.getProperty("spark.sql.sources.schema") - val userSpecifiedSchema = - if (schemaString == null) { - None - } else { - Some(DataType.fromJson(schemaString).asInstanceOf[StructType]) + val table = HiveMetastoreCatalog.this.synchronized { + client.getTable(in.database, in.name) + } + + def schemaStringFromParts: Option[String] = { + Option(table.getProperty("spark.sql.sources.schema.numParts")).map { numParts => + val parts = (0 until numParts.toInt).map { index => + val part = table.getProperty(s"spark.sql.sources.schema.part.${index}") + if (part == null) { + throw new AnalysisException( + s"Could not read schema from the metastore because it is corrupted " + + s"(missing part ${index} of the schema).") + } + + part + } + // Stick all parts back to a single schema string. + parts.mkString } + } + + // Originally, we used spark.sql.sources.schema to store the schema of a data source table. + // After SPARK-6024, we removed this flag. + // Although we are not using spark.sql.sources.schema any more, we need to still support. + val schemaString = + Option(table.getProperty("spark.sql.sources.schema")).orElse(schemaStringFromParts) + + val userSpecifiedSchema = + schemaString.map(s => DataType.fromJson(s).asInstanceOf[StructType]) + // It does not appear that the ql client for the metastore has a way to enumerate all the // SerDe properties directly... val options = table.getTTable.getSd.getSerdeInfo.getParameters.toMap @@ -90,8 +115,16 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with CacheBuilder.newBuilder().maximumSize(1000).build(cacheLoader) } - def refreshTable(databaseName: String, tableName: String): Unit = { - cachedDataSourceTables.refresh(QualifiedTableName(databaseName, tableName).toLowerCase) + override def refreshTable(databaseName: String, tableName: String): Unit = { + // refreshTable does not eagerly reload the cache. It just invalidate the cache. + // Next time when we use the table, it will be populated in the cache. + // Since we also cache ParquetRealtions converted from Hive Parquet tables and + // adding converted ParquetRealtions into the cache is not defined in the load function + // of the cache (instead, we add the cache entry in convertToParquetRelation), + // it is better at here to invalidate the cache to avoid confusing waring logs from the + // cache loader (e.g. cannot find data source provider, which is only defined for + // data source table.). + invalidateTable(databaseName, tableName) } def invalidateTable(databaseName: String, tableName: String): Unit = { @@ -100,16 +133,10 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with val caseSensitive: Boolean = false - /** * - * Creates a data source table (a table created with USING clause) in Hive's metastore. - * Returns true when the table has been created. Otherwise, false. - * @param tableName - * @param userSpecifiedSchema - * @param provider - * @param options - * @param isExternal - * @return - */ + /** + * Creates a data source table (a table created with USING clause) in Hive's metastore. + * Returns true when the table has been created. Otherwise, false. + */ def createDataSourceTable( tableName: String, userSpecifiedSchema: Option[StructType], @@ -121,7 +148,14 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with tbl.setProperty("spark.sql.sources.provider", provider) if (userSpecifiedSchema.isDefined) { - tbl.setProperty("spark.sql.sources.schema", userSpecifiedSchema.get.json) + val threshold = hive.conf.schemaStringLengthThreshold + val schemaJsonString = userSpecifiedSchema.get.json + // Split the JSON string. + val parts = schemaJsonString.grouped(threshold).toSeq + tbl.setProperty("spark.sql.sources.schema.numParts", parts.size.toString) + parts.zipWithIndex.foreach { case (part, index) => + tbl.setProperty(s"spark.sql.sources.schema.part.${index}", part) + } } options.foreach { case (key, value) => tbl.setSerdeParam(key, value) } @@ -139,33 +173,48 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } } - def hiveDefaultTableFilePath(tableName: String): String = { - hiveWarehouse.getTablePath(client.getDatabaseCurrent, tableName).toString + def hiveDefaultTableFilePath(tableName: String): String = synchronized { + val currentDatabase = client.getDatabase(hive.sessionState.getCurrentDatabase) + + hiveWarehouse.getTablePath(currentDatabase, tableName).toString } - def tableExists(tableIdentifier: Seq[String]): Boolean = { + def tableExists(tableIdentifier: Seq[String]): Boolean = synchronized { val tableIdent = processTableIdentifier(tableIdentifier) - val databaseName = tableIdent.lift(tableIdent.size - 2).getOrElse( - hive.sessionState.getCurrentDatabase) + val databaseName = + tableIdent + .lift(tableIdent.size - 2) + .getOrElse(hive.sessionState.getCurrentDatabase) val tblName = tableIdent.last - try { - client.getTable(databaseName, tblName) != null - } catch { - case ie: InvalidTableException => false - } + client.getTable(databaseName, tblName, false) != null } def lookupRelation( tableIdentifier: Seq[String], - alias: Option[String]): LogicalPlan = synchronized { + alias: Option[String]): LogicalPlan = { val tableIdent = processTableIdentifier(tableIdentifier) val databaseName = tableIdent.lift(tableIdent.size - 2).getOrElse( hive.sessionState.getCurrentDatabase) val tblName = tableIdent.last - val table = client.getTable(databaseName, tblName) + val table = try { + synchronized { + client.getTable(databaseName, tblName) + } + } catch { + case te: org.apache.hadoop.hive.ql.metadata.InvalidTableException => + throw new NoSuchTableException + } if (table.getProperty("spark.sql.sources.provider") != null) { - cachedDataSourceTables(QualifiedTableName(databaseName, tblName).toLowerCase) + val dataSourceTable = + cachedDataSourceTables(QualifiedTableName(databaseName, tblName).toLowerCase) + // Then, if alias is specified, wrap the table with a Subquery using the alias. + // Othersie, wrap the table with a Subquery using the table name. + val withAlias = + alias.map(a => Subquery(a, dataSourceTable)).getOrElse( + Subquery(tableIdent.last, dataSourceTable)) + + withAlias } else if (table.isView) { // if the unresolved relation is from hive view // parse the text into logic node. @@ -173,16 +222,111 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } else { val partitions: Seq[Partition] = if (table.isPartitioned) { - HiveShim.getAllPartitionsOf(client, table).toSeq + synchronized { + HiveShim.getAllPartitionsOf(client, table).toSeq + } } else { Nil } - // Since HiveQL is case insensitive for table names we make them all lowercase. - MetastoreRelation( - databaseName, tblName, alias)( - table.getTTable, partitions.map(part => part.getTPartition))(hive) + MetastoreRelation(databaseName, tblName, alias)( + table.getTTable, partitions.map(part => part.getTPartition))(hive) + } + } + + private def convertToParquetRelation(metastoreRelation: MetastoreRelation): LogicalRelation = { + val metastoreSchema = StructType.fromAttributes(metastoreRelation.output) + val mergeSchema = hive.convertMetastoreParquetWithSchemaMerging + + // NOTE: Instead of passing Metastore schema directly to `ParquetRelation2`, we have to + // serialize the Metastore schema to JSON and pass it as a data source option because of the + // evil case insensitivity issue, which is reconciled within `ParquetRelation2`. + val parquetOptions = Map( + ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json, + ParquetRelation2.MERGE_SCHEMA -> mergeSchema.toString) + val tableIdentifier = + QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName) + + def getCached( + tableIdentifier: QualifiedTableName, + pathsInMetastore: Seq[String], + schemaInMetastore: StructType, + partitionSpecInMetastore: Option[PartitionSpec]): Option[LogicalRelation] = { + cachedDataSourceTables.getIfPresent(tableIdentifier) match { + case null => None // Cache miss + case logical@LogicalRelation(parquetRelation: ParquetRelation2) => + // If we have the same paths, same schema, and same partition spec, + // we will use the cached Parquet Relation. + val useCached = + parquetRelation.paths.toSet == pathsInMetastore.toSet && + logical.schema.sameType(metastoreSchema) && + parquetRelation.maybePartitionSpec == partitionSpecInMetastore + + if (useCached) { + Some(logical) + } else { + // If the cached relation is not updated, we invalidate it right away. + cachedDataSourceTables.invalidate(tableIdentifier) + None + } + case other => + logWarning( + s"${metastoreRelation.databaseName}.${metastoreRelation.tableName} shold be stored " + + s"as Parquet. However, we are getting a ${other} from the metastore cache. " + + s"This cached entry will be invalidated.") + cachedDataSourceTables.invalidate(tableIdentifier) + None + } + } + + val result = if (metastoreRelation.hiveQlTable.isPartitioned) { + val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys) + val partitionColumnDataTypes = partitionSchema.map(_.dataType) + val partitions = metastoreRelation.hiveQlPartitions.map { p => + val location = p.getLocation + val values = Row.fromSeq(p.getValues.zip(partitionColumnDataTypes).map { + case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null) + }) + ParquetPartition(values, location) + } + val partitionSpec = PartitionSpec(partitionSchema, partitions) + val paths = partitions.map(_.path) + + val cached = getCached(tableIdentifier, paths, metastoreSchema, Some(partitionSpec)) + val parquetRelation = cached.getOrElse { + val created = + LogicalRelation(ParquetRelation2(paths, parquetOptions, None, Some(partitionSpec))(hive)) + cachedDataSourceTables.put(tableIdentifier, created) + created + } + + parquetRelation + } else { + val paths = Seq(metastoreRelation.hiveQlTable.getDataLocation.toString) + + val cached = getCached(tableIdentifier, paths, metastoreSchema, None) + val parquetRelation = cached.getOrElse { + val created = + LogicalRelation(ParquetRelation2(paths, parquetOptions)(hive)) + cachedDataSourceTables.put(tableIdentifier, created) + created + } + + parquetRelation } + + result.newInstance() + } + + override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = synchronized { + val dbName = if (!caseSensitive) { + if (databaseName.isDefined) Some(databaseName.get.toLowerCase) else None + } else { + databaseName + } + val db = dbName.getOrElse(hive.sessionState.getCurrentDatabase) + + client.getAllTables(db).map(tableName => (tableName, false)) } /** @@ -212,7 +356,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with val hiveSchema: JList[FieldSchema] = if (schema == null || schema.isEmpty) { crtTbl.getCols } else { - schema.map(attr => new FieldSchema(attr.name, toMetastoreType(attr.dataType), "")) + schema.map(attr => new FieldSchema(attr.name, toMetastoreType(attr.dataType), null)) } tbl.setFields(hiveSchema) @@ -243,9 +387,9 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with logInfo(s"Default to LazySimpleSerDe for table $dbName.$tblName") tbl.setSerializationLib(classOf[LazySimpleSerDe].getName()) - import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat import org.apache.hadoop.io.Text + import org.apache.hadoop.mapred.TextInputFormat tbl.setInputFormatClass(classOf[TextInputFormat]) tbl.setOutputFormatClass(classOf[HiveIgnoreKeyTextOutputFormat[Text, Text]]) @@ -286,6 +430,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with if (crtTbl != null && crtTbl.getLineDelim() != null) { tbl.setSerdeParam(serdeConstants.LINE_DELIM, crtTbl.getLineDelim()) } + HiveShim.setTblNullFormat(crtTbl, tbl) if (crtTbl != null && crtTbl.getSerdeProps() != null) { val iter = crtTbl.getSerdeProps().entrySet().iterator() @@ -367,13 +512,87 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } } + /** + * When scanning or writing to non-partitioned Metastore Parquet tables, convert them to Parquet + * data source relations for better performance. + * + * This rule can be considered as [[HiveStrategies.ParquetConversion]] done right. + */ + object ParquetConversions extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!plan.resolved) { + return plan + } + + // Collects all `MetastoreRelation`s which should be replaced + val toBeReplaced = plan.collect { + // Write path + case InsertIntoTable(relation: MetastoreRelation, _, _, _) + // Inserting into partitioned table is not supported in Parquet data source (yet). + if !relation.hiveQlTable.isPartitioned && + hive.convertMetastoreParquet && + hive.conf.parquetUseDataSourceApi && + relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => + val parquetRelation = convertToParquetRelation(relation) + val attributedRewrites = relation.output.zip(parquetRelation.output) + (relation, parquetRelation, attributedRewrites) + + // Write path + case InsertIntoHiveTable(relation: MetastoreRelation, _, _, _) + // Inserting into partitioned table is not supported in Parquet data source (yet). + if !relation.hiveQlTable.isPartitioned && + hive.convertMetastoreParquet && + hive.conf.parquetUseDataSourceApi && + relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => + val parquetRelation = convertToParquetRelation(relation) + val attributedRewrites = relation.output.zip(parquetRelation.output) + (relation, parquetRelation, attributedRewrites) + + // Read path + case p @ PhysicalOperation(_, _, relation: MetastoreRelation) + if hive.convertMetastoreParquet && + hive.conf.parquetUseDataSourceApi && + relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => + val parquetRelation = convertToParquetRelation(relation) + val attributedRewrites = relation.output.zip(parquetRelation.output) + (relation, parquetRelation, attributedRewrites) + } + + val relationMap = toBeReplaced.map(r => (r._1, r._2)).toMap + val attributedRewrites = AttributeMap(toBeReplaced.map(_._3).fold(Nil)(_ ++: _)) + + // Replaces all `MetastoreRelation`s with corresponding `ParquetRelation2`s, and fixes + // attribute IDs referenced in other nodes. + plan.transformUp { + case r: MetastoreRelation if relationMap.contains(r) => + val parquetRelation = relationMap(r) + val alias = r.alias.getOrElse(r.tableName) + Subquery(alias, parquetRelation) + + case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite) + if relationMap.contains(r) => + val parquetRelation = relationMap(r) + InsertIntoTable(parquetRelation, partition, child, overwrite) + + case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite) + if relationMap.contains(r) => + val parquetRelation = relationMap(r) + InsertIntoTable(parquetRelation, partition, child, overwrite) + + case other => other.transformExpressions { + case a: Attribute if a.resolved => attributedRewrites.getOrElse(a, a) + } + } + } + } + /** * Creates any tables required for query execution. * For example, because of a CREATE TABLE X AS statement. */ object CreateTables extends Rule[LogicalPlan] { import org.apache.hadoop.hive.ql.Context - import org.apache.hadoop.hive.ql.parse.{QB, ASTNode, SemanticAnalyzer} + import org.apache.hadoop.hive.ql.parse.{ASTNode, QB, SemanticAnalyzer} def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Wait until children are resolved. @@ -404,24 +623,69 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with Some(sa.getQB().getTableDesc) } - execution.CreateTableAsSelect( - databaseName, - tableName, - child, - allowExisting, - desc) + // Check if the query specifies file format or storage handler. + val hasStorageSpec = desc match { + case Some(crtTbl) => + crtTbl != null && (crtTbl.getSerName != null || crtTbl.getStorageHandler != null) + case None => false + } + + if (hive.convertCTAS && !hasStorageSpec) { + // Do the conversion when spark.sql.hive.convertCTAS is true and the query + // does not specify any storage format (file format and storage handler). + if (dbName.isDefined) { + throw new AnalysisException( + "Cannot specify database name in a CTAS statement " + + "when spark.sql.hive.convertCTAS is set to true.") + } + + val mode = if (allowExisting) SaveMode.Ignore else SaveMode.ErrorIfExists + CreateTableUsingAsSelect( + tblName, + hive.conf.defaultDataSourceName, + temporary = false, + mode, + options = Map.empty[String, String], + child + ) + } else { + execution.CreateTableAsSelect( + databaseName, + tableName, + child, + allowExisting, + desc) + } case p: LogicalPlan if p.resolved => p case p @ CreateTableAsSelect(db, tableName, child, allowExisting, None) => val (dbName, tblName) = processDatabaseAndTableName(db, tableName) - val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase) - execution.CreateTableAsSelect( - databaseName, - tableName, - child, - allowExisting, - None) + if (hive.convertCTAS) { + if (dbName.isDefined) { + throw new AnalysisException( + "Cannot specify database name in a CTAS statement " + + "when spark.sql.hive.convertCTAS is set to true.") + } + + val mode = if (allowExisting) SaveMode.Ignore else SaveMode.ErrorIfExists + CreateTableUsingAsSelect( + tblName, + hive.conf.defaultDataSourceName, + temporary = false, + mode, + options = Map.empty[String, String], + child + ) + } else { + val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase) + execution.CreateTableAsSelect( + databaseName, + tableName, + child, + allowExisting, + None) + } } } @@ -438,7 +702,8 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with castChildOutput(p, table, child) } - def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan) = { + def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan) + : LogicalPlan = { val childOutputDataTypes = child.output.map(_.dataType) val tableOutputDataTypes = (table.attributes ++ table.partitionKeys).take(child.output.length).map(_.dataType) @@ -447,7 +712,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with p } else if (childOutputDataTypes.size == tableOutputDataTypes.size && childOutputDataTypes.zip(tableOutputDataTypes) - .forall { case (left, right) => DataType.equalsIgnoreNullability(left, right) }) { + .forall { case (left, right) => left.sameType(right) }) { // If both types ignoring nullability of ArrayType, MapType, StructType are the same, // use InsertIntoHiveTable instead of InsertIntoTable. InsertIntoHiveTable(p.table, p.partition, p.child, p.overwrite) @@ -476,7 +741,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with */ override def unregisterTable(tableIdentifier: Seq[String]): Unit = ??? - override def unregisterAllTables() = {} + override def unregisterAllTables(): Unit = {} } /** @@ -491,12 +756,11 @@ private[hive] case class InsertIntoHiveTable( overwrite: Boolean) extends LogicalPlan { - override def children = child :: Nil - override def output = child.output + override def children: Seq[LogicalPlan] = child :: Nil + override def output: Seq[Attribute] = child.output - override lazy val resolved = childrenResolved && child.output.zip(table.output).forall { - case (childAttr, tableAttr) => - DataType.equalsIgnoreNullability(childAttr.dataType, tableAttr.dataType) + override lazy val resolved: Boolean = childrenResolved && child.output.zip(table.output).forall { + case (childAttr, tableAttr) => childAttr.dataType.sameType(tableAttr.dataType) } } @@ -504,23 +768,36 @@ private[hive] case class MetastoreRelation (databaseName: String, tableName: String, alias: Option[String]) (val table: TTable, val partitions: Seq[TPartition]) (@transient sqlContext: SQLContext) - extends LeafNode { + extends LeafNode with MultiInstanceRelation { self: Product => + override def equals(other: scala.Any): Boolean = other match { + case relation: MetastoreRelation => + databaseName == relation.databaseName && + tableName == relation.tableName && + alias == relation.alias && + output == relation.output + case _ => false + } + + override def hashCode(): Int = { + Objects.hashCode(databaseName, tableName, alias, output) + } + // TODO: Can we use org.apache.hadoop.hive.ql.metadata.Table as the type of table and // use org.apache.hadoop.hive.ql.metadata.Partition as the type of elements of partitions. // Right now, using org.apache.hadoop.hive.ql.metadata.Table and // org.apache.hadoop.hive.ql.metadata.Partition will cause a NotSerializableException // which indicates the SerDe we used is not Serializable. - @transient val hiveQlTable = new Table(table) + @transient val hiveQlTable: Table = new Table(table) - @transient val hiveQlPartitions = partitions.map { p => + @transient val hiveQlPartitions: Seq[Partition] = partitions.map { p => new Partition(hiveQlTable, p) } - @transient override lazy val statistics = Statistics( + @transient override lazy val statistics: Statistics = Statistics( sizeInBytes = { val totalSize = hiveQlTable.getParameters.get(HiveShim.getStatsSetupConstTotalSize) val rawDataSize = hiveQlTable.getParameters.get(HiveShim.getStatsSetupConstRawDataSize) @@ -564,9 +841,9 @@ private[hive] case class MetastoreRelation ) implicit class SchemaAttribute(f: FieldSchema) { - def toAttribute = AttributeReference( + def toAttribute: AttributeReference = AttributeReference( f.getName, - sqlContext.ddlParser.parseType(f.getType), + HiveMetastoreTypes.toDataType(f.getType), // Since data can be dumped in randomly with no validation, everything is nullable. nullable = true )(qualifiers = Seq(alias.getOrElse(tableName))) @@ -585,14 +862,15 @@ private[hive] case class MetastoreRelation /** An attribute map for determining the ordinal for non-partition columns. */ val columnOrdinals = AttributeMap(attributes.zipWithIndex) + + override def newInstance() = { + MetastoreRelation(databaseName, tableName, alias)(table, partitions)(sqlContext) + } } -object HiveMetastoreTypes { - protected val ddlParser = new DDLParser - def toDataType(metastoreType: String): DataType = synchronized { - ddlParser.parseType(metastoreType) - } +private[hive] object HiveMetastoreTypes { + def toDataType(metastoreType: String): DataType = DataTypeParser(metastoreType) def toMetastoreType(dt: DataType): String = dt match { case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>" diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index ab305e1f82a55..0c2ffae9ac595 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -19,22 +19,28 @@ package org.apache.spark.sql.hive import java.sql.Date + +import scala.collection.mutable.ArrayBuffer + import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Context import org.apache.hadoop.hive.ql.lib.Node import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.parse._ import org.apache.hadoop.hive.ql.plan.PlanUtils -import org.apache.spark.sql.SparkSQLParser +import org.apache.spark.sql.{AnalysisException, SparkSQLParser} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.execution.ExplainCommand +import org.apache.spark.sql.sources.DescribeCommand import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DropTable, AnalyzeTable, HiveScriptIOSchema} import org.apache.spark.sql.types._ +import org.apache.spark.util.random.RandomSampler /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -46,22 +52,6 @@ import scala.collection.JavaConversions._ */ private[hive] case object NativePlaceholder extends Command -/** - * Returned for the "DESCRIBE [EXTENDED] [dbName.]tableName" command. - * @param table The table to be described. - * @param isExtended True if "DESCRIBE EXTENDED" is used. Otherwise, false. - * It is effective only when the table is a Hive table. - */ -case class DescribeCommand( - table: LogicalPlan, - isExtended: Boolean) extends Command { - override def output = Seq( - // Column names are based on Hive. - AttributeReference("col_name", StringType, nullable = false)(), - AttributeReference("data_type", StringType, nullable = false)(), - AttributeReference("comment", StringType, nullable = false)()) -} - /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ private[hive] object HiveQl { protected val nativeCommands = Seq( @@ -75,13 +65,13 @@ private[hive] object HiveQl { "TOK_SHOWINDEXES", "TOK_SHOWINDEXES", "TOK_SHOWPARTITIONS", - "TOK_SHOWTABLES", "TOK_SHOW_TBLPROPERTIES", "TOK_LOCKTABLE", "TOK_SHOWLOCKS", "TOK_UNLOCKTABLE", + "TOK_SHOW_ROLES", "TOK_CREATEROLE", "TOK_DROPROLE", "TOK_GRANT", @@ -89,11 +79,13 @@ private[hive] object HiveQl { "TOK_REVOKE", "TOK_SHOW_GRANT", "TOK_SHOW_ROLE_GRANT", + "TOK_SHOW_SET_ROLE", "TOK_CREATEFUNCTION", "TOK_DROPFUNCTION", "TOK_ALTERDATABASE_PROPERTIES", + "TOK_ALTERDATABASE_OWNER", "TOK_ALTERINDEX_PROPERTIES", "TOK_ALTERINDEX_REBUILD", "TOK_ALTERTABLE_ADDCOLS", @@ -115,6 +107,7 @@ private[hive] object HiveQl { "TOK_CREATEINDEX", "TOK_DROPDATABASE", "TOK_DROPINDEX", + "TOK_DROPTABLE_PROPERTIES", "TOK_MSCK", "TOK_ALTERVIEW_ADDPARTS", @@ -123,6 +116,7 @@ private[hive] object HiveQl { "TOK_ALTERVIEW_PROPERTIES", "TOK_ALTERVIEW_RENAME", "TOK_CREATEVIEW", + "TOK_DROPVIEW_PROPERTIES", "TOK_DROPVIEW", "TOK_EXPORT", @@ -135,6 +129,7 @@ private[hive] object HiveQl { // Commands that we do not need to explain. protected val noExplainCommands = Seq( "TOK_DESCTABLE", + "TOK_SHOWTABLES", "TOK_TRUNCATETABLE" // truncate table" is a NativeCommand, does not need to explain. ) ++ nativeCommands @@ -201,8 +196,8 @@ private[hive] object HiveQl { * Right now this function only checks the name, type, text and children of the node * for equality. */ - def checkEquals(other: ASTNode) { - def check(field: String, f: ASTNode => Any) = if (f(n) != f(other)) { + def checkEquals(other: ASTNode): Unit = { + def check(field: String, f: ASTNode => Any): Unit = if (f(n) != f(other)) { sys.error(s"$field does not match for trees. " + s"'${f(n)}' != '${f(other)}' left: ${dumpTree(n)}, right: ${dumpTree(other)}") } @@ -214,17 +209,11 @@ private[hive] object HiveQl { val leftChildren = nilIfEmpty(n.getChildren).asInstanceOf[Seq[ASTNode]] val rightChildren = nilIfEmpty(other.getChildren).asInstanceOf[Seq[ASTNode]] leftChildren zip rightChildren foreach { - case (l,r) => l checkEquals r + case (l, r) => l checkEquals r } } } - class ParseException(sql: String, cause: Throwable) - extends Exception(s"Failed to parse: $sql", cause) - - class SemanticException(msg: String) - extends Exception(s"Error in semantic analysis: $msg") - /** * Returns the AST for the given SQL string. */ @@ -244,8 +233,10 @@ private[hive] object HiveQl { /** Returns a LogicalPlan for a given HiveQL string. */ def parseSql(sql: String): LogicalPlan = hqlParser(sql) + val errorRegEx = "line (\\d+):(\\d+) (.*)".r + /** Creates LogicalPlan for a given HiveQL string. */ - def createPlan(sql: String) = { + def createPlan(sql: String): LogicalPlan = { try { val tree = getAst(sql) if (nativeCommands contains tree.getText) { @@ -257,19 +248,28 @@ private[hive] object HiveQl { } } } catch { - case e: Exception => throw new ParseException(sql, e) - case e: NotImplementedError => sys.error( - s""" - |Unsupported language features in query: $sql - |${dumpTree(getAst(sql))} - |$e - |${e.getStackTrace.head} - """.stripMargin) + case pe: org.apache.hadoop.hive.ql.parse.ParseException => + pe.getMessage match { + case errorRegEx(line, start, message) => + throw new AnalysisException(message, Some(line.toInt), Some(start.toInt)) + case otherMessage => + throw new AnalysisException(otherMessage) + } + case e: Exception => + throw new AnalysisException(e.getMessage) + case e: NotImplementedError => + throw new AnalysisException( + s""" + |Unsupported language features in query: $sql + |${dumpTree(getAst(sql))} + |$e + |${e.getStackTrace.head} + """.stripMargin) } } /** Creates LogicalPlan for a given VIEW */ - def createPlanForView(view: Table, alias: Option[String]) = alias match { + def createPlanForView(view: Table, alias: Option[String]): Subquery = alias match { // because hive use things like `_c0` to build the expanded text // currently we cannot support view from "create view v1(c1) as ..." case None => Subquery(view.getTableName, createPlan(view.getViewExpandedText)) @@ -300,6 +300,7 @@ private[hive] object HiveQl { /** @return matches of the form (tokenName, children). */ def unapply(t: Any): Option[(String, Seq[ASTNode])] = t match { case t: ASTNode => + CurrentOrigin.setPosition(t.getLine, t.getCharPositionInLine) Some((t.getText, Option(t.getChildren).map(_.toList).getOrElse(Nil).asInstanceOf[Seq[ASTNode]])) case _ => None @@ -322,7 +323,7 @@ private[hive] object HiveQl { clauses } - def getClause(clauseName: String, nodeList: Seq[Node]) = + def getClause(clauseName: String, nodeList: Seq[Node]): Node = getClauseOption(clauseName, nodeList).getOrElse(sys.error( s"Expected clause $clauseName missing from ${nodeList.map(dumpTree(_)).mkString("\n")}")) @@ -474,23 +475,21 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // Just fake explain for any of the native commands. case Token("TOK_EXPLAIN", explainArgs) if noExplainCommands.contains(explainArgs.head.getText) => - ExplainCommand(NoRelation, Seq(AttributeReference("plan", StringType, nullable = false)())) + ExplainCommand(OneRowRelation) case Token("TOK_EXPLAIN", explainArgs) if "TOK_CREATETABLE" == explainArgs.head.getText => val Some(crtTbl) :: _ :: extended :: Nil = getClauses(Seq("TOK_CREATETABLE", "FORMATTED", "EXTENDED"), explainArgs) ExplainCommand( nodeToPlan(crtTbl), - Seq(AttributeReference("plan", StringType,nullable = false)()), - extended != None) + extended = extended.isDefined) case Token("TOK_EXPLAIN", explainArgs) => // Ignore FORMATTED if present. val Some(query) :: _ :: extended :: Nil = getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs) ExplainCommand( nodeToPlan(query), - Seq(AttributeReference("plan", StringType, nullable = false)()), - extended != None) + extended = extended.isDefined) case Token("TOK_DESCTABLE", describeArgs) => // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL @@ -509,15 +508,14 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // TODO: Actually, a user may mean tableName.columnName. Need to resolve this issue. val tableIdent = extractTableIdent(nameParts.head) DescribeCommand( - UnresolvedRelation(tableIdent, None), extended.isDefined) + UnresolvedRelation(tableIdent, None), isExtended = extended.isDefined) case Token(".", dbName :: tableName :: colName :: Nil) => // It is describing a column with the format like "describe db.table column". NativePlaceholder case tableName => // It is describing a table with the format like "describe table". DescribeCommand( - UnresolvedRelation(Seq(tableName.getText), None), - extended.isDefined) + UnresolvedRelation(Seq(tableName.getText), None), isExtended = extended.isDefined) } } // All other cases. @@ -552,6 +550,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C "TOK_TBLTEXTFILE", // Stored as TextFile "TOK_TBLRCFILE", // Stored as RCFile "TOK_TBLORCFILE", // Stored as ORC File + "TOK_TBLPARQUETFILE", // Stored as PARQUET "TOK_TABLEFILEFORMAT", // User-provided InputFormat and OutputFormat "TOK_STORAGEHANDLER", // Storage handler "TOK_TABLELOCATION", @@ -568,9 +567,14 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_TRUNCATETABLE", Token("TOK_TABLE_PARTITION",table)::Nil) => NativePlaceholder - case Token("TOK_QUERY", - Token("TOK_FROM", fromClause :: Nil) :: - insertClauses) => + case Token("TOK_QUERY", queryArgs) + if Seq("TOK_FROM", "TOK_INSERT").contains(queryArgs.head.getText) => + + val (fromClause: Option[ASTNode], insertClauses) = queryArgs match { + case Token("TOK_FROM", args: Seq[ASTNode]) :: insertClauses => + (Some(args.head), insertClauses) + case Token("TOK_INSERT", _) :: Nil => (None, queryArgs) + } // Return one query for each insert clause. val queries = insertClauses.map { case Token("TOK_INSERT", singleInsert) => @@ -611,8 +615,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C "TOK_LATERAL_VIEW"), singleInsert) } - - val relations = nodeToRelation(fromClause) + + val relations = fromClause match { + case Some(f) => nodeToRelation(f) + case None => OneRowRelation + } + val withWhere = whereClause.map { whereNode => val Seq(whereExpr) = whereNode.getChildren.toSeq Filter(nodeToExpr(whereExpr), relations) @@ -843,7 +851,15 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_TABLESPLITSAMPLE", Token("TOK_PERCENT", Nil) :: Token(fraction, Nil) :: Nil) => - Sample(fraction.toDouble, withReplacement = false, (math.random * 1000).toInt, relation) + // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling + // function takes X PERCENT as the input and the range of X is [0, 100], we need to + // adjust the fraction. + require( + fraction.toDouble >= (0.0 - RandomSampler.roundingEpsilon) + && fraction.toDouble <= (100.0 + RandomSampler.roundingEpsilon), + s"Sampling fraction ($fraction) must be on interval [0, 100]") + Sample(fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt, + relation) case Token("TOK_TABLEBUCKETSAMPLE", Token(numerator, Nil) :: Token(denominator, Nil) :: Nil) => @@ -968,14 +984,21 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } protected def selExprNodeToExpr(node: Node): Option[Expression] = node match { - case Token("TOK_SELEXPR", - e :: Nil) => + case Token("TOK_SELEXPR", e :: Nil) => Some(nodeToExpr(e)) - case Token("TOK_SELEXPR", - e :: Token(alias, Nil) :: Nil) => + case Token("TOK_SELEXPR", e :: Token(alias, Nil) :: Nil) => Some(Alias(nodeToExpr(e), cleanIdentifier(alias))()) + case Token("TOK_SELEXPR", e :: aliasChildren) => + var aliasNames = ArrayBuffer[String]() + aliasChildren.foreach { _ match { + case Token(name, Nil) => aliasNames += cleanIdentifier(name) + case _ => + } + } + Some(MultiAlias(nodeToExpr(e), aliasNames)) + /* Hints are ignored */ case Token("TOK_HINTLIST", _) => None @@ -1034,7 +1057,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C nodeToExpr(qualifier) match { case UnresolvedAttribute(qualifierName) => UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr)) - case other => GetField(other, attr) + case other => UnresolvedGetField(other, attr) } /* Stars (*) */ @@ -1093,6 +1116,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C Cast(nodeToExpr(arg), DateType) /* Arithmetic */ + case Token("+", child :: Nil) => nodeToExpr(child) case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child)) case Token("~", child :: Nil) => BitwiseNot(nodeToExpr(child)) case Token("+", left :: right:: Nil) => Add(nodeToExpr(left), nodeToExpr(right)) @@ -1230,6 +1254,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case ast: ASTNode if ast.getType == HiveParser.TOK_DATELITERAL => Literal(Date.valueOf(ast.getText.substring(1, ast.getText.length - 1))) + case ast: ASTNode if ast.getType == HiveParser.TOK_CHARSETLITERAL => + Literal(BaseSemanticAnalyzer.charSetString(ast.getChild(0).getText, ast.getChild(1).getText)) + case a: ASTNode => throw new NotImplementedError( s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText} : @@ -1268,7 +1295,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C def dumpTree(node: Node, builder: StringBuilder = new StringBuilder, indent: Int = 0) : StringBuilder = { node match { - case a: ASTNode => builder.append((" " * indent) + a.getText + "\n") + case a: ASTNode => builder.append( + (" " * indent) + a.getText + " " + + a.getLine + ", " + + a.getTokenStartIndex + "," + + a.getTokenStopIndex + ", " + + a.getCharPositionInLine + "\n") case other => sys.error(s"Non ASTNode encountered: $other") } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index d89111094b9ff..5f7e897295117 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -29,11 +29,12 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.sources.DescribeCommand import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.sources.{CreateTableUsingAsLogicalPlan, CreateTableUsingAsSelect, CreateTableUsing} +import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, CreateTableUsing} import org.apache.spark.sql.types.StringType @@ -57,9 +58,9 @@ private[hive] trait HiveStrategies { @Experimental object ParquetConversion extends Strategy { implicit class LogicalPlanHacks(s: DataFrame) { - def lowerCase = DataFrame(s.sqlContext, s.logicalPlan) + def lowerCase: DataFrame = DataFrame(s.sqlContext, s.logicalPlan) - def addPartitioningAttributes(attrs: Seq[Attribute]) = { + def addPartitioningAttributes(attrs: Seq[Attribute]): DataFrame = { // Don't add the partitioning key if its already present in the data. if (attrs.map(_.name).toSet.subsetOf(s.logicalPlan.output.map(_.name).toSet)) { s @@ -74,7 +75,7 @@ private[hive] trait HiveStrategies { } implicit class PhysicalPlanHacks(originalPlan: SparkPlan) { - def fakeOutput(newOutput: Seq[Attribute]) = + def fakeOutput(newOutput: Seq[Attribute]): OutputFaker = OutputFaker( originalPlan.output.map(a => newOutput.find(a.name.toLowerCase == _.name.toLowerCase) @@ -86,7 +87,8 @@ private[hive] trait HiveStrategies { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projectList, predicates, relation: MetastoreRelation) if relation.tableDesc.getSerdeClassName.contains("Parquet") && - hiveContext.convertMetastoreParquet => + hiveContext.convertMetastoreParquet && + !hiveContext.conf.parquetUseDataSourceApi => // Filter out all predicates that only deal with partition keys val partitionsKeys = AttributeSet(relation.partitionKeys) @@ -135,15 +137,22 @@ private[hive] trait HiveStrategies { pruningCondition(inputData) } - hiveContext - .parquetFile(partitions.map(_.getLocation).mkString(",")) - .addPartitioningAttributes(relation.partitionKeys) - .lowerCase - .where(unresolvedOtherPredicates) - .select(unresolvedProjection: _*) - .queryExecution - .executedPlan - .fakeOutput(projectList.map(_.toAttribute)) :: Nil + val partitionLocations = partitions.map(_.getLocation) + + if (partitionLocations.isEmpty) { + PhysicalRDD(plan.output, sparkContext.emptyRDD[Row]) :: Nil + } else { + hiveContext + .parquetFile(partitionLocations: _*) + .addPartitioningAttributes(relation.partitionKeys) + .lowerCase + .where(unresolvedOtherPredicates) + .select(unresolvedProjection: _*) + .queryExecution + .executedPlan + .fakeOutput(projectList.map(_.toAttribute)) :: Nil + } + } else { hiveContext .parquetFile(relation.hiveQlTable.getDataLocation.toString) @@ -212,20 +221,15 @@ private[hive] trait HiveStrategies { object HiveDDLStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case CreateTableUsing(tableName, userSpecifiedSchema, provider, false, opts, allowExisting) => + case CreateTableUsing( + tableName, userSpecifiedSchema, provider, false, opts, allowExisting, managedIfNoPath) => ExecutedCommand( CreateMetastoreDataSource( - tableName, userSpecifiedSchema, provider, opts, allowExisting)) :: Nil + tableName, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath)) :: Nil - case CreateTableUsingAsSelect(tableName, provider, false, opts, allowExisting, query) => - val logicalPlan = hiveContext.parseSql(query) + case CreateTableUsingAsSelect(tableName, provider, false, mode, opts, query) => val cmd = - CreateMetastoreDataSourceAsSelect(tableName, provider, opts, allowExisting, logicalPlan) - ExecutedCommand(cmd) :: Nil - - case CreateTableUsingAsLogicalPlan(tableName, provider, false, opts, allowExisting, query) => - val cmd = - CreateMetastoreDataSourceAsSelect(tableName, provider, opts, allowExisting, query) + CreateMetastoreDataSourceAsSelect(tableName, provider, mode, opts, query) ExecutedCommand(cmd) :: Nil case _ => Nil @@ -240,8 +244,11 @@ private[hive] trait HiveStrategies { case t: MetastoreRelation => ExecutedCommand( DescribeHiveTableCommand(t, describe.output, describe.isExtended)) :: Nil + case o: LogicalPlan => - ExecutedCommand(RunnableDescribeCommand(planLater(o), describe.output)) :: Nil + val resultPlan = context.executePlan(o).executedPlan + ExecutedCommand(RunnableDescribeCommand( + resultPlan, describe.output, describe.isExtended)) :: Nil } case _ => Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index c368715f7c6f5..af309c0c6ce2c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -34,6 +34,7 @@ import org.apache.spark.SerializableWritable import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.DateUtils /** * A trait for subclasses that handle table scans. @@ -174,7 +175,7 @@ class HadoopTableReader( relation.partitionKeys.contains(attr) } - def fillPartitionKeys(rawPartValues: Array[String], row: MutableRow) = { + def fillPartitionKeys(rawPartValues: Array[String], row: MutableRow): Unit = { partitionKeyAttrs.foreach { case (attr, ordinal) => val partOrdinal = relation.partitionKeys.indexOf(attr) row(ordinal) = Cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null) @@ -247,7 +248,7 @@ private[hive] object HadoopTableReader extends HiveInspectors { * instantiate a HadoopRDD. */ def initializeLocalJobConfFunc(path: String, tableDesc: TableDesc)(jobConf: JobConf) { - FileInputFormat.setInputPaths(jobConf, path) + FileInputFormat.setInputPaths(jobConf, Seq[Path](new Path(path)): _*) if (tableDesc != null) { PlanUtils.configureInputJobPropertiesForStorageHandler(tableDesc) Utilities.copyTableJobPropertiesToConf(tableDesc, jobConf) @@ -306,7 +307,7 @@ private[hive] object HadoopTableReader extends HiveInspectors { row.update(ordinal, oi.getPrimitiveJavaObject(value).clone()) case oi: DateObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => - row.update(ordinal, oi.getPrimitiveJavaObject(value)) + row.update(ordinal, DateUtils.fromJavaDate(oi.getPrimitiveJavaObject(value))) case oi: BinaryObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, oi.getPrimitiveJavaObject(value)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index a547babcebfff..fade9e5852eaa 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.MetastoreRelation /** - * :: Experimental :: * Create table and insert the query result into it. * @param database the database name of the new relation * @param tableName the table name of the new relation @@ -38,7 +37,7 @@ import org.apache.spark.sql.hive.MetastoreRelation * @param desc the CreateTableDesc, which may contains serde, storage handler etc. */ -@Experimental +private[hive] case class CreateTableAsSelect( database: String, tableName: String, @@ -46,7 +45,7 @@ case class CreateTableAsSelect( allowExisting: Boolean, desc: Option[CreateTableDesc]) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] lazy val metastoreRelation: MetastoreRelation = { // Create Hive Table diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala index bfacc51ef57ab..6fce69b58b85e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala @@ -30,16 +30,14 @@ import org.apache.spark.sql.SQLContext /** * Implementation for "describe [extended] table". - * - * :: DeveloperApi :: */ -@DeveloperApi +private[hive] case class DescribeHiveTableCommand( table: MetastoreRelation, override val output: Seq[Attribute], isExtended: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { // Trying to mimic the format of Hive's output. But not exactly the same. var results: Seq[(String, String, String)] = Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala index 781a2e9164c82..60a9bb630d0d9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala @@ -17,22 +17,18 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row} import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types.StringType -/** - * :: DeveloperApi :: - */ -@DeveloperApi +private[hive] case class HiveNativeCommand(sql: String) extends RunnableCommand { - override def output = + override def output: Seq[AttributeReference] = Seq(AttributeReference("result", StringType, nullable = false)()) - override def run(sqlContext: SQLContext) = + override def run(sqlContext: SQLContext): Seq[Row] = sqlContext.asInstanceOf[HiveContext].runSqlHive(sql).map(Row(_)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index b56175fe76376..0a5f19eee7105 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -26,21 +26,20 @@ import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive._ import org.apache.spark.sql.types.{BooleanType, DataType} /** - * :: DeveloperApi :: * The Hive table scan operator. Column and partition pruning are both handled. * * @param requestedAttributes Attributes to be fetched from the Hive table. * @param relation The Hive table be be scanned. * @param partitionPruningPred An optional partition pruning predicate for partitioned table. */ -@DeveloperApi +private[hive] case class HiveTableScan( requestedAttributes: Seq[Attribute], relation: MetastoreRelation, @@ -130,11 +129,11 @@ case class HiveTableScan( } } - override def execute() = if (!relation.hiveQlTable.isPartitioned) { + override def execute(): RDD[Row] = if (!relation.hiveQlTable.isPartitioned) { hadoopReader.makeRDDForTable(relation.hiveQlTable) } else { hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions)) } - override def output = attributes + override def output: Seq[Attribute] = attributes } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 91af35f0965c0..6c96747439683 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -32,19 +32,15 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.Object import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.{ ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.{SerializableWritable, SparkException, TaskContext} -/** - * :: DeveloperApi :: - */ -@DeveloperApi +private[hive] case class InsertIntoHiveTable( table: MetastoreRelation, partition: Map[String, Option[String]], @@ -54,7 +50,7 @@ case class InsertIntoHiveTable( @transient val sc: HiveContext = sqlContext.asInstanceOf[HiveContext] @transient lazy val outputClass = newSerializer(table.tableDesc).getSerializedClass @transient private lazy val hiveContext = new Context(sc.hiveconf) - @transient private lazy val db = Hive.get(sc.hiveconf) + @transient private lazy val catalog = sc.catalog private def newSerializer(tableDesc: TableDesc): Serializer = { val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer] @@ -62,7 +58,7 @@ case class InsertIntoHiveTable( serializer } - def output = child.output + def output: Seq[Attribute] = child.output def saveAsHiveFile( rdd: RDD[Row], @@ -76,7 +72,6 @@ case class InsertIntoHiveTable( val outputFileFormatClassName = fileSinkConf.getTableInfo.getOutputFileFormatClassName assert(outputFileFormatClassName != null, "Output format class not set") conf.value.set("mapred.output.format.class", outputFileFormatClassName) - conf.value.setOutputCommitter(classOf[FileOutputCommitter]) FileOutputFormat.setOutputPath( conf.value, @@ -204,38 +199,45 @@ case class InsertIntoHiveTable( orderedPartitionSpec.put(entry.getName,partitionSpec.get(entry.getName).getOrElse("")) } val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols, partitionSpec) - db.validatePartitionNameCharacters(partVals) + catalog.synchronized { + catalog.client.validatePartitionNameCharacters(partVals) + } // inheritTableSpecs is set to true. It should be set to false for a IMPORT query // which is currently considered as a Hive native command. val inheritTableSpecs = true // TODO: Correctly set isSkewedStoreAsSubdir. val isSkewedStoreAsSubdir = false if (numDynamicPartitions > 0) { - db.loadDynamicPartitions( - outputPath, - qualifiedTableName, - orderedPartitionSpec, - overwrite, - numDynamicPartitions, - holdDDLTime, - isSkewedStoreAsSubdir - ) + catalog.synchronized { + catalog.client.loadDynamicPartitions( + outputPath, + qualifiedTableName, + orderedPartitionSpec, + overwrite, + numDynamicPartitions, + holdDDLTime, + isSkewedStoreAsSubdir) + } } else { - db.loadPartition( + catalog.synchronized { + catalog.client.loadPartition( + outputPath, + qualifiedTableName, + orderedPartitionSpec, + overwrite, + holdDDLTime, + inheritTableSpecs, + isSkewedStoreAsSubdir) + } + } + } else { + catalog.synchronized { + catalog.client.loadTable( outputPath, qualifiedTableName, - orderedPartitionSpec, overwrite, - holdDDLTime, - inheritTableSpecs, - isSkewedStoreAsSubdir) + holdDDLTime) } - } else { - db.loadTable( - outputPath, - qualifiedTableName, - overwrite, - holdDDLTime) } // Invalidate the cache. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index c54fbb6e24690..8efed7f0299bf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -21,15 +21,13 @@ import java.io.{BufferedReader, InputStreamReader} import java.io.{DataInputStream, DataOutputStream, EOFException} import java.util.Properties +import scala.collection.JavaConversions._ + import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.AbstractSerDe -import org.apache.hadoop.hive.serde2.Serializer -import org.apache.hadoop.hive.serde2.Deserializer import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.execution._ @@ -38,19 +36,14 @@ import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.util.Utils - -/* Implicit conversions */ -import scala.collection.JavaConversions._ - /** - * :: DeveloperApi :: * Transforms the input by forking and running the specified script. * * @param input the set of expression that should be passed to the script. * @param script the command that should be executed. * @param output the attributes that are produced by the script. */ -@DeveloperApi +private[hive] case class ScriptTransformation( input: Seq[Expression], script: String, @@ -59,9 +52,9 @@ case class ScriptTransformation( ioschema: HiveScriptIOSchema)(@transient sc: HiveContext) extends UnaryNode { - override def otherCopyArgs = sc :: Nil + override def otherCopyArgs: Seq[HiveContext] = sc :: Nil - def execute() = { + def execute(): RDD[Row] = { child.execute().mapPartitions { iter => val cmd = List("/bin/bash", "-c", script) val builder = new ProcessBuilder(cmd) @@ -175,6 +168,7 @@ case class ScriptTransformation( /** * The wrapper class of Hive input and output schema properties */ +private[hive] case class HiveScriptIOSchema ( inputRowFormat: Seq[(String, String)], outputRowFormat: Seq[(String, String)], diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 95dcaccefdc54..99dc58646ddd6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.sources.ResolvedDataSource -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.RunnableCommand @@ -27,43 +29,42 @@ import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.types.StructType /** - * :: DeveloperApi :: * Analyzes the given table in the current database to generate statistics, which will be * used in query optimizations. * * Right now, it only supports Hive tables and it only updates the size of a Hive table * in the Hive metastore. */ -@DeveloperApi +private[hive] case class AnalyzeTable(tableName: String) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { sqlContext.asInstanceOf[HiveContext].analyze(tableName) Seq.empty[Row] } } /** - * :: DeveloperApi :: * Drops a table from the metastore and removes it if it is cached. */ -@DeveloperApi +private[hive] case class DropTable( tableName: String, ifExists: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] val ifExistsClause = if (ifExists) "IF EXISTS " else "" try { hiveContext.cacheManager.tryUncacheQuery(hiveContext.table(tableName)) } catch { - // This table's metadata is not in + // This table's metadata is not in Hive metastore (e.g. the table does not exist). case _: org.apache.hadoop.hive.ql.metadata.InvalidTableException => - // Other exceptions can be caused by users providing wrong parameters in OPTIONS + case _: org.apache.spark.sql.catalyst.analysis.NoSuchTableException => + // Other Throwables can be caused by users providing wrong parameters in OPTIONS // (e.g. invalid paths). We catch it and log a warning message. - // Users should be able to drop such kinds of tables regardless if there is an exception. - case e: Exception => log.warn(s"${e.getMessage}") + // Users should be able to drop such kinds of tables regardless if there is an error. + case e: Throwable => log.warn(s"${e.getMessage}", e) } hiveContext.invalidateTable(tableName) hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName") @@ -72,13 +73,10 @@ case class DropTable( } } -/** - * :: DeveloperApi :: - */ -@DeveloperApi +private[hive] case class AddJar(path: String) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] hiveContext.runSqlHive(s"ADD JAR $path") hiveContext.sparkContext.addJar(path) @@ -86,13 +84,10 @@ case class AddJar(path: String) extends RunnableCommand { } } -/** - * :: DeveloperApi :: - */ -@DeveloperApi +private[hive] case class AddFile(path: String) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] hiveContext.runSqlHive(s"ADD FILE $path") hiveContext.sparkContext.addFile(path) @@ -100,12 +95,14 @@ case class AddFile(path: String) extends RunnableCommand { } } +private[hive] case class CreateMetastoreDataSource( tableName: String, userSpecifiedSchema: Option[StructType], provider: String, options: Map[String, String], - allowExisting: Boolean) extends RunnableCommand { + allowExisting: Boolean, + managedIfNoPath: Boolean) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] @@ -114,13 +111,13 @@ case class CreateMetastoreDataSource( if (allowExisting) { return Seq.empty[Row] } else { - sys.error(s"Table $tableName already exists.") + throw new AnalysisException(s"Table $tableName already exists.") } } var isExternal = true val optionsWithPath = - if (!options.contains("path")) { + if (!options.contains("path") && managedIfNoPath) { isExternal = false options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableName)) } else { @@ -138,25 +135,17 @@ case class CreateMetastoreDataSource( } } +private[hive] case class CreateMetastoreDataSourceAsSelect( tableName: String, provider: String, + mode: SaveMode, options: Map[String, String], - allowExisting: Boolean, query: LogicalPlan) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] - - if (hiveContext.catalog.tableExists(tableName :: Nil)) { - if (allowExisting) { - return Seq.empty[Row] - } else { - sys.error(s"Table $tableName already exists.") - } - } - - val df = DataFrame(hiveContext, query) + var createMetastoreTable = false var isExternal = true val optionsWithPath = if (!options.contains("path")) { @@ -166,16 +155,82 @@ case class CreateMetastoreDataSourceAsSelect( options } - // Create the relation based on the data of df. - ResolvedDataSource(sqlContext, provider, optionsWithPath, df) + var existingSchema = None: Option[StructType] + if (sqlContext.catalog.tableExists(Seq(tableName))) { + // Check if we need to throw an exception or just return. + mode match { + case SaveMode.ErrorIfExists => + throw new AnalysisException(s"Table $tableName already exists. " + + s"If you are using saveAsTable, you can set SaveMode to SaveMode.Append to " + + s"insert data into the table or set SaveMode to SaveMode.Overwrite to overwrite" + + s"the existing data. " + + s"Or, if you are using SQL CREATE TABLE, you need to drop $tableName first.") + case SaveMode.Ignore => + // Since the table already exists and the save mode is Ignore, we will just return. + return Seq.empty[Row] + case SaveMode.Append => + // Check if the specified data source match the data source of the existing table. + val resolved = + ResolvedDataSource(sqlContext, Some(query.schema), provider, optionsWithPath) + val createdRelation = LogicalRelation(resolved.relation) + EliminateSubQueries(sqlContext.table(tableName).logicalPlan) match { + case l @ LogicalRelation(i: InsertableRelation) => + if (i != createdRelation.relation) { + val errorDescription = + s"Cannot append to table $tableName because the resolved relation does not " + + s"match the existing relation of $tableName. " + + s"You can use insertInto($tableName, false) to append this DataFrame to the " + + s"table $tableName and using its data source and options." + val errorMessage = + s""" + |$errorDescription + |== Relations == + |${sideBySide( + s"== Expected Relation ==" :: + l.toString :: Nil, + s"== Actual Relation ==" :: + createdRelation.toString :: Nil).mkString("\n")} + """.stripMargin + throw new AnalysisException(errorMessage) + } + existingSchema = Some(l.schema) + case o => + throw new AnalysisException(s"Saving data in ${o.toString} is not supported.") + } + case SaveMode.Overwrite => + hiveContext.sql(s"DROP TABLE IF EXISTS $tableName") + // Need to create the table again. + createMetastoreTable = true + } + } else { + // The table does not exist. We need to create it in metastore. + createMetastoreTable = true + } - hiveContext.catalog.createDataSourceTable( - tableName, - None, - provider, - optionsWithPath, - isExternal) + val data = DataFrame(hiveContext, query) + val df = existingSchema match { + // If we are inserting into an existing table, just use the existing schema. + case Some(schema) => sqlContext.createDataFrame(data.queryExecution.toRdd, schema) + case None => data + } + + // Create the relation based on the data of df. + val resolved = ResolvedDataSource(sqlContext, provider, mode, optionsWithPath, df) + + if (createMetastoreTable) { + // We will use the schema of resolved.relation as the schema of the table (instead of + // the schema of df). It is important since the nullability may be changed by the relation + // provider (for example, see org.apache.spark.sql.parquet.DefaultSource). + hiveContext.catalog.createDataSourceTable( + tableName, + Some(resolved.relation.schema), + provider, + optionsWithPath, + isExternal) + } + // Refresh the cache of the table in the catalog. + hiveContext.refreshTable(tableName) Seq.empty[Row] } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 76d2140372197..7df964fa9cead 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -33,8 +33,11 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{Generate, Project, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils.getContextOrSparkClassLoader +import org.apache.spark.sql.catalyst.analysis.MultiAlias +import org.apache.spark.sql.catalyst.errors.TreeNodeException /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -42,7 +45,7 @@ import scala.collection.JavaConversions._ private[hive] abstract class HiveFunctionRegistry extends analysis.FunctionRegistry with HiveInspectors { - def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name) + def getFunctionInfo(name: String): FunctionInfo = FunctionRegistry.getFunctionInfo(name) def lookupFunction(name: String, children: Seq[Expression]): Expression = { // We only look it up to see if it exists, but do not include it in the HiveUDF since it is @@ -75,7 +78,7 @@ private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, childre type EvaluatedType = Any type UDFType = UDF - def nullable = true + override def nullable: Boolean = true @transient lazy val function = funcWrapper.createFunction[UDFType]() @@ -93,7 +96,7 @@ private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, childre udfType != null && udfType.deterministic() } - override def foldable = isUDFDeterministic && children.forall(_.foldable) + override def foldable: Boolean = isUDFDeterministic && children.forall(_.foldable) // Create parameter converters @transient @@ -107,7 +110,7 @@ private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, childre method.getGenericReturnType(), ObjectInspectorOptions.JAVA) @transient - protected lazy val cached = new Array[AnyRef](children.length) + protected lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) // TODO: Finish input output types. override def eval(input: Row): Any = { @@ -117,17 +120,19 @@ private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, childre returnInspector) } - override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + override def toString: String = { + s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + } } // Adapter from Catalyst ExpressionResult to Hive DeferredObject private[hive] class DeferredObjectAdapter(oi: ObjectInspector) extends DeferredObject with HiveInspectors { private var func: () => Any = _ - def set(func: () => Any) { + def set(func: () => Any): Unit = { this.func = func } - override def prepare(i: Int) = {} + override def prepare(i: Int): Unit = {} override def get(): AnyRef = wrap(func(), oi) } @@ -136,7 +141,7 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr type UDFType = GenericUDF type EvaluatedType = Any - def nullable = true + override def nullable: Boolean = true @transient lazy val function = funcWrapper.createFunction[UDFType]() @@ -155,7 +160,7 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr (udfType != null && udfType.deterministic()) } - override def foldable = + override def foldable: Boolean = isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector] @transient @@ -179,7 +184,9 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr unwrap(function.evaluate(deferedObjects), returnInspector) } - override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + override def toString: String = { + s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + } } private[hive] case class HiveGenericUdaf( @@ -206,9 +213,11 @@ private[hive] case class HiveGenericUdaf( def nullable: Boolean = true - override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + override def toString: String = { + s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + } - def newInstance() = new HiveUdafFunction(funcWrapper, children, this) + def newInstance(): HiveUdafFunction = new HiveUdafFunction(funcWrapper, children, this) } /** It is used as a wrapper for the hive functions which uses UDAF interface */ @@ -237,10 +246,11 @@ private[hive] case class HiveUdaf( def nullable: Boolean = true - override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + override def toString: String = { + s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + } - def newInstance() = - new HiveUdafFunction(funcWrapper, children, this, true) + def newInstance(): HiveUdafFunction = new HiveUdafFunction(funcWrapper, children, this, true) } /** @@ -311,14 +321,30 @@ private[hive] case class HiveGenericUdtf( collected += unwrap(input, outputInspector).asInstanceOf[Row] } - def collectRows() = { + def collectRows(): Seq[Row] = { val toCollect = collected collected = new ArrayBuffer[Row] toCollect } } - override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + override def toString: String = { + s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + } +} + +/** + * Resolve Udtfs Alias. + */ +private[spark] object ResolveUdtfsAlias extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case p @ Project(projectList, _) + if projectList.exists(_.isInstanceOf[MultiAlias]) && projectList.size != 1 => + throw new TreeNodeException(p, "only single Generator supported for SELECT clause") + + case Project(Seq(MultiAlias(udtf @ HiveGenericUdtf(_, _, _), names)), child) => + Generate(udtf.copy(aliasNames = names), join = false, outer = false, None, child) + } } private[hive] case class HiveUdafFunction( @@ -347,9 +373,8 @@ private[hive] case class HiveUdafFunction( private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) - // Cast required to avoid type inference selecting a deprecated Hive API. private val buffer = - function.getNewAggregationBuffer.asInstanceOf[GenericUDAFEvaluator.AbstractAggregationBuffer] + function.getNewAggregationBuffer override def eval(input: Row): Any = unwrap(function.evaluate(buffer), returnInspector) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index aae175e426ade..8398da268174d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.hive -import java.io.IOException import java.text.NumberFormat import java.util.Date @@ -30,6 +29,7 @@ import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat} import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred._ +import org.apache.hadoop.hive.common.FileUtils import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.Row @@ -117,19 +117,7 @@ private[hive] class SparkHiveWriterContainer( } protected def commit() { - if (committer.needsTaskCommit(taskContext)) { - try { - committer.commitTask(taskContext) - logInfo (taID + ": Committed") - } catch { - case e: IOException => - logError("Error committing the output of task: " + taID.value, e) - committer.abortTask(taskContext) - throw e - } - } else { - logInfo("No need to commit output of task: " + taID.value) - } + SparkHadoopMapRedUtil.commitTask(committer, taskContext, jobID, splitID, attemptID) } private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { @@ -212,11 +200,16 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( .zip(row.toSeq.takeRight(dynamicPartColNames.length)) .map { case (col, rawVal) => val string = if (rawVal == null) null else String.valueOf(rawVal) - s"/$col=${if (string == null || string.isEmpty) defaultPartName else string}" - } - .mkString - - def newWriter = { + val colString = + if (string == null || string.isEmpty) { + defaultPartName + } else { + FileUtils.escapePathName(string) + } + s"/$col=$colString" + }.mkString + + def newWriter(): FileSinkOperator.RecordWriter = { val newFileSinkDesc = new FileSinkDesc( fileSinkConf.getDirName + dynamicPartPath, fileSinkConf.getTableInfo, @@ -240,6 +233,6 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( Reporter.NULL) } - writers.getOrElseUpdate(dynamicPartPath, newWriter) + writers.getOrElseUpdate(dynamicPartPath, newWriter()) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/package.scala index a6c8ed4f7e866..db074361ef03c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/package.scala @@ -17,4 +17,14 @@ package org.apache.spark.sql +/** + * Support for running Spark SQL queries using functionality from Apache Hive (does not require an + * existing Hive installation). Supported Hive features include: + * - Using HiveQL to express queries. + * - Reading metadata from the Hive Metastore using HiveSerDes. + * - Hive UDFs, UDAs, UDTs + * + * Users that would like access to this functionality should create a + * [[hive.HiveContext HiveContext]] instead of a [[SQLContext]]. + */ package object hive diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala deleted file mode 100644 index 2a16c9d1a27c9..0000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.parquet - -import java.util.Properties - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category -import org.apache.hadoop.hive.serde2.{SerDeStats, SerDe} -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector -import org.apache.hadoop.io.Writable - -/** - * A placeholder that allows Spark SQL users to create metastore tables that are stored as - * parquet files. It is only intended to pass the checks that the serde is valid and exists - * when a CREATE TABLE is run. The actual work of decoding will be done by ParquetTableScan - * when "spark.sql.hive.convertMetastoreParquet" is set to true. - */ -@deprecated("No code should depend on FakeParquetHiveSerDe as it is only intended as a " + - "placeholder in the Hive MetaStore", "1.2.0") -class FakeParquetSerDe extends SerDe { - override def getObjectInspector: ObjectInspector = new ObjectInspector { - override def getCategory: Category = Category.PRIMITIVE - - override def getTypeName: String = "string" - } - - override def deserialize(p1: Writable): AnyRef = throwError - - override def initialize(p1: Configuration, p2: Properties): Unit = {} - - override def getSerializedClass: Class[_ <: Writable] = throwError - - override def getSerDeStats: SerDeStats = throwError - - override def serialize(p1: scala.Any, p2: ObjectInspector): Writable = throwError - - private def throwError = - sys.error( - "spark.sql.hive.convertMetastoreParquet must be set to true to use FakeParquetSerDe") -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala similarity index 97% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 7c1d1133c3425..3380ef0475cdb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -20,9 +20,6 @@ package org.apache.spark.sql.hive.test import java.io.File import java.util.{Set => JavaSet} -import scala.collection.mutable -import scala.language.implicitConversions - import org.apache.hadoop.hive.ql.exec.FunctionRegistry import org.apache.hadoop.hive.ql.io.avro.{AvroContainerInputFormat, AvroContainerOutputFormat} import org.apache.hadoop.hive.ql.metadata.Table @@ -30,16 +27,18 @@ import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.serde2.RegexSerDe import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.hive.serde2.avro.AvroSerDe - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.util.Utils +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.CacheTableCommand import org.apache.spark.sql.hive._ -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.execution.HiveNativeCommand +import org.apache.spark.util.Utils +import org.apache.spark.{SparkConf, SparkContext} + +import scala.collection.mutable +import scala.language.implicitConversions /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -154,8 +153,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { protected[hive] class HiveQLQueryExecution(hql: String) extends this.QueryExecution(HiveQl.parseSql(hql)) { - def hiveExec() = runSqlHive(hql) - override def toString = hql + "\n" + super.toString + def hiveExec(): Seq[String] = runSqlHive(hql) + override def toString: String = hql + "\n" + super.toString } /** @@ -185,7 +184,9 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { case class TestTable(name: String, commands: (()=>Unit)*) protected[hive] implicit class SqlCmd(sql: String) { - def cmd = () => new HiveQLQueryExecution(sql).stringResult(): Unit + def cmd: () => Unit = { + () => new HiveQLQueryExecution(sql).stringResult(): Unit + } } /** @@ -193,10 +194,14 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { * demand when a query are run against it. */ lazy val testTables = new mutable.HashMap[String, TestTable]() - def registerTestTable(testTable: TestTable) = testTables += (testTable.name -> testTable) + + def registerTestTable(testTable: TestTable): Unit = { + testTables += (testTable.name -> testTable) + } // The test tables that are defined in the Hive QTestUtil. // /itests/util/src/main/java/org/apache/hadoop/hive/ql/QTestUtil.java + // https://github.com/apache/hive/blob/branch-0.13/data/scripts/q_test_init.sql val hiveQTestUtilTables = Seq( TestTable("src", "CREATE TABLE src (key INT, value STRING)".cmd, @@ -224,11 +229,10 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { } }), TestTable("src_thrift", () => { - import org.apache.thrift.protocol.TBinaryProtocol - import org.apache.hadoop.hive.serde2.thrift.test.Complex import org.apache.hadoop.hive.serde2.thrift.ThriftDeserializer - import org.apache.hadoop.mapred.SequenceFileInputFormat - import org.apache.hadoop.mapred.SequenceFileOutputFormat + import org.apache.hadoop.hive.serde2.thrift.test.Complex + import org.apache.hadoop.mapred.{SequenceFileInputFormat, SequenceFileOutputFormat} + import org.apache.thrift.protocol.TBinaryProtocol val srcThrift = new Table("default", "src_thrift") srcThrift.setFields(Nil) diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java new file mode 100644 index 0000000000000..53ddecf57958b --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.hive; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.spark.sql.SaveMode; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.QueryTest$; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.hive.test.TestHive$; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.util.Utils; + +public class JavaMetastoreDataSourcesSuite { + private transient JavaSparkContext sc; + private transient HiveContext sqlContext; + + String originalDefaultSource; + File path; + Path hiveManagedPath; + FileSystem fs; + DataFrame df; + + private void checkAnswer(DataFrame actual, List expected) { + String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); + if (errorMessage != null) { + Assert.fail(errorMessage); + } + } + + @Before + public void setUp() throws IOException { + sqlContext = TestHive$.MODULE$; + sc = new JavaSparkContext(sqlContext.sparkContext()); + + originalDefaultSource = sqlContext.conf().defaultDataSourceName(); + path = + Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile(); + if (path.exists()) { + path.delete(); + } + hiveManagedPath = new Path(sqlContext.catalog().hiveDefaultTableFilePath("javaSavedTable")); + fs = hiveManagedPath.getFileSystem(sc.hadoopConfiguration()); + if (fs.exists(hiveManagedPath)){ + fs.delete(hiveManagedPath, true); + } + + List jsonObjects = new ArrayList(10); + for (int i = 0; i < 10; i++) { + jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); + } + JavaRDD rdd = sc.parallelize(jsonObjects); + df = sqlContext.jsonRDD(rdd); + df.registerTempTable("jsonTable"); + } + + @After + public void tearDown() throws IOException { + // Clean up tables. + sqlContext.sql("DROP TABLE IF EXISTS javaSavedTable"); + sqlContext.sql("DROP TABLE IF EXISTS externalTable"); + } + + @Test + public void saveExternalTableAndQueryIt() { + Map options = new HashMap(); + options.put("path", path.toString()); + df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options); + + checkAnswer( + sqlContext.sql("SELECT * FROM javaSavedTable"), + df.collectAsList()); + + DataFrame loadedDF = + sqlContext.createExternalTable("externalTable", "org.apache.spark.sql.json", options); + + checkAnswer(loadedDF, df.collectAsList()); + checkAnswer( + sqlContext.sql("SELECT * FROM externalTable"), + df.collectAsList()); + } + + @Test + public void saveExternalTableWithSchemaAndQueryIt() { + Map options = new HashMap(); + options.put("path", path.toString()); + df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options); + + checkAnswer( + sqlContext.sql("SELECT * FROM javaSavedTable"), + df.collectAsList()); + + List fields = new ArrayList(); + fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); + StructType schema = DataTypes.createStructType(fields); + DataFrame loadedDF = + sqlContext.createExternalTable("externalTable", "org.apache.spark.sql.json", schema, options); + + checkAnswer( + loadedDF, + sqlContext.sql("SELECT b FROM javaSavedTable").collectAsList()); + checkAnswer( + sqlContext.sql("SELECT * FROM externalTable"), + sqlContext.sql("SELECT b FROM javaSavedTable").collectAsList()); + } + + @Test + public void saveTableAndQueryIt() { + Map options = new HashMap(); + df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options); + + checkAnswer( + sqlContext.sql("SELECT * FROM javaSavedTable"), + df.collectAsList()); + } +} diff --git a/sql/hive/src/test/resources/golden/Date cast-0-a7cd69b80c77a771a2c955db666be53d b/sql/hive/src/test/resources/golden/Date cast-0-a7cd69b80c77a771a2c955db666be53d new file mode 100644 index 0000000000000..98da82fa89386 --- /dev/null +++ b/sql/hive/src/test/resources/golden/Date cast-0-a7cd69b80c77a771a2c955db666be53d @@ -0,0 +1 @@ +1970-01-01 1970-01-01 1969-12-31 16:00:00 1969-12-31 16:00:00 1970-01-01 00:00:00 diff --git a/sql/hive/src/test/resources/golden/Date comparison test 1-0-bde89be08a12361073ff658fef768b7e b/sql/hive/src/test/resources/golden/Date comparison test 1-0-bde89be08a12361073ff658fef768b7e new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/Date comparison test 1-0-bde89be08a12361073ff658fef768b7e @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/Date comparison test 2-0-dc1b267f1d79d49e6675afe4fd2a34a5 b/sql/hive/src/test/resources/golden/Date comparison test 2-0-dc1b267f1d79d49e6675afe4fd2a34a5 new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/Date comparison test 2-0-dc1b267f1d79d49e6675afe4fd2a34a5 @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/date_1-0-50131c0ba7b7a6b65c789a5a8497bada b/sql/hive/src/test/resources/golden/date_1-0-50131c0ba7b7a6b65c789a5a8497bada new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-0-50131c0ba7b7a6b65c789a5a8497bada @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/date_1-0-23edf29bf7376c70d5ecf12720f4b1eb b/sql/hive/src/test/resources/golden/date_1-1-23edf29bf7376c70d5ecf12720f4b1eb similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-0-23edf29bf7376c70d5ecf12720f4b1eb rename to sql/hive/src/test/resources/golden/date_1-1-23edf29bf7376c70d5ecf12720f4b1eb diff --git a/sql/hive/src/test/resources/golden/date_1-3-df16364a220ff96a6ea1cd478cbc1d0b b/sql/hive/src/test/resources/golden/date_1-10-df16364a220ff96a6ea1cd478cbc1d0b similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-3-df16364a220ff96a6ea1cd478cbc1d0b rename to sql/hive/src/test/resources/golden/date_1-10-df16364a220ff96a6ea1cd478cbc1d0b diff --git a/sql/hive/src/test/resources/golden/date_1-10-d964bec7e5632091ab5cb6f6786dbbf9 b/sql/hive/src/test/resources/golden/date_1-11-d964bec7e5632091ab5cb6f6786dbbf9 similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-10-d964bec7e5632091ab5cb6f6786dbbf9 rename to sql/hive/src/test/resources/golden/date_1-11-d964bec7e5632091ab5cb6f6786dbbf9 diff --git a/sql/hive/src/test/resources/golden/date_1-11-480c5f024a28232b7857be327c992509 b/sql/hive/src/test/resources/golden/date_1-12-480c5f024a28232b7857be327c992509 similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-11-480c5f024a28232b7857be327c992509 rename to sql/hive/src/test/resources/golden/date_1-12-480c5f024a28232b7857be327c992509 diff --git a/sql/hive/src/test/resources/golden/date_1-12-4c0ed7fcb75770d8790575b586bf14f4 b/sql/hive/src/test/resources/golden/date_1-13-4c0ed7fcb75770d8790575b586bf14f4 similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-12-4c0ed7fcb75770d8790575b586bf14f4 rename to sql/hive/src/test/resources/golden/date_1-13-4c0ed7fcb75770d8790575b586bf14f4 diff --git a/sql/hive/src/test/resources/golden/date_1-13-44fc74c1993062c0a9522199ff27fea b/sql/hive/src/test/resources/golden/date_1-14-44fc74c1993062c0a9522199ff27fea similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-13-44fc74c1993062c0a9522199ff27fea rename to sql/hive/src/test/resources/golden/date_1-14-44fc74c1993062c0a9522199ff27fea diff --git a/sql/hive/src/test/resources/golden/date_1-14-4855a66124b16d1d0d003235995ac06b b/sql/hive/src/test/resources/golden/date_1-15-4855a66124b16d1d0d003235995ac06b similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-14-4855a66124b16d1d0d003235995ac06b rename to sql/hive/src/test/resources/golden/date_1-15-4855a66124b16d1d0d003235995ac06b diff --git a/sql/hive/src/test/resources/golden/date_1-15-8bc190dba0f641840b5e1e198a14c55b b/sql/hive/src/test/resources/golden/date_1-16-8bc190dba0f641840b5e1e198a14c55b similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-15-8bc190dba0f641840b5e1e198a14c55b rename to sql/hive/src/test/resources/golden/date_1-16-8bc190dba0f641840b5e1e198a14c55b diff --git a/sql/hive/src/test/resources/golden/date_1-1-4ebe3571c13a8b0c03096fbd972b7f1b b/sql/hive/src/test/resources/golden/date_1-17-23edf29bf7376c70d5ecf12720f4b1eb similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-1-4ebe3571c13a8b0c03096fbd972b7f1b rename to sql/hive/src/test/resources/golden/date_1-17-23edf29bf7376c70d5ecf12720f4b1eb diff --git a/sql/hive/src/test/resources/golden/date_1-16-23edf29bf7376c70d5ecf12720f4b1eb b/sql/hive/src/test/resources/golden/date_1-2-4ebe3571c13a8b0c03096fbd972b7f1b similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-16-23edf29bf7376c70d5ecf12720f4b1eb rename to sql/hive/src/test/resources/golden/date_1-2-4ebe3571c13a8b0c03096fbd972b7f1b diff --git a/sql/hive/src/test/resources/golden/date_1-2-abdce0c0d14d3fc7441b7c134b02f99a b/sql/hive/src/test/resources/golden/date_1-3-26b5c291400dfde455b3c1b878b71d0 similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-2-abdce0c0d14d3fc7441b7c134b02f99a rename to sql/hive/src/test/resources/golden/date_1-3-26b5c291400dfde455b3c1b878b71d0 diff --git a/sql/hive/src/test/resources/golden/date_1-6-df16364a220ff96a6ea1cd478cbc1d0b b/sql/hive/src/test/resources/golden/date_1-4-df16364a220ff96a6ea1cd478cbc1d0b similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-6-df16364a220ff96a6ea1cd478cbc1d0b rename to sql/hive/src/test/resources/golden/date_1-4-df16364a220ff96a6ea1cd478cbc1d0b diff --git a/sql/hive/src/test/resources/golden/date_1-4-d964bec7e5632091ab5cb6f6786dbbf9 b/sql/hive/src/test/resources/golden/date_1-5-d964bec7e5632091ab5cb6f6786dbbf9 similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-4-d964bec7e5632091ab5cb6f6786dbbf9 rename to sql/hive/src/test/resources/golden/date_1-5-d964bec7e5632091ab5cb6f6786dbbf9 diff --git a/sql/hive/src/test/resources/golden/date_1-5-5e70fc74158fbfca38134174360de12d b/sql/hive/src/test/resources/golden/date_1-6-559d01fb0b42c42f0c4927fa0f9deac4 similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-5-5e70fc74158fbfca38134174360de12d rename to sql/hive/src/test/resources/golden/date_1-6-559d01fb0b42c42f0c4927fa0f9deac4 diff --git a/sql/hive/src/test/resources/golden/date_1-9-df16364a220ff96a6ea1cd478cbc1d0b b/sql/hive/src/test/resources/golden/date_1-7-df16364a220ff96a6ea1cd478cbc1d0b similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-9-df16364a220ff96a6ea1cd478cbc1d0b rename to sql/hive/src/test/resources/golden/date_1-7-df16364a220ff96a6ea1cd478cbc1d0b diff --git a/sql/hive/src/test/resources/golden/date_1-7-d964bec7e5632091ab5cb6f6786dbbf9 b/sql/hive/src/test/resources/golden/date_1-8-d964bec7e5632091ab5cb6f6786dbbf9 similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-7-d964bec7e5632091ab5cb6f6786dbbf9 rename to sql/hive/src/test/resources/golden/date_1-8-d964bec7e5632091ab5cb6f6786dbbf9 diff --git a/sql/hive/src/test/resources/golden/date_1-8-1d5c58095cd52ea539d869f2ab1ab67d b/sql/hive/src/test/resources/golden/date_1-9-8306558e0eabe936ac33dabaaa17fea4 similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-8-1d5c58095cd52ea539d869f2ab1ab67d rename to sql/hive/src/test/resources/golden/date_1-9-8306558e0eabe936ac33dabaaa17fea4 diff --git a/sql/hive/src/test/resources/golden/inputddl5-0-ebbf2aec5f76af7225c2efaf870b8ba7 b/sql/hive/src/test/resources/golden/inputddl5-0-ebbf2aec5f76af7225c2efaf870b8ba7 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/inputddl5-1-2691407ccdc5c848a4ba2aecb6dbad75 b/sql/hive/src/test/resources/golden/inputddl5-1-2691407ccdc5c848a4ba2aecb6dbad75 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/inputddl5-2-ca2faacf63dc4785f8bfd2ecc397e69b b/sql/hive/src/test/resources/golden/inputddl5-2-ca2faacf63dc4785f8bfd2ecc397e69b new file mode 100644 index 0000000000000..518a70918b2c7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/inputddl5-2-ca2faacf63dc4785f8bfd2ecc397e69b @@ -0,0 +1 @@ +name string diff --git a/sql/hive/src/test/resources/golden/inputddl5-3-4f28c7412a05cff89c0bd86b65aa7ce b/sql/hive/src/test/resources/golden/inputddl5-3-4f28c7412a05cff89c0bd86b65aa7ce new file mode 100644 index 0000000000000..33398360345d7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/inputddl5-3-4f28c7412a05cff89c0bd86b65aa7ce @@ -0,0 +1 @@ +邵铮 diff --git a/sql/hive/src/test/resources/golden/inputddl5-4-bd7e25cff73f470d2e2336876342b783 b/sql/hive/src/test/resources/golden/inputddl5-4-bd7e25cff73f470d2e2336876342b783 new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/inputddl5-4-bd7e25cff73f470d2e2336876342b783 @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/merge4-0-b12e5c70d6d29757471b900b6160fa8a b/sql/hive/src/test/resources/golden/merge4-0-b12e5c70d6d29757471b900b6160fa8a new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/merge4-0-b12e5c70d6d29757471b900b6160fa8a @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/merge4-1-593999fae618b6b38322bc9ae4e0c027 b/sql/hive/src/test/resources/golden/merge4-1-593999fae618b6b38322bc9ae4e0c027 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/merge4-1-593999fae618b6b38322bc9ae4e0c027 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/merge4-10-692a197bd688b48f762e72978f54aa32 b/sql/hive/src/test/resources/golden/merge4-10-692a197bd688b48f762e72978f54aa32 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/merge4-11-f407e661307b23a5d52a08a3e7af19b b/sql/hive/src/test/resources/golden/merge4-11-f407e661307b23a5d52a08a3e7af19b new file mode 100644 index 0000000000000..5d2cddc42f272 --- /dev/null +++ b/sql/hive/src/test/resources/golden/merge4-11-f407e661307b23a5d52a08a3e7af19b @@ -0,0 +1,1500 @@ +0 val_0 2010-08-15 11 +0 val_0 2010-08-15 11 +0 val_0 2010-08-15 11 +0 val_0 2010-08-15 11 +0 val_0 2010-08-15 11 +0 val_0 2010-08-15 11 +0 val_0 2010-08-15 12 +0 val_0 2010-08-15 12 +0 val_0 2010-08-15 12 +2 val_2 2010-08-15 11 +2 val_2 2010-08-15 11 +2 val_2 2010-08-15 12 +4 val_4 2010-08-15 11 +4 val_4 2010-08-15 11 +4 val_4 2010-08-15 12 +5 val_5 2010-08-15 11 +5 val_5 2010-08-15 11 +5 val_5 2010-08-15 11 +5 val_5 2010-08-15 11 +5 val_5 2010-08-15 11 +5 val_5 2010-08-15 11 +5 val_5 2010-08-15 12 +5 val_5 2010-08-15 12 +5 val_5 2010-08-15 12 +8 val_8 2010-08-15 11 +8 val_8 2010-08-15 11 +8 val_8 2010-08-15 12 +9 val_9 2010-08-15 11 +9 val_9 2010-08-15 11 +9 val_9 2010-08-15 12 +10 val_10 2010-08-15 11 +10 val_10 2010-08-15 11 +10 val_10 2010-08-15 12 +11 val_11 2010-08-15 11 +11 val_11 2010-08-15 11 +11 val_11 2010-08-15 12 +12 val_12 2010-08-15 11 +12 val_12 2010-08-15 11 +12 val_12 2010-08-15 11 +12 val_12 2010-08-15 11 +12 val_12 2010-08-15 12 +12 val_12 2010-08-15 12 +15 val_15 2010-08-15 11 +15 val_15 2010-08-15 11 +15 val_15 2010-08-15 11 +15 val_15 2010-08-15 11 +15 val_15 2010-08-15 12 +15 val_15 2010-08-15 12 +17 val_17 2010-08-15 11 +17 val_17 2010-08-15 11 +17 val_17 2010-08-15 12 +18 val_18 2010-08-15 11 +18 val_18 2010-08-15 11 +18 val_18 2010-08-15 11 +18 val_18 2010-08-15 11 +18 val_18 2010-08-15 12 +18 val_18 2010-08-15 12 +19 val_19 2010-08-15 11 +19 val_19 2010-08-15 11 +19 val_19 2010-08-15 12 +20 val_20 2010-08-15 11 +20 val_20 2010-08-15 11 +20 val_20 2010-08-15 12 +24 val_24 2010-08-15 11 +24 val_24 2010-08-15 11 +24 val_24 2010-08-15 11 +24 val_24 2010-08-15 11 +24 val_24 2010-08-15 12 +24 val_24 2010-08-15 12 +26 val_26 2010-08-15 11 +26 val_26 2010-08-15 11 +26 val_26 2010-08-15 11 +26 val_26 2010-08-15 11 +26 val_26 2010-08-15 12 +26 val_26 2010-08-15 12 +27 val_27 2010-08-15 11 +27 val_27 2010-08-15 11 +27 val_27 2010-08-15 12 +28 val_28 2010-08-15 11 +28 val_28 2010-08-15 11 +28 val_28 2010-08-15 12 +30 val_30 2010-08-15 11 +30 val_30 2010-08-15 11 +30 val_30 2010-08-15 12 +33 val_33 2010-08-15 11 +33 val_33 2010-08-15 11 +33 val_33 2010-08-15 12 +34 val_34 2010-08-15 11 +34 val_34 2010-08-15 11 +34 val_34 2010-08-15 12 +35 val_35 2010-08-15 11 +35 val_35 2010-08-15 11 +35 val_35 2010-08-15 11 +35 val_35 2010-08-15 11 +35 val_35 2010-08-15 11 +35 val_35 2010-08-15 11 +35 val_35 2010-08-15 12 +35 val_35 2010-08-15 12 +35 val_35 2010-08-15 12 +37 val_37 2010-08-15 11 +37 val_37 2010-08-15 11 +37 val_37 2010-08-15 11 +37 val_37 2010-08-15 11 +37 val_37 2010-08-15 12 +37 val_37 2010-08-15 12 +41 val_41 2010-08-15 11 +41 val_41 2010-08-15 11 +41 val_41 2010-08-15 12 +42 val_42 2010-08-15 11 +42 val_42 2010-08-15 11 +42 val_42 2010-08-15 11 +42 val_42 2010-08-15 11 +42 val_42 2010-08-15 12 +42 val_42 2010-08-15 12 +43 val_43 2010-08-15 11 +43 val_43 2010-08-15 11 +43 val_43 2010-08-15 12 +44 val_44 2010-08-15 11 +44 val_44 2010-08-15 11 +44 val_44 2010-08-15 12 +47 val_47 2010-08-15 11 +47 val_47 2010-08-15 11 +47 val_47 2010-08-15 12 +51 val_51 2010-08-15 11 +51 val_51 2010-08-15 11 +51 val_51 2010-08-15 11 +51 val_51 2010-08-15 11 +51 val_51 2010-08-15 12 +51 val_51 2010-08-15 12 +53 val_53 2010-08-15 11 +53 val_53 2010-08-15 11 +53 val_53 2010-08-15 12 +54 val_54 2010-08-15 11 +54 val_54 2010-08-15 11 +54 val_54 2010-08-15 12 +57 val_57 2010-08-15 11 +57 val_57 2010-08-15 11 +57 val_57 2010-08-15 12 +58 val_58 2010-08-15 11 +58 val_58 2010-08-15 11 +58 val_58 2010-08-15 11 +58 val_58 2010-08-15 11 +58 val_58 2010-08-15 12 +58 val_58 2010-08-15 12 +64 val_64 2010-08-15 11 +64 val_64 2010-08-15 11 +64 val_64 2010-08-15 12 +65 val_65 2010-08-15 11 +65 val_65 2010-08-15 11 +65 val_65 2010-08-15 12 +66 val_66 2010-08-15 11 +66 val_66 2010-08-15 11 +66 val_66 2010-08-15 12 +67 val_67 2010-08-15 11 +67 val_67 2010-08-15 11 +67 val_67 2010-08-15 11 +67 val_67 2010-08-15 11 +67 val_67 2010-08-15 12 +67 val_67 2010-08-15 12 +69 val_69 2010-08-15 11 +69 val_69 2010-08-15 11 +69 val_69 2010-08-15 12 +70 val_70 2010-08-15 11 +70 val_70 2010-08-15 11 +70 val_70 2010-08-15 11 +70 val_70 2010-08-15 11 +70 val_70 2010-08-15 11 +70 val_70 2010-08-15 11 +70 val_70 2010-08-15 12 +70 val_70 2010-08-15 12 +70 val_70 2010-08-15 12 +72 val_72 2010-08-15 11 +72 val_72 2010-08-15 11 +72 val_72 2010-08-15 11 +72 val_72 2010-08-15 11 +72 val_72 2010-08-15 12 +72 val_72 2010-08-15 12 +74 val_74 2010-08-15 11 +74 val_74 2010-08-15 11 +74 val_74 2010-08-15 12 +76 val_76 2010-08-15 11 +76 val_76 2010-08-15 11 +76 val_76 2010-08-15 11 +76 val_76 2010-08-15 11 +76 val_76 2010-08-15 12 +76 val_76 2010-08-15 12 +77 val_77 2010-08-15 11 +77 val_77 2010-08-15 11 +77 val_77 2010-08-15 12 +78 val_78 2010-08-15 11 +78 val_78 2010-08-15 11 +78 val_78 2010-08-15 12 +80 val_80 2010-08-15 11 +80 val_80 2010-08-15 11 +80 val_80 2010-08-15 12 +82 val_82 2010-08-15 11 +82 val_82 2010-08-15 11 +82 val_82 2010-08-15 12 +83 val_83 2010-08-15 11 +83 val_83 2010-08-15 11 +83 val_83 2010-08-15 11 +83 val_83 2010-08-15 11 +83 val_83 2010-08-15 12 +83 val_83 2010-08-15 12 +84 val_84 2010-08-15 11 +84 val_84 2010-08-15 11 +84 val_84 2010-08-15 11 +84 val_84 2010-08-15 11 +84 val_84 2010-08-15 12 +84 val_84 2010-08-15 12 +85 val_85 2010-08-15 11 +85 val_85 2010-08-15 11 +85 val_85 2010-08-15 12 +86 val_86 2010-08-15 11 +86 val_86 2010-08-15 11 +86 val_86 2010-08-15 12 +87 val_87 2010-08-15 11 +87 val_87 2010-08-15 11 +87 val_87 2010-08-15 12 +90 val_90 2010-08-15 11 +90 val_90 2010-08-15 11 +90 val_90 2010-08-15 11 +90 val_90 2010-08-15 11 +90 val_90 2010-08-15 11 +90 val_90 2010-08-15 11 +90 val_90 2010-08-15 12 +90 val_90 2010-08-15 12 +90 val_90 2010-08-15 12 +92 val_92 2010-08-15 11 +92 val_92 2010-08-15 11 +92 val_92 2010-08-15 12 +95 val_95 2010-08-15 11 +95 val_95 2010-08-15 11 +95 val_95 2010-08-15 11 +95 val_95 2010-08-15 11 +95 val_95 2010-08-15 12 +95 val_95 2010-08-15 12 +96 val_96 2010-08-15 11 +96 val_96 2010-08-15 11 +96 val_96 2010-08-15 12 +97 val_97 2010-08-15 11 +97 val_97 2010-08-15 11 +97 val_97 2010-08-15 11 +97 val_97 2010-08-15 11 +97 val_97 2010-08-15 12 +97 val_97 2010-08-15 12 +98 val_98 2010-08-15 11 +98 val_98 2010-08-15 11 +98 val_98 2010-08-15 11 +98 val_98 2010-08-15 11 +98 val_98 2010-08-15 12 +98 val_98 2010-08-15 12 +100 val_100 2010-08-15 11 +100 val_100 2010-08-15 11 +100 val_100 2010-08-15 11 +100 val_100 2010-08-15 11 +100 val_100 2010-08-15 12 +100 val_100 2010-08-15 12 +103 val_103 2010-08-15 11 +103 val_103 2010-08-15 11 +103 val_103 2010-08-15 11 +103 val_103 2010-08-15 11 +103 val_103 2010-08-15 12 +103 val_103 2010-08-15 12 +104 val_104 2010-08-15 11 +104 val_104 2010-08-15 11 +104 val_104 2010-08-15 11 +104 val_104 2010-08-15 11 +104 val_104 2010-08-15 12 +104 val_104 2010-08-15 12 +105 val_105 2010-08-15 11 +105 val_105 2010-08-15 11 +105 val_105 2010-08-15 12 +111 val_111 2010-08-15 11 +111 val_111 2010-08-15 11 +111 val_111 2010-08-15 12 +113 val_113 2010-08-15 11 +113 val_113 2010-08-15 11 +113 val_113 2010-08-15 11 +113 val_113 2010-08-15 11 +113 val_113 2010-08-15 12 +113 val_113 2010-08-15 12 +114 val_114 2010-08-15 11 +114 val_114 2010-08-15 11 +114 val_114 2010-08-15 12 +116 val_116 2010-08-15 11 +116 val_116 2010-08-15 11 +116 val_116 2010-08-15 12 +118 val_118 2010-08-15 11 +118 val_118 2010-08-15 11 +118 val_118 2010-08-15 11 +118 val_118 2010-08-15 11 +118 val_118 2010-08-15 12 +118 val_118 2010-08-15 12 +119 val_119 2010-08-15 11 +119 val_119 2010-08-15 11 +119 val_119 2010-08-15 11 +119 val_119 2010-08-15 11 +119 val_119 2010-08-15 11 +119 val_119 2010-08-15 11 +119 val_119 2010-08-15 12 +119 val_119 2010-08-15 12 +119 val_119 2010-08-15 12 +120 val_120 2010-08-15 11 +120 val_120 2010-08-15 11 +120 val_120 2010-08-15 11 +120 val_120 2010-08-15 11 +120 val_120 2010-08-15 12 +120 val_120 2010-08-15 12 +125 val_125 2010-08-15 11 +125 val_125 2010-08-15 11 +125 val_125 2010-08-15 11 +125 val_125 2010-08-15 11 +125 val_125 2010-08-15 12 +125 val_125 2010-08-15 12 +126 val_126 2010-08-15 11 +126 val_126 2010-08-15 11 +126 val_126 2010-08-15 12 +128 val_128 2010-08-15 11 +128 val_128 2010-08-15 11 +128 val_128 2010-08-15 11 +128 val_128 2010-08-15 11 +128 val_128 2010-08-15 11 +128 val_128 2010-08-15 11 +128 val_128 2010-08-15 12 +128 val_128 2010-08-15 12 +128 val_128 2010-08-15 12 +129 val_129 2010-08-15 11 +129 val_129 2010-08-15 11 +129 val_129 2010-08-15 11 +129 val_129 2010-08-15 11 +129 val_129 2010-08-15 12 +129 val_129 2010-08-15 12 +131 val_131 2010-08-15 11 +131 val_131 2010-08-15 11 +131 val_131 2010-08-15 12 +133 val_133 2010-08-15 11 +133 val_133 2010-08-15 11 +133 val_133 2010-08-15 12 +134 val_134 2010-08-15 11 +134 val_134 2010-08-15 11 +134 val_134 2010-08-15 11 +134 val_134 2010-08-15 11 +134 val_134 2010-08-15 12 +134 val_134 2010-08-15 12 +136 val_136 2010-08-15 11 +136 val_136 2010-08-15 11 +136 val_136 2010-08-15 12 +137 val_137 2010-08-15 11 +137 val_137 2010-08-15 11 +137 val_137 2010-08-15 11 +137 val_137 2010-08-15 11 +137 val_137 2010-08-15 12 +137 val_137 2010-08-15 12 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 12 +138 val_138 2010-08-15 12 +138 val_138 2010-08-15 12 +138 val_138 2010-08-15 12 +143 val_143 2010-08-15 11 +143 val_143 2010-08-15 11 +143 val_143 2010-08-15 12 +145 val_145 2010-08-15 11 +145 val_145 2010-08-15 11 +145 val_145 2010-08-15 12 +146 val_146 2010-08-15 11 +146 val_146 2010-08-15 11 +146 val_146 2010-08-15 11 +146 val_146 2010-08-15 11 +146 val_146 2010-08-15 12 +146 val_146 2010-08-15 12 +149 val_149 2010-08-15 11 +149 val_149 2010-08-15 11 +149 val_149 2010-08-15 11 +149 val_149 2010-08-15 11 +149 val_149 2010-08-15 12 +149 val_149 2010-08-15 12 +150 val_150 2010-08-15 11 +150 val_150 2010-08-15 11 +150 val_150 2010-08-15 12 +152 val_152 2010-08-15 11 +152 val_152 2010-08-15 11 +152 val_152 2010-08-15 11 +152 val_152 2010-08-15 11 +152 val_152 2010-08-15 12 +152 val_152 2010-08-15 12 +153 val_153 2010-08-15 11 +153 val_153 2010-08-15 11 +153 val_153 2010-08-15 12 +155 val_155 2010-08-15 11 +155 val_155 2010-08-15 11 +155 val_155 2010-08-15 12 +156 val_156 2010-08-15 11 +156 val_156 2010-08-15 11 +156 val_156 2010-08-15 12 +157 val_157 2010-08-15 11 +157 val_157 2010-08-15 11 +157 val_157 2010-08-15 12 +158 val_158 2010-08-15 11 +158 val_158 2010-08-15 11 +158 val_158 2010-08-15 12 +160 val_160 2010-08-15 11 +160 val_160 2010-08-15 11 +160 val_160 2010-08-15 12 +162 val_162 2010-08-15 11 +162 val_162 2010-08-15 11 +162 val_162 2010-08-15 12 +163 val_163 2010-08-15 11 +163 val_163 2010-08-15 11 +163 val_163 2010-08-15 12 +164 val_164 2010-08-15 11 +164 val_164 2010-08-15 11 +164 val_164 2010-08-15 11 +164 val_164 2010-08-15 11 +164 val_164 2010-08-15 12 +164 val_164 2010-08-15 12 +165 val_165 2010-08-15 11 +165 val_165 2010-08-15 11 +165 val_165 2010-08-15 11 +165 val_165 2010-08-15 11 +165 val_165 2010-08-15 12 +165 val_165 2010-08-15 12 +166 val_166 2010-08-15 11 +166 val_166 2010-08-15 11 +166 val_166 2010-08-15 12 +167 val_167 2010-08-15 11 +167 val_167 2010-08-15 11 +167 val_167 2010-08-15 11 +167 val_167 2010-08-15 11 +167 val_167 2010-08-15 11 +167 val_167 2010-08-15 11 +167 val_167 2010-08-15 12 +167 val_167 2010-08-15 12 +167 val_167 2010-08-15 12 +168 val_168 2010-08-15 11 +168 val_168 2010-08-15 11 +168 val_168 2010-08-15 12 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 12 +169 val_169 2010-08-15 12 +169 val_169 2010-08-15 12 +169 val_169 2010-08-15 12 +170 val_170 2010-08-15 11 +170 val_170 2010-08-15 11 +170 val_170 2010-08-15 12 +172 val_172 2010-08-15 11 +172 val_172 2010-08-15 11 +172 val_172 2010-08-15 11 +172 val_172 2010-08-15 11 +172 val_172 2010-08-15 12 +172 val_172 2010-08-15 12 +174 val_174 2010-08-15 11 +174 val_174 2010-08-15 11 +174 val_174 2010-08-15 11 +174 val_174 2010-08-15 11 +174 val_174 2010-08-15 12 +174 val_174 2010-08-15 12 +175 val_175 2010-08-15 11 +175 val_175 2010-08-15 11 +175 val_175 2010-08-15 11 +175 val_175 2010-08-15 11 +175 val_175 2010-08-15 12 +175 val_175 2010-08-15 12 +176 val_176 2010-08-15 11 +176 val_176 2010-08-15 11 +176 val_176 2010-08-15 11 +176 val_176 2010-08-15 11 +176 val_176 2010-08-15 12 +176 val_176 2010-08-15 12 +177 val_177 2010-08-15 11 +177 val_177 2010-08-15 11 +177 val_177 2010-08-15 12 +178 val_178 2010-08-15 11 +178 val_178 2010-08-15 11 +178 val_178 2010-08-15 12 +179 val_179 2010-08-15 11 +179 val_179 2010-08-15 11 +179 val_179 2010-08-15 11 +179 val_179 2010-08-15 11 +179 val_179 2010-08-15 12 +179 val_179 2010-08-15 12 +180 val_180 2010-08-15 11 +180 val_180 2010-08-15 11 +180 val_180 2010-08-15 12 +181 val_181 2010-08-15 11 +181 val_181 2010-08-15 11 +181 val_181 2010-08-15 12 +183 val_183 2010-08-15 11 +183 val_183 2010-08-15 11 +183 val_183 2010-08-15 12 +186 val_186 2010-08-15 11 +186 val_186 2010-08-15 11 +186 val_186 2010-08-15 12 +187 val_187 2010-08-15 11 +187 val_187 2010-08-15 11 +187 val_187 2010-08-15 11 +187 val_187 2010-08-15 11 +187 val_187 2010-08-15 11 +187 val_187 2010-08-15 11 +187 val_187 2010-08-15 12 +187 val_187 2010-08-15 12 +187 val_187 2010-08-15 12 +189 val_189 2010-08-15 11 +189 val_189 2010-08-15 11 +189 val_189 2010-08-15 12 +190 val_190 2010-08-15 11 +190 val_190 2010-08-15 11 +190 val_190 2010-08-15 12 +191 val_191 2010-08-15 11 +191 val_191 2010-08-15 11 +191 val_191 2010-08-15 11 +191 val_191 2010-08-15 11 +191 val_191 2010-08-15 12 +191 val_191 2010-08-15 12 +192 val_192 2010-08-15 11 +192 val_192 2010-08-15 11 +192 val_192 2010-08-15 12 +193 val_193 2010-08-15 11 +193 val_193 2010-08-15 11 +193 val_193 2010-08-15 11 +193 val_193 2010-08-15 11 +193 val_193 2010-08-15 11 +193 val_193 2010-08-15 11 +193 val_193 2010-08-15 12 +193 val_193 2010-08-15 12 +193 val_193 2010-08-15 12 +194 val_194 2010-08-15 11 +194 val_194 2010-08-15 11 +194 val_194 2010-08-15 12 +195 val_195 2010-08-15 11 +195 val_195 2010-08-15 11 +195 val_195 2010-08-15 11 +195 val_195 2010-08-15 11 +195 val_195 2010-08-15 12 +195 val_195 2010-08-15 12 +196 val_196 2010-08-15 11 +196 val_196 2010-08-15 11 +196 val_196 2010-08-15 12 +197 val_197 2010-08-15 11 +197 val_197 2010-08-15 11 +197 val_197 2010-08-15 11 +197 val_197 2010-08-15 11 +197 val_197 2010-08-15 12 +197 val_197 2010-08-15 12 +199 val_199 2010-08-15 11 +199 val_199 2010-08-15 11 +199 val_199 2010-08-15 11 +199 val_199 2010-08-15 11 +199 val_199 2010-08-15 11 +199 val_199 2010-08-15 11 +199 val_199 2010-08-15 12 +199 val_199 2010-08-15 12 +199 val_199 2010-08-15 12 +200 val_200 2010-08-15 11 +200 val_200 2010-08-15 11 +200 val_200 2010-08-15 11 +200 val_200 2010-08-15 11 +200 val_200 2010-08-15 12 +200 val_200 2010-08-15 12 +201 val_201 2010-08-15 11 +201 val_201 2010-08-15 11 +201 val_201 2010-08-15 12 +202 val_202 2010-08-15 11 +202 val_202 2010-08-15 11 +202 val_202 2010-08-15 12 +203 val_203 2010-08-15 11 +203 val_203 2010-08-15 11 +203 val_203 2010-08-15 11 +203 val_203 2010-08-15 11 +203 val_203 2010-08-15 12 +203 val_203 2010-08-15 12 +205 val_205 2010-08-15 11 +205 val_205 2010-08-15 11 +205 val_205 2010-08-15 11 +205 val_205 2010-08-15 11 +205 val_205 2010-08-15 12 +205 val_205 2010-08-15 12 +207 val_207 2010-08-15 11 +207 val_207 2010-08-15 11 +207 val_207 2010-08-15 11 +207 val_207 2010-08-15 11 +207 val_207 2010-08-15 12 +207 val_207 2010-08-15 12 +208 val_208 2010-08-15 11 +208 val_208 2010-08-15 11 +208 val_208 2010-08-15 11 +208 val_208 2010-08-15 11 +208 val_208 2010-08-15 11 +208 val_208 2010-08-15 11 +208 val_208 2010-08-15 12 +208 val_208 2010-08-15 12 +208 val_208 2010-08-15 12 +209 val_209 2010-08-15 11 +209 val_209 2010-08-15 11 +209 val_209 2010-08-15 11 +209 val_209 2010-08-15 11 +209 val_209 2010-08-15 12 +209 val_209 2010-08-15 12 +213 val_213 2010-08-15 11 +213 val_213 2010-08-15 11 +213 val_213 2010-08-15 11 +213 val_213 2010-08-15 11 +213 val_213 2010-08-15 12 +213 val_213 2010-08-15 12 +214 val_214 2010-08-15 11 +214 val_214 2010-08-15 11 +214 val_214 2010-08-15 12 +216 val_216 2010-08-15 11 +216 val_216 2010-08-15 11 +216 val_216 2010-08-15 11 +216 val_216 2010-08-15 11 +216 val_216 2010-08-15 12 +216 val_216 2010-08-15 12 +217 val_217 2010-08-15 11 +217 val_217 2010-08-15 11 +217 val_217 2010-08-15 11 +217 val_217 2010-08-15 11 +217 val_217 2010-08-15 12 +217 val_217 2010-08-15 12 +218 val_218 2010-08-15 11 +218 val_218 2010-08-15 11 +218 val_218 2010-08-15 12 +219 val_219 2010-08-15 11 +219 val_219 2010-08-15 11 +219 val_219 2010-08-15 11 +219 val_219 2010-08-15 11 +219 val_219 2010-08-15 12 +219 val_219 2010-08-15 12 +221 val_221 2010-08-15 11 +221 val_221 2010-08-15 11 +221 val_221 2010-08-15 11 +221 val_221 2010-08-15 11 +221 val_221 2010-08-15 12 +221 val_221 2010-08-15 12 +222 val_222 2010-08-15 11 +222 val_222 2010-08-15 11 +222 val_222 2010-08-15 12 +223 val_223 2010-08-15 11 +223 val_223 2010-08-15 11 +223 val_223 2010-08-15 11 +223 val_223 2010-08-15 11 +223 val_223 2010-08-15 12 +223 val_223 2010-08-15 12 +224 val_224 2010-08-15 11 +224 val_224 2010-08-15 11 +224 val_224 2010-08-15 11 +224 val_224 2010-08-15 11 +224 val_224 2010-08-15 12 +224 val_224 2010-08-15 12 +226 val_226 2010-08-15 11 +226 val_226 2010-08-15 11 +226 val_226 2010-08-15 12 +228 val_228 2010-08-15 11 +228 val_228 2010-08-15 11 +228 val_228 2010-08-15 12 +229 val_229 2010-08-15 11 +229 val_229 2010-08-15 11 +229 val_229 2010-08-15 11 +229 val_229 2010-08-15 11 +229 val_229 2010-08-15 12 +229 val_229 2010-08-15 12 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 12 +230 val_230 2010-08-15 12 +230 val_230 2010-08-15 12 +230 val_230 2010-08-15 12 +230 val_230 2010-08-15 12 +233 val_233 2010-08-15 11 +233 val_233 2010-08-15 11 +233 val_233 2010-08-15 11 +233 val_233 2010-08-15 11 +233 val_233 2010-08-15 12 +233 val_233 2010-08-15 12 +235 val_235 2010-08-15 11 +235 val_235 2010-08-15 11 +235 val_235 2010-08-15 12 +237 val_237 2010-08-15 11 +237 val_237 2010-08-15 11 +237 val_237 2010-08-15 11 +237 val_237 2010-08-15 11 +237 val_237 2010-08-15 12 +237 val_237 2010-08-15 12 +238 val_238 2010-08-15 11 +238 val_238 2010-08-15 11 +238 val_238 2010-08-15 11 +238 val_238 2010-08-15 11 +238 val_238 2010-08-15 12 +238 val_238 2010-08-15 12 +239 val_239 2010-08-15 11 +239 val_239 2010-08-15 11 +239 val_239 2010-08-15 11 +239 val_239 2010-08-15 11 +239 val_239 2010-08-15 12 +239 val_239 2010-08-15 12 +241 val_241 2010-08-15 11 +241 val_241 2010-08-15 11 +241 val_241 2010-08-15 12 +242 val_242 2010-08-15 11 +242 val_242 2010-08-15 11 +242 val_242 2010-08-15 11 +242 val_242 2010-08-15 11 +242 val_242 2010-08-15 12 +242 val_242 2010-08-15 12 +244 val_244 2010-08-15 11 +244 val_244 2010-08-15 11 +244 val_244 2010-08-15 12 +247 val_247 2010-08-15 11 +247 val_247 2010-08-15 11 +247 val_247 2010-08-15 12 +248 val_248 2010-08-15 11 +248 val_248 2010-08-15 11 +248 val_248 2010-08-15 12 +249 val_249 2010-08-15 11 +249 val_249 2010-08-15 11 +249 val_249 2010-08-15 12 +252 val_252 2010-08-15 11 +252 val_252 2010-08-15 11 +252 val_252 2010-08-15 12 +255 val_255 2010-08-15 11 +255 val_255 2010-08-15 11 +255 val_255 2010-08-15 11 +255 val_255 2010-08-15 11 +255 val_255 2010-08-15 12 +255 val_255 2010-08-15 12 +256 val_256 2010-08-15 11 +256 val_256 2010-08-15 11 +256 val_256 2010-08-15 11 +256 val_256 2010-08-15 11 +256 val_256 2010-08-15 12 +256 val_256 2010-08-15 12 +257 val_257 2010-08-15 11 +257 val_257 2010-08-15 11 +257 val_257 2010-08-15 12 +258 val_258 2010-08-15 11 +258 val_258 2010-08-15 11 +258 val_258 2010-08-15 12 +260 val_260 2010-08-15 11 +260 val_260 2010-08-15 11 +260 val_260 2010-08-15 12 +262 val_262 2010-08-15 11 +262 val_262 2010-08-15 11 +262 val_262 2010-08-15 12 +263 val_263 2010-08-15 11 +263 val_263 2010-08-15 11 +263 val_263 2010-08-15 12 +265 val_265 2010-08-15 11 +265 val_265 2010-08-15 11 +265 val_265 2010-08-15 11 +265 val_265 2010-08-15 11 +265 val_265 2010-08-15 12 +265 val_265 2010-08-15 12 +266 val_266 2010-08-15 11 +266 val_266 2010-08-15 11 +266 val_266 2010-08-15 12 +272 val_272 2010-08-15 11 +272 val_272 2010-08-15 11 +272 val_272 2010-08-15 11 +272 val_272 2010-08-15 11 +272 val_272 2010-08-15 12 +272 val_272 2010-08-15 12 +273 val_273 2010-08-15 11 +273 val_273 2010-08-15 11 +273 val_273 2010-08-15 11 +273 val_273 2010-08-15 11 +273 val_273 2010-08-15 11 +273 val_273 2010-08-15 11 +273 val_273 2010-08-15 12 +273 val_273 2010-08-15 12 +273 val_273 2010-08-15 12 +274 val_274 2010-08-15 11 +274 val_274 2010-08-15 11 +274 val_274 2010-08-15 12 +275 val_275 2010-08-15 11 +275 val_275 2010-08-15 11 +275 val_275 2010-08-15 12 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 12 +277 val_277 2010-08-15 12 +277 val_277 2010-08-15 12 +277 val_277 2010-08-15 12 +278 val_278 2010-08-15 11 +278 val_278 2010-08-15 11 +278 val_278 2010-08-15 11 +278 val_278 2010-08-15 11 +278 val_278 2010-08-15 12 +278 val_278 2010-08-15 12 +280 val_280 2010-08-15 11 +280 val_280 2010-08-15 11 +280 val_280 2010-08-15 11 +280 val_280 2010-08-15 11 +280 val_280 2010-08-15 12 +280 val_280 2010-08-15 12 +281 val_281 2010-08-15 11 +281 val_281 2010-08-15 11 +281 val_281 2010-08-15 11 +281 val_281 2010-08-15 11 +281 val_281 2010-08-15 12 +281 val_281 2010-08-15 12 +282 val_282 2010-08-15 11 +282 val_282 2010-08-15 11 +282 val_282 2010-08-15 11 +282 val_282 2010-08-15 11 +282 val_282 2010-08-15 12 +282 val_282 2010-08-15 12 +283 val_283 2010-08-15 11 +283 val_283 2010-08-15 11 +283 val_283 2010-08-15 12 +284 val_284 2010-08-15 11 +284 val_284 2010-08-15 11 +284 val_284 2010-08-15 12 +285 val_285 2010-08-15 11 +285 val_285 2010-08-15 11 +285 val_285 2010-08-15 12 +286 val_286 2010-08-15 11 +286 val_286 2010-08-15 11 +286 val_286 2010-08-15 12 +287 val_287 2010-08-15 11 +287 val_287 2010-08-15 11 +287 val_287 2010-08-15 12 +288 val_288 2010-08-15 11 +288 val_288 2010-08-15 11 +288 val_288 2010-08-15 11 +288 val_288 2010-08-15 11 +288 val_288 2010-08-15 12 +288 val_288 2010-08-15 12 +289 val_289 2010-08-15 11 +289 val_289 2010-08-15 11 +289 val_289 2010-08-15 12 +291 val_291 2010-08-15 11 +291 val_291 2010-08-15 11 +291 val_291 2010-08-15 12 +292 val_292 2010-08-15 11 +292 val_292 2010-08-15 11 +292 val_292 2010-08-15 12 +296 val_296 2010-08-15 11 +296 val_296 2010-08-15 11 +296 val_296 2010-08-15 12 +298 val_298 2010-08-15 11 +298 val_298 2010-08-15 11 +298 val_298 2010-08-15 11 +298 val_298 2010-08-15 11 +298 val_298 2010-08-15 11 +298 val_298 2010-08-15 11 +298 val_298 2010-08-15 12 +298 val_298 2010-08-15 12 +298 val_298 2010-08-15 12 +302 val_302 2010-08-15 11 +302 val_302 2010-08-15 11 +302 val_302 2010-08-15 12 +305 val_305 2010-08-15 11 +305 val_305 2010-08-15 11 +305 val_305 2010-08-15 12 +306 val_306 2010-08-15 11 +306 val_306 2010-08-15 11 +306 val_306 2010-08-15 12 +307 val_307 2010-08-15 11 +307 val_307 2010-08-15 11 +307 val_307 2010-08-15 11 +307 val_307 2010-08-15 11 +307 val_307 2010-08-15 12 +307 val_307 2010-08-15 12 +308 val_308 2010-08-15 11 +308 val_308 2010-08-15 11 +308 val_308 2010-08-15 12 +309 val_309 2010-08-15 11 +309 val_309 2010-08-15 11 +309 val_309 2010-08-15 11 +309 val_309 2010-08-15 11 +309 val_309 2010-08-15 12 +309 val_309 2010-08-15 12 +310 val_310 2010-08-15 11 +310 val_310 2010-08-15 11 +310 val_310 2010-08-15 12 +311 val_311 2010-08-15 11 +311 val_311 2010-08-15 11 +311 val_311 2010-08-15 11 +311 val_311 2010-08-15 11 +311 val_311 2010-08-15 11 +311 val_311 2010-08-15 11 +311 val_311 2010-08-15 12 +311 val_311 2010-08-15 12 +311 val_311 2010-08-15 12 +315 val_315 2010-08-15 11 +315 val_315 2010-08-15 11 +315 val_315 2010-08-15 12 +316 val_316 2010-08-15 11 +316 val_316 2010-08-15 11 +316 val_316 2010-08-15 11 +316 val_316 2010-08-15 11 +316 val_316 2010-08-15 11 +316 val_316 2010-08-15 11 +316 val_316 2010-08-15 12 +316 val_316 2010-08-15 12 +316 val_316 2010-08-15 12 +317 val_317 2010-08-15 11 +317 val_317 2010-08-15 11 +317 val_317 2010-08-15 11 +317 val_317 2010-08-15 11 +317 val_317 2010-08-15 12 +317 val_317 2010-08-15 12 +318 val_318 2010-08-15 11 +318 val_318 2010-08-15 11 +318 val_318 2010-08-15 11 +318 val_318 2010-08-15 11 +318 val_318 2010-08-15 11 +318 val_318 2010-08-15 11 +318 val_318 2010-08-15 12 +318 val_318 2010-08-15 12 +318 val_318 2010-08-15 12 +321 val_321 2010-08-15 11 +321 val_321 2010-08-15 11 +321 val_321 2010-08-15 11 +321 val_321 2010-08-15 11 +321 val_321 2010-08-15 12 +321 val_321 2010-08-15 12 +322 val_322 2010-08-15 11 +322 val_322 2010-08-15 11 +322 val_322 2010-08-15 11 +322 val_322 2010-08-15 11 +322 val_322 2010-08-15 12 +322 val_322 2010-08-15 12 +323 val_323 2010-08-15 11 +323 val_323 2010-08-15 11 +323 val_323 2010-08-15 12 +325 val_325 2010-08-15 11 +325 val_325 2010-08-15 11 +325 val_325 2010-08-15 11 +325 val_325 2010-08-15 11 +325 val_325 2010-08-15 12 +325 val_325 2010-08-15 12 +327 val_327 2010-08-15 11 +327 val_327 2010-08-15 11 +327 val_327 2010-08-15 11 +327 val_327 2010-08-15 11 +327 val_327 2010-08-15 11 +327 val_327 2010-08-15 11 +327 val_327 2010-08-15 12 +327 val_327 2010-08-15 12 +327 val_327 2010-08-15 12 +331 val_331 2010-08-15 11 +331 val_331 2010-08-15 11 +331 val_331 2010-08-15 11 +331 val_331 2010-08-15 11 +331 val_331 2010-08-15 12 +331 val_331 2010-08-15 12 +332 val_332 2010-08-15 11 +332 val_332 2010-08-15 11 +332 val_332 2010-08-15 12 +333 val_333 2010-08-15 11 +333 val_333 2010-08-15 11 +333 val_333 2010-08-15 11 +333 val_333 2010-08-15 11 +333 val_333 2010-08-15 12 +333 val_333 2010-08-15 12 +335 val_335 2010-08-15 11 +335 val_335 2010-08-15 11 +335 val_335 2010-08-15 12 +336 val_336 2010-08-15 11 +336 val_336 2010-08-15 11 +336 val_336 2010-08-15 12 +338 val_338 2010-08-15 11 +338 val_338 2010-08-15 11 +338 val_338 2010-08-15 12 +339 val_339 2010-08-15 11 +339 val_339 2010-08-15 11 +339 val_339 2010-08-15 12 +341 val_341 2010-08-15 11 +341 val_341 2010-08-15 11 +341 val_341 2010-08-15 12 +342 val_342 2010-08-15 11 +342 val_342 2010-08-15 11 +342 val_342 2010-08-15 11 +342 val_342 2010-08-15 11 +342 val_342 2010-08-15 12 +342 val_342 2010-08-15 12 +344 val_344 2010-08-15 11 +344 val_344 2010-08-15 11 +344 val_344 2010-08-15 11 +344 val_344 2010-08-15 11 +344 val_344 2010-08-15 12 +344 val_344 2010-08-15 12 +345 val_345 2010-08-15 11 +345 val_345 2010-08-15 11 +345 val_345 2010-08-15 12 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 12 +348 val_348 2010-08-15 12 +348 val_348 2010-08-15 12 +348 val_348 2010-08-15 12 +348 val_348 2010-08-15 12 +351 val_351 2010-08-15 11 +351 val_351 2010-08-15 11 +351 val_351 2010-08-15 12 +353 val_353 2010-08-15 11 +353 val_353 2010-08-15 11 +353 val_353 2010-08-15 11 +353 val_353 2010-08-15 11 +353 val_353 2010-08-15 12 +353 val_353 2010-08-15 12 +356 val_356 2010-08-15 11 +356 val_356 2010-08-15 11 +356 val_356 2010-08-15 12 +360 val_360 2010-08-15 11 +360 val_360 2010-08-15 11 +360 val_360 2010-08-15 12 +362 val_362 2010-08-15 11 +362 val_362 2010-08-15 11 +362 val_362 2010-08-15 12 +364 val_364 2010-08-15 11 +364 val_364 2010-08-15 11 +364 val_364 2010-08-15 12 +365 val_365 2010-08-15 11 +365 val_365 2010-08-15 11 +365 val_365 2010-08-15 12 +366 val_366 2010-08-15 11 +366 val_366 2010-08-15 11 +366 val_366 2010-08-15 12 +367 val_367 2010-08-15 11 +367 val_367 2010-08-15 11 +367 val_367 2010-08-15 11 +367 val_367 2010-08-15 11 +367 val_367 2010-08-15 12 +367 val_367 2010-08-15 12 +368 val_368 2010-08-15 11 +368 val_368 2010-08-15 11 +368 val_368 2010-08-15 12 +369 val_369 2010-08-15 11 +369 val_369 2010-08-15 11 +369 val_369 2010-08-15 11 +369 val_369 2010-08-15 11 +369 val_369 2010-08-15 11 +369 val_369 2010-08-15 11 +369 val_369 2010-08-15 12 +369 val_369 2010-08-15 12 +369 val_369 2010-08-15 12 +373 val_373 2010-08-15 11 +373 val_373 2010-08-15 11 +373 val_373 2010-08-15 12 +374 val_374 2010-08-15 11 +374 val_374 2010-08-15 11 +374 val_374 2010-08-15 12 +375 val_375 2010-08-15 11 +375 val_375 2010-08-15 11 +375 val_375 2010-08-15 12 +377 val_377 2010-08-15 11 +377 val_377 2010-08-15 11 +377 val_377 2010-08-15 12 +378 val_378 2010-08-15 11 +378 val_378 2010-08-15 11 +378 val_378 2010-08-15 12 +379 val_379 2010-08-15 11 +379 val_379 2010-08-15 11 +379 val_379 2010-08-15 12 +382 val_382 2010-08-15 11 +382 val_382 2010-08-15 11 +382 val_382 2010-08-15 11 +382 val_382 2010-08-15 11 +382 val_382 2010-08-15 12 +382 val_382 2010-08-15 12 +384 val_384 2010-08-15 11 +384 val_384 2010-08-15 11 +384 val_384 2010-08-15 11 +384 val_384 2010-08-15 11 +384 val_384 2010-08-15 11 +384 val_384 2010-08-15 11 +384 val_384 2010-08-15 12 +384 val_384 2010-08-15 12 +384 val_384 2010-08-15 12 +386 val_386 2010-08-15 11 +386 val_386 2010-08-15 11 +386 val_386 2010-08-15 12 +389 val_389 2010-08-15 11 +389 val_389 2010-08-15 11 +389 val_389 2010-08-15 12 +392 val_392 2010-08-15 11 +392 val_392 2010-08-15 11 +392 val_392 2010-08-15 12 +393 val_393 2010-08-15 11 +393 val_393 2010-08-15 11 +393 val_393 2010-08-15 12 +394 val_394 2010-08-15 11 +394 val_394 2010-08-15 11 +394 val_394 2010-08-15 12 +395 val_395 2010-08-15 11 +395 val_395 2010-08-15 11 +395 val_395 2010-08-15 11 +395 val_395 2010-08-15 11 +395 val_395 2010-08-15 12 +395 val_395 2010-08-15 12 +396 val_396 2010-08-15 11 +396 val_396 2010-08-15 11 +396 val_396 2010-08-15 11 +396 val_396 2010-08-15 11 +396 val_396 2010-08-15 11 +396 val_396 2010-08-15 11 +396 val_396 2010-08-15 12 +396 val_396 2010-08-15 12 +396 val_396 2010-08-15 12 +397 val_397 2010-08-15 11 +397 val_397 2010-08-15 11 +397 val_397 2010-08-15 11 +397 val_397 2010-08-15 11 +397 val_397 2010-08-15 12 +397 val_397 2010-08-15 12 +399 val_399 2010-08-15 11 +399 val_399 2010-08-15 11 +399 val_399 2010-08-15 11 +399 val_399 2010-08-15 11 +399 val_399 2010-08-15 12 +399 val_399 2010-08-15 12 +400 val_400 2010-08-15 11 +400 val_400 2010-08-15 11 +400 val_400 2010-08-15 12 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 12 +401 val_401 2010-08-15 12 +401 val_401 2010-08-15 12 +401 val_401 2010-08-15 12 +401 val_401 2010-08-15 12 +402 val_402 2010-08-15 11 +402 val_402 2010-08-15 11 +402 val_402 2010-08-15 12 +403 val_403 2010-08-15 11 +403 val_403 2010-08-15 11 +403 val_403 2010-08-15 11 +403 val_403 2010-08-15 11 +403 val_403 2010-08-15 11 +403 val_403 2010-08-15 11 +403 val_403 2010-08-15 12 +403 val_403 2010-08-15 12 +403 val_403 2010-08-15 12 +404 val_404 2010-08-15 11 +404 val_404 2010-08-15 11 +404 val_404 2010-08-15 11 +404 val_404 2010-08-15 11 +404 val_404 2010-08-15 12 +404 val_404 2010-08-15 12 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 12 +406 val_406 2010-08-15 12 +406 val_406 2010-08-15 12 +406 val_406 2010-08-15 12 +407 val_407 2010-08-15 11 +407 val_407 2010-08-15 11 +407 val_407 2010-08-15 12 +409 val_409 2010-08-15 11 +409 val_409 2010-08-15 11 +409 val_409 2010-08-15 11 +409 val_409 2010-08-15 11 +409 val_409 2010-08-15 11 +409 val_409 2010-08-15 11 +409 val_409 2010-08-15 12 +409 val_409 2010-08-15 12 +409 val_409 2010-08-15 12 +411 val_411 2010-08-15 11 +411 val_411 2010-08-15 11 +411 val_411 2010-08-15 12 +413 val_413 2010-08-15 11 +413 val_413 2010-08-15 11 +413 val_413 2010-08-15 11 +413 val_413 2010-08-15 11 +413 val_413 2010-08-15 12 +413 val_413 2010-08-15 12 +414 val_414 2010-08-15 11 +414 val_414 2010-08-15 11 +414 val_414 2010-08-15 11 +414 val_414 2010-08-15 11 +414 val_414 2010-08-15 12 +414 val_414 2010-08-15 12 +417 val_417 2010-08-15 11 +417 val_417 2010-08-15 11 +417 val_417 2010-08-15 11 +417 val_417 2010-08-15 11 +417 val_417 2010-08-15 11 +417 val_417 2010-08-15 11 +417 val_417 2010-08-15 12 +417 val_417 2010-08-15 12 +417 val_417 2010-08-15 12 +418 val_418 2010-08-15 11 +418 val_418 2010-08-15 11 +418 val_418 2010-08-15 12 +419 val_419 2010-08-15 11 +419 val_419 2010-08-15 11 +419 val_419 2010-08-15 12 +421 val_421 2010-08-15 11 +421 val_421 2010-08-15 11 +421 val_421 2010-08-15 12 +424 val_424 2010-08-15 11 +424 val_424 2010-08-15 11 +424 val_424 2010-08-15 11 +424 val_424 2010-08-15 11 +424 val_424 2010-08-15 12 +424 val_424 2010-08-15 12 +427 val_427 2010-08-15 11 +427 val_427 2010-08-15 11 +427 val_427 2010-08-15 12 +429 val_429 2010-08-15 11 +429 val_429 2010-08-15 11 +429 val_429 2010-08-15 11 +429 val_429 2010-08-15 11 +429 val_429 2010-08-15 12 +429 val_429 2010-08-15 12 +430 val_430 2010-08-15 11 +430 val_430 2010-08-15 11 +430 val_430 2010-08-15 11 +430 val_430 2010-08-15 11 +430 val_430 2010-08-15 11 +430 val_430 2010-08-15 11 +430 val_430 2010-08-15 12 +430 val_430 2010-08-15 12 +430 val_430 2010-08-15 12 +431 val_431 2010-08-15 11 +431 val_431 2010-08-15 11 +431 val_431 2010-08-15 11 +431 val_431 2010-08-15 11 +431 val_431 2010-08-15 11 +431 val_431 2010-08-15 11 +431 val_431 2010-08-15 12 +431 val_431 2010-08-15 12 +431 val_431 2010-08-15 12 +432 val_432 2010-08-15 11 +432 val_432 2010-08-15 11 +432 val_432 2010-08-15 12 +435 val_435 2010-08-15 11 +435 val_435 2010-08-15 11 +435 val_435 2010-08-15 12 +436 val_436 2010-08-15 11 +436 val_436 2010-08-15 11 +436 val_436 2010-08-15 12 +437 val_437 2010-08-15 11 +437 val_437 2010-08-15 11 +437 val_437 2010-08-15 12 +438 val_438 2010-08-15 11 +438 val_438 2010-08-15 11 +438 val_438 2010-08-15 11 +438 val_438 2010-08-15 11 +438 val_438 2010-08-15 11 +438 val_438 2010-08-15 11 +438 val_438 2010-08-15 12 +438 val_438 2010-08-15 12 +438 val_438 2010-08-15 12 +439 val_439 2010-08-15 11 +439 val_439 2010-08-15 11 +439 val_439 2010-08-15 11 +439 val_439 2010-08-15 11 +439 val_439 2010-08-15 12 +439 val_439 2010-08-15 12 +443 val_443 2010-08-15 11 +443 val_443 2010-08-15 11 +443 val_443 2010-08-15 12 +444 val_444 2010-08-15 11 +444 val_444 2010-08-15 11 +444 val_444 2010-08-15 12 +446 val_446 2010-08-15 11 +446 val_446 2010-08-15 11 +446 val_446 2010-08-15 12 +448 val_448 2010-08-15 11 +448 val_448 2010-08-15 11 +448 val_448 2010-08-15 12 +449 val_449 2010-08-15 11 +449 val_449 2010-08-15 11 +449 val_449 2010-08-15 12 +452 val_452 2010-08-15 11 +452 val_452 2010-08-15 11 +452 val_452 2010-08-15 12 +453 val_453 2010-08-15 11 +453 val_453 2010-08-15 11 +453 val_453 2010-08-15 12 +454 val_454 2010-08-15 11 +454 val_454 2010-08-15 11 +454 val_454 2010-08-15 11 +454 val_454 2010-08-15 11 +454 val_454 2010-08-15 11 +454 val_454 2010-08-15 11 +454 val_454 2010-08-15 12 +454 val_454 2010-08-15 12 +454 val_454 2010-08-15 12 +455 val_455 2010-08-15 11 +455 val_455 2010-08-15 11 +455 val_455 2010-08-15 12 +457 val_457 2010-08-15 11 +457 val_457 2010-08-15 11 +457 val_457 2010-08-15 12 +458 val_458 2010-08-15 11 +458 val_458 2010-08-15 11 +458 val_458 2010-08-15 11 +458 val_458 2010-08-15 11 +458 val_458 2010-08-15 12 +458 val_458 2010-08-15 12 +459 val_459 2010-08-15 11 +459 val_459 2010-08-15 11 +459 val_459 2010-08-15 11 +459 val_459 2010-08-15 11 +459 val_459 2010-08-15 12 +459 val_459 2010-08-15 12 +460 val_460 2010-08-15 11 +460 val_460 2010-08-15 11 +460 val_460 2010-08-15 12 +462 val_462 2010-08-15 11 +462 val_462 2010-08-15 11 +462 val_462 2010-08-15 11 +462 val_462 2010-08-15 11 +462 val_462 2010-08-15 12 +462 val_462 2010-08-15 12 +463 val_463 2010-08-15 11 +463 val_463 2010-08-15 11 +463 val_463 2010-08-15 11 +463 val_463 2010-08-15 11 +463 val_463 2010-08-15 12 +463 val_463 2010-08-15 12 +466 val_466 2010-08-15 11 +466 val_466 2010-08-15 11 +466 val_466 2010-08-15 11 +466 val_466 2010-08-15 11 +466 val_466 2010-08-15 11 +466 val_466 2010-08-15 11 +466 val_466 2010-08-15 12 +466 val_466 2010-08-15 12 +466 val_466 2010-08-15 12 +467 val_467 2010-08-15 11 +467 val_467 2010-08-15 11 +467 val_467 2010-08-15 12 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 12 +468 val_468 2010-08-15 12 +468 val_468 2010-08-15 12 +468 val_468 2010-08-15 12 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 12 +469 val_469 2010-08-15 12 +469 val_469 2010-08-15 12 +469 val_469 2010-08-15 12 +469 val_469 2010-08-15 12 +470 val_470 2010-08-15 11 +470 val_470 2010-08-15 11 +470 val_470 2010-08-15 12 +472 val_472 2010-08-15 11 +472 val_472 2010-08-15 11 +472 val_472 2010-08-15 12 +475 val_475 2010-08-15 11 +475 val_475 2010-08-15 11 +475 val_475 2010-08-15 12 +477 val_477 2010-08-15 11 +477 val_477 2010-08-15 11 +477 val_477 2010-08-15 12 +478 val_478 2010-08-15 11 +478 val_478 2010-08-15 11 +478 val_478 2010-08-15 11 +478 val_478 2010-08-15 11 +478 val_478 2010-08-15 12 +478 val_478 2010-08-15 12 +479 val_479 2010-08-15 11 +479 val_479 2010-08-15 11 +479 val_479 2010-08-15 12 +480 val_480 2010-08-15 11 +480 val_480 2010-08-15 11 +480 val_480 2010-08-15 11 +480 val_480 2010-08-15 11 +480 val_480 2010-08-15 11 +480 val_480 2010-08-15 11 +480 val_480 2010-08-15 12 +480 val_480 2010-08-15 12 +480 val_480 2010-08-15 12 +481 val_481 2010-08-15 11 +481 val_481 2010-08-15 11 +481 val_481 2010-08-15 12 +482 val_482 2010-08-15 11 +482 val_482 2010-08-15 11 +482 val_482 2010-08-15 12 +483 val_483 2010-08-15 11 +483 val_483 2010-08-15 11 +483 val_483 2010-08-15 12 +484 val_484 2010-08-15 11 +484 val_484 2010-08-15 11 +484 val_484 2010-08-15 12 +485 val_485 2010-08-15 11 +485 val_485 2010-08-15 11 +485 val_485 2010-08-15 12 +487 val_487 2010-08-15 11 +487 val_487 2010-08-15 11 +487 val_487 2010-08-15 12 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 12 +489 val_489 2010-08-15 12 +489 val_489 2010-08-15 12 +489 val_489 2010-08-15 12 +490 val_490 2010-08-15 11 +490 val_490 2010-08-15 11 +490 val_490 2010-08-15 12 +491 val_491 2010-08-15 11 +491 val_491 2010-08-15 11 +491 val_491 2010-08-15 12 +492 val_492 2010-08-15 11 +492 val_492 2010-08-15 11 +492 val_492 2010-08-15 11 +492 val_492 2010-08-15 11 +492 val_492 2010-08-15 12 +492 val_492 2010-08-15 12 +493 val_493 2010-08-15 11 +493 val_493 2010-08-15 11 +493 val_493 2010-08-15 12 +494 val_494 2010-08-15 11 +494 val_494 2010-08-15 11 +494 val_494 2010-08-15 12 +495 val_495 2010-08-15 11 +495 val_495 2010-08-15 11 +495 val_495 2010-08-15 12 +496 val_496 2010-08-15 11 +496 val_496 2010-08-15 11 +496 val_496 2010-08-15 12 +497 val_497 2010-08-15 11 +497 val_497 2010-08-15 11 +497 val_497 2010-08-15 12 +498 val_498 2010-08-15 11 +498 val_498 2010-08-15 11 +498 val_498 2010-08-15 11 +498 val_498 2010-08-15 11 +498 val_498 2010-08-15 11 +498 val_498 2010-08-15 11 +498 val_498 2010-08-15 12 +498 val_498 2010-08-15 12 +498 val_498 2010-08-15 12 diff --git a/sql/hive/src/test/resources/golden/merge4-12-62541540a18d68a3cb8497a741061d11 b/sql/hive/src/test/resources/golden/merge4-12-62541540a18d68a3cb8497a741061d11 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/merge4-13-ed1103f06609365b40e78d13c654cc71 b/sql/hive/src/test/resources/golden/merge4-13-ed1103f06609365b40e78d13c654cc71 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/merge4-14-ba5dbcd0527b8ddab284bc322255bfc7 b/sql/hive/src/test/resources/golden/merge4-14-ba5dbcd0527b8ddab284bc322255bfc7 new file mode 100644 index 0000000000000..30becc42d7b5a --- /dev/null +++ b/sql/hive/src/test/resources/golden/merge4-14-ba5dbcd0527b8ddab284bc322255bfc7 @@ -0,0 +1,3 @@ +ds=2010-08-15/hr=11 +ds=2010-08-15/hr=12 +ds=2010-08-15/hr=file, diff --git a/sql/hive/src/test/resources/golden/merge4-15-68f50dc2ad6ff803a372bdd88dd8e19a b/sql/hive/src/test/resources/golden/merge4-15-68f50dc2ad6ff803a372bdd88dd8e19a new file mode 100644 index 0000000000000..4c867a5deff08 --- /dev/null +++ b/sql/hive/src/test/resources/golden/merge4-15-68f50dc2ad6ff803a372bdd88dd8e19a @@ -0,0 +1 @@ +1 1 2010-08-15 file, diff --git a/sql/hive/src/test/resources/golden/merge4-2-43d53504df013e6b35f81811138a167a b/sql/hive/src/test/resources/golden/merge4-2-43d53504df013e6b35f81811138a167a new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/merge4-2-43d53504df013e6b35f81811138a167a @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/merge4-3-a4fb8359a2179ec70777aad6366071b7 b/sql/hive/src/test/resources/golden/merge4-3-a4fb8359a2179ec70777aad6366071b7 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/merge4-3-a4fb8359a2179ec70777aad6366071b7 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/merge4-4-16367c381d4b189b3640c92511244bfe b/sql/hive/src/test/resources/golden/merge4-4-16367c381d4b189b3640c92511244bfe new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/merge4-4-16367c381d4b189b3640c92511244bfe @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/merge4-5-3d24d877366c42030f6d9a596665720d b/sql/hive/src/test/resources/golden/merge4-5-3d24d877366c42030f6d9a596665720d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/merge4-6-b3a76420183795720ab3a384046e5af b/sql/hive/src/test/resources/golden/merge4-6-b3a76420183795720ab3a384046e5af new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/merge4-7-631a45828eae3f5f562d992efe4cd56d b/sql/hive/src/test/resources/golden/merge4-7-631a45828eae3f5f562d992efe4cd56d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/merge4-8-f407e661307b23a5d52a08a3e7af19b b/sql/hive/src/test/resources/golden/merge4-8-f407e661307b23a5d52a08a3e7af19b new file mode 100644 index 0000000000000..aa972caa5665d --- /dev/null +++ b/sql/hive/src/test/resources/golden/merge4-8-f407e661307b23a5d52a08a3e7af19b @@ -0,0 +1,1000 @@ +0 val_0 2010-08-15 11 +0 val_0 2010-08-15 11 +0 val_0 2010-08-15 11 +0 val_0 2010-08-15 12 +0 val_0 2010-08-15 12 +0 val_0 2010-08-15 12 +2 val_2 2010-08-15 11 +2 val_2 2010-08-15 12 +4 val_4 2010-08-15 11 +4 val_4 2010-08-15 12 +5 val_5 2010-08-15 11 +5 val_5 2010-08-15 11 +5 val_5 2010-08-15 11 +5 val_5 2010-08-15 12 +5 val_5 2010-08-15 12 +5 val_5 2010-08-15 12 +8 val_8 2010-08-15 11 +8 val_8 2010-08-15 12 +9 val_9 2010-08-15 11 +9 val_9 2010-08-15 12 +10 val_10 2010-08-15 11 +10 val_10 2010-08-15 12 +11 val_11 2010-08-15 11 +11 val_11 2010-08-15 12 +12 val_12 2010-08-15 11 +12 val_12 2010-08-15 11 +12 val_12 2010-08-15 12 +12 val_12 2010-08-15 12 +15 val_15 2010-08-15 11 +15 val_15 2010-08-15 11 +15 val_15 2010-08-15 12 +15 val_15 2010-08-15 12 +17 val_17 2010-08-15 11 +17 val_17 2010-08-15 12 +18 val_18 2010-08-15 11 +18 val_18 2010-08-15 11 +18 val_18 2010-08-15 12 +18 val_18 2010-08-15 12 +19 val_19 2010-08-15 11 +19 val_19 2010-08-15 12 +20 val_20 2010-08-15 11 +20 val_20 2010-08-15 12 +24 val_24 2010-08-15 11 +24 val_24 2010-08-15 11 +24 val_24 2010-08-15 12 +24 val_24 2010-08-15 12 +26 val_26 2010-08-15 11 +26 val_26 2010-08-15 11 +26 val_26 2010-08-15 12 +26 val_26 2010-08-15 12 +27 val_27 2010-08-15 11 +27 val_27 2010-08-15 12 +28 val_28 2010-08-15 11 +28 val_28 2010-08-15 12 +30 val_30 2010-08-15 11 +30 val_30 2010-08-15 12 +33 val_33 2010-08-15 11 +33 val_33 2010-08-15 12 +34 val_34 2010-08-15 11 +34 val_34 2010-08-15 12 +35 val_35 2010-08-15 11 +35 val_35 2010-08-15 11 +35 val_35 2010-08-15 11 +35 val_35 2010-08-15 12 +35 val_35 2010-08-15 12 +35 val_35 2010-08-15 12 +37 val_37 2010-08-15 11 +37 val_37 2010-08-15 11 +37 val_37 2010-08-15 12 +37 val_37 2010-08-15 12 +41 val_41 2010-08-15 11 +41 val_41 2010-08-15 12 +42 val_42 2010-08-15 11 +42 val_42 2010-08-15 11 +42 val_42 2010-08-15 12 +42 val_42 2010-08-15 12 +43 val_43 2010-08-15 11 +43 val_43 2010-08-15 12 +44 val_44 2010-08-15 11 +44 val_44 2010-08-15 12 +47 val_47 2010-08-15 11 +47 val_47 2010-08-15 12 +51 val_51 2010-08-15 11 +51 val_51 2010-08-15 11 +51 val_51 2010-08-15 12 +51 val_51 2010-08-15 12 +53 val_53 2010-08-15 11 +53 val_53 2010-08-15 12 +54 val_54 2010-08-15 11 +54 val_54 2010-08-15 12 +57 val_57 2010-08-15 11 +57 val_57 2010-08-15 12 +58 val_58 2010-08-15 11 +58 val_58 2010-08-15 11 +58 val_58 2010-08-15 12 +58 val_58 2010-08-15 12 +64 val_64 2010-08-15 11 +64 val_64 2010-08-15 12 +65 val_65 2010-08-15 11 +65 val_65 2010-08-15 12 +66 val_66 2010-08-15 11 +66 val_66 2010-08-15 12 +67 val_67 2010-08-15 11 +67 val_67 2010-08-15 11 +67 val_67 2010-08-15 12 +67 val_67 2010-08-15 12 +69 val_69 2010-08-15 11 +69 val_69 2010-08-15 12 +70 val_70 2010-08-15 11 +70 val_70 2010-08-15 11 +70 val_70 2010-08-15 11 +70 val_70 2010-08-15 12 +70 val_70 2010-08-15 12 +70 val_70 2010-08-15 12 +72 val_72 2010-08-15 11 +72 val_72 2010-08-15 11 +72 val_72 2010-08-15 12 +72 val_72 2010-08-15 12 +74 val_74 2010-08-15 11 +74 val_74 2010-08-15 12 +76 val_76 2010-08-15 11 +76 val_76 2010-08-15 11 +76 val_76 2010-08-15 12 +76 val_76 2010-08-15 12 +77 val_77 2010-08-15 11 +77 val_77 2010-08-15 12 +78 val_78 2010-08-15 11 +78 val_78 2010-08-15 12 +80 val_80 2010-08-15 11 +80 val_80 2010-08-15 12 +82 val_82 2010-08-15 11 +82 val_82 2010-08-15 12 +83 val_83 2010-08-15 11 +83 val_83 2010-08-15 11 +83 val_83 2010-08-15 12 +83 val_83 2010-08-15 12 +84 val_84 2010-08-15 11 +84 val_84 2010-08-15 11 +84 val_84 2010-08-15 12 +84 val_84 2010-08-15 12 +85 val_85 2010-08-15 11 +85 val_85 2010-08-15 12 +86 val_86 2010-08-15 11 +86 val_86 2010-08-15 12 +87 val_87 2010-08-15 11 +87 val_87 2010-08-15 12 +90 val_90 2010-08-15 11 +90 val_90 2010-08-15 11 +90 val_90 2010-08-15 11 +90 val_90 2010-08-15 12 +90 val_90 2010-08-15 12 +90 val_90 2010-08-15 12 +92 val_92 2010-08-15 11 +92 val_92 2010-08-15 12 +95 val_95 2010-08-15 11 +95 val_95 2010-08-15 11 +95 val_95 2010-08-15 12 +95 val_95 2010-08-15 12 +96 val_96 2010-08-15 11 +96 val_96 2010-08-15 12 +97 val_97 2010-08-15 11 +97 val_97 2010-08-15 11 +97 val_97 2010-08-15 12 +97 val_97 2010-08-15 12 +98 val_98 2010-08-15 11 +98 val_98 2010-08-15 11 +98 val_98 2010-08-15 12 +98 val_98 2010-08-15 12 +100 val_100 2010-08-15 11 +100 val_100 2010-08-15 11 +100 val_100 2010-08-15 12 +100 val_100 2010-08-15 12 +103 val_103 2010-08-15 11 +103 val_103 2010-08-15 11 +103 val_103 2010-08-15 12 +103 val_103 2010-08-15 12 +104 val_104 2010-08-15 11 +104 val_104 2010-08-15 11 +104 val_104 2010-08-15 12 +104 val_104 2010-08-15 12 +105 val_105 2010-08-15 11 +105 val_105 2010-08-15 12 +111 val_111 2010-08-15 11 +111 val_111 2010-08-15 12 +113 val_113 2010-08-15 11 +113 val_113 2010-08-15 11 +113 val_113 2010-08-15 12 +113 val_113 2010-08-15 12 +114 val_114 2010-08-15 11 +114 val_114 2010-08-15 12 +116 val_116 2010-08-15 11 +116 val_116 2010-08-15 12 +118 val_118 2010-08-15 11 +118 val_118 2010-08-15 11 +118 val_118 2010-08-15 12 +118 val_118 2010-08-15 12 +119 val_119 2010-08-15 11 +119 val_119 2010-08-15 11 +119 val_119 2010-08-15 11 +119 val_119 2010-08-15 12 +119 val_119 2010-08-15 12 +119 val_119 2010-08-15 12 +120 val_120 2010-08-15 11 +120 val_120 2010-08-15 11 +120 val_120 2010-08-15 12 +120 val_120 2010-08-15 12 +125 val_125 2010-08-15 11 +125 val_125 2010-08-15 11 +125 val_125 2010-08-15 12 +125 val_125 2010-08-15 12 +126 val_126 2010-08-15 11 +126 val_126 2010-08-15 12 +128 val_128 2010-08-15 11 +128 val_128 2010-08-15 11 +128 val_128 2010-08-15 11 +128 val_128 2010-08-15 12 +128 val_128 2010-08-15 12 +128 val_128 2010-08-15 12 +129 val_129 2010-08-15 11 +129 val_129 2010-08-15 11 +129 val_129 2010-08-15 12 +129 val_129 2010-08-15 12 +131 val_131 2010-08-15 11 +131 val_131 2010-08-15 12 +133 val_133 2010-08-15 11 +133 val_133 2010-08-15 12 +134 val_134 2010-08-15 11 +134 val_134 2010-08-15 11 +134 val_134 2010-08-15 12 +134 val_134 2010-08-15 12 +136 val_136 2010-08-15 11 +136 val_136 2010-08-15 12 +137 val_137 2010-08-15 11 +137 val_137 2010-08-15 11 +137 val_137 2010-08-15 12 +137 val_137 2010-08-15 12 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 12 +138 val_138 2010-08-15 12 +138 val_138 2010-08-15 12 +138 val_138 2010-08-15 12 +143 val_143 2010-08-15 11 +143 val_143 2010-08-15 12 +145 val_145 2010-08-15 11 +145 val_145 2010-08-15 12 +146 val_146 2010-08-15 11 +146 val_146 2010-08-15 11 +146 val_146 2010-08-15 12 +146 val_146 2010-08-15 12 +149 val_149 2010-08-15 11 +149 val_149 2010-08-15 11 +149 val_149 2010-08-15 12 +149 val_149 2010-08-15 12 +150 val_150 2010-08-15 11 +150 val_150 2010-08-15 12 +152 val_152 2010-08-15 11 +152 val_152 2010-08-15 11 +152 val_152 2010-08-15 12 +152 val_152 2010-08-15 12 +153 val_153 2010-08-15 11 +153 val_153 2010-08-15 12 +155 val_155 2010-08-15 11 +155 val_155 2010-08-15 12 +156 val_156 2010-08-15 11 +156 val_156 2010-08-15 12 +157 val_157 2010-08-15 11 +157 val_157 2010-08-15 12 +158 val_158 2010-08-15 11 +158 val_158 2010-08-15 12 +160 val_160 2010-08-15 11 +160 val_160 2010-08-15 12 +162 val_162 2010-08-15 11 +162 val_162 2010-08-15 12 +163 val_163 2010-08-15 11 +163 val_163 2010-08-15 12 +164 val_164 2010-08-15 11 +164 val_164 2010-08-15 11 +164 val_164 2010-08-15 12 +164 val_164 2010-08-15 12 +165 val_165 2010-08-15 11 +165 val_165 2010-08-15 11 +165 val_165 2010-08-15 12 +165 val_165 2010-08-15 12 +166 val_166 2010-08-15 11 +166 val_166 2010-08-15 12 +167 val_167 2010-08-15 11 +167 val_167 2010-08-15 11 +167 val_167 2010-08-15 11 +167 val_167 2010-08-15 12 +167 val_167 2010-08-15 12 +167 val_167 2010-08-15 12 +168 val_168 2010-08-15 11 +168 val_168 2010-08-15 12 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 12 +169 val_169 2010-08-15 12 +169 val_169 2010-08-15 12 +169 val_169 2010-08-15 12 +170 val_170 2010-08-15 11 +170 val_170 2010-08-15 12 +172 val_172 2010-08-15 11 +172 val_172 2010-08-15 11 +172 val_172 2010-08-15 12 +172 val_172 2010-08-15 12 +174 val_174 2010-08-15 11 +174 val_174 2010-08-15 11 +174 val_174 2010-08-15 12 +174 val_174 2010-08-15 12 +175 val_175 2010-08-15 11 +175 val_175 2010-08-15 11 +175 val_175 2010-08-15 12 +175 val_175 2010-08-15 12 +176 val_176 2010-08-15 11 +176 val_176 2010-08-15 11 +176 val_176 2010-08-15 12 +176 val_176 2010-08-15 12 +177 val_177 2010-08-15 11 +177 val_177 2010-08-15 12 +178 val_178 2010-08-15 11 +178 val_178 2010-08-15 12 +179 val_179 2010-08-15 11 +179 val_179 2010-08-15 11 +179 val_179 2010-08-15 12 +179 val_179 2010-08-15 12 +180 val_180 2010-08-15 11 +180 val_180 2010-08-15 12 +181 val_181 2010-08-15 11 +181 val_181 2010-08-15 12 +183 val_183 2010-08-15 11 +183 val_183 2010-08-15 12 +186 val_186 2010-08-15 11 +186 val_186 2010-08-15 12 +187 val_187 2010-08-15 11 +187 val_187 2010-08-15 11 +187 val_187 2010-08-15 11 +187 val_187 2010-08-15 12 +187 val_187 2010-08-15 12 +187 val_187 2010-08-15 12 +189 val_189 2010-08-15 11 +189 val_189 2010-08-15 12 +190 val_190 2010-08-15 11 +190 val_190 2010-08-15 12 +191 val_191 2010-08-15 11 +191 val_191 2010-08-15 11 +191 val_191 2010-08-15 12 +191 val_191 2010-08-15 12 +192 val_192 2010-08-15 11 +192 val_192 2010-08-15 12 +193 val_193 2010-08-15 11 +193 val_193 2010-08-15 11 +193 val_193 2010-08-15 11 +193 val_193 2010-08-15 12 +193 val_193 2010-08-15 12 +193 val_193 2010-08-15 12 +194 val_194 2010-08-15 11 +194 val_194 2010-08-15 12 +195 val_195 2010-08-15 11 +195 val_195 2010-08-15 11 +195 val_195 2010-08-15 12 +195 val_195 2010-08-15 12 +196 val_196 2010-08-15 11 +196 val_196 2010-08-15 12 +197 val_197 2010-08-15 11 +197 val_197 2010-08-15 11 +197 val_197 2010-08-15 12 +197 val_197 2010-08-15 12 +199 val_199 2010-08-15 11 +199 val_199 2010-08-15 11 +199 val_199 2010-08-15 11 +199 val_199 2010-08-15 12 +199 val_199 2010-08-15 12 +199 val_199 2010-08-15 12 +200 val_200 2010-08-15 11 +200 val_200 2010-08-15 11 +200 val_200 2010-08-15 12 +200 val_200 2010-08-15 12 +201 val_201 2010-08-15 11 +201 val_201 2010-08-15 12 +202 val_202 2010-08-15 11 +202 val_202 2010-08-15 12 +203 val_203 2010-08-15 11 +203 val_203 2010-08-15 11 +203 val_203 2010-08-15 12 +203 val_203 2010-08-15 12 +205 val_205 2010-08-15 11 +205 val_205 2010-08-15 11 +205 val_205 2010-08-15 12 +205 val_205 2010-08-15 12 +207 val_207 2010-08-15 11 +207 val_207 2010-08-15 11 +207 val_207 2010-08-15 12 +207 val_207 2010-08-15 12 +208 val_208 2010-08-15 11 +208 val_208 2010-08-15 11 +208 val_208 2010-08-15 11 +208 val_208 2010-08-15 12 +208 val_208 2010-08-15 12 +208 val_208 2010-08-15 12 +209 val_209 2010-08-15 11 +209 val_209 2010-08-15 11 +209 val_209 2010-08-15 12 +209 val_209 2010-08-15 12 +213 val_213 2010-08-15 11 +213 val_213 2010-08-15 11 +213 val_213 2010-08-15 12 +213 val_213 2010-08-15 12 +214 val_214 2010-08-15 11 +214 val_214 2010-08-15 12 +216 val_216 2010-08-15 11 +216 val_216 2010-08-15 11 +216 val_216 2010-08-15 12 +216 val_216 2010-08-15 12 +217 val_217 2010-08-15 11 +217 val_217 2010-08-15 11 +217 val_217 2010-08-15 12 +217 val_217 2010-08-15 12 +218 val_218 2010-08-15 11 +218 val_218 2010-08-15 12 +219 val_219 2010-08-15 11 +219 val_219 2010-08-15 11 +219 val_219 2010-08-15 12 +219 val_219 2010-08-15 12 +221 val_221 2010-08-15 11 +221 val_221 2010-08-15 11 +221 val_221 2010-08-15 12 +221 val_221 2010-08-15 12 +222 val_222 2010-08-15 11 +222 val_222 2010-08-15 12 +223 val_223 2010-08-15 11 +223 val_223 2010-08-15 11 +223 val_223 2010-08-15 12 +223 val_223 2010-08-15 12 +224 val_224 2010-08-15 11 +224 val_224 2010-08-15 11 +224 val_224 2010-08-15 12 +224 val_224 2010-08-15 12 +226 val_226 2010-08-15 11 +226 val_226 2010-08-15 12 +228 val_228 2010-08-15 11 +228 val_228 2010-08-15 12 +229 val_229 2010-08-15 11 +229 val_229 2010-08-15 11 +229 val_229 2010-08-15 12 +229 val_229 2010-08-15 12 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 12 +230 val_230 2010-08-15 12 +230 val_230 2010-08-15 12 +230 val_230 2010-08-15 12 +230 val_230 2010-08-15 12 +233 val_233 2010-08-15 11 +233 val_233 2010-08-15 11 +233 val_233 2010-08-15 12 +233 val_233 2010-08-15 12 +235 val_235 2010-08-15 11 +235 val_235 2010-08-15 12 +237 val_237 2010-08-15 11 +237 val_237 2010-08-15 11 +237 val_237 2010-08-15 12 +237 val_237 2010-08-15 12 +238 val_238 2010-08-15 11 +238 val_238 2010-08-15 11 +238 val_238 2010-08-15 12 +238 val_238 2010-08-15 12 +239 val_239 2010-08-15 11 +239 val_239 2010-08-15 11 +239 val_239 2010-08-15 12 +239 val_239 2010-08-15 12 +241 val_241 2010-08-15 11 +241 val_241 2010-08-15 12 +242 val_242 2010-08-15 11 +242 val_242 2010-08-15 11 +242 val_242 2010-08-15 12 +242 val_242 2010-08-15 12 +244 val_244 2010-08-15 11 +244 val_244 2010-08-15 12 +247 val_247 2010-08-15 11 +247 val_247 2010-08-15 12 +248 val_248 2010-08-15 11 +248 val_248 2010-08-15 12 +249 val_249 2010-08-15 11 +249 val_249 2010-08-15 12 +252 val_252 2010-08-15 11 +252 val_252 2010-08-15 12 +255 val_255 2010-08-15 11 +255 val_255 2010-08-15 11 +255 val_255 2010-08-15 12 +255 val_255 2010-08-15 12 +256 val_256 2010-08-15 11 +256 val_256 2010-08-15 11 +256 val_256 2010-08-15 12 +256 val_256 2010-08-15 12 +257 val_257 2010-08-15 11 +257 val_257 2010-08-15 12 +258 val_258 2010-08-15 11 +258 val_258 2010-08-15 12 +260 val_260 2010-08-15 11 +260 val_260 2010-08-15 12 +262 val_262 2010-08-15 11 +262 val_262 2010-08-15 12 +263 val_263 2010-08-15 11 +263 val_263 2010-08-15 12 +265 val_265 2010-08-15 11 +265 val_265 2010-08-15 11 +265 val_265 2010-08-15 12 +265 val_265 2010-08-15 12 +266 val_266 2010-08-15 11 +266 val_266 2010-08-15 12 +272 val_272 2010-08-15 11 +272 val_272 2010-08-15 11 +272 val_272 2010-08-15 12 +272 val_272 2010-08-15 12 +273 val_273 2010-08-15 11 +273 val_273 2010-08-15 11 +273 val_273 2010-08-15 11 +273 val_273 2010-08-15 12 +273 val_273 2010-08-15 12 +273 val_273 2010-08-15 12 +274 val_274 2010-08-15 11 +274 val_274 2010-08-15 12 +275 val_275 2010-08-15 11 +275 val_275 2010-08-15 12 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 12 +277 val_277 2010-08-15 12 +277 val_277 2010-08-15 12 +277 val_277 2010-08-15 12 +278 val_278 2010-08-15 11 +278 val_278 2010-08-15 11 +278 val_278 2010-08-15 12 +278 val_278 2010-08-15 12 +280 val_280 2010-08-15 11 +280 val_280 2010-08-15 11 +280 val_280 2010-08-15 12 +280 val_280 2010-08-15 12 +281 val_281 2010-08-15 11 +281 val_281 2010-08-15 11 +281 val_281 2010-08-15 12 +281 val_281 2010-08-15 12 +282 val_282 2010-08-15 11 +282 val_282 2010-08-15 11 +282 val_282 2010-08-15 12 +282 val_282 2010-08-15 12 +283 val_283 2010-08-15 11 +283 val_283 2010-08-15 12 +284 val_284 2010-08-15 11 +284 val_284 2010-08-15 12 +285 val_285 2010-08-15 11 +285 val_285 2010-08-15 12 +286 val_286 2010-08-15 11 +286 val_286 2010-08-15 12 +287 val_287 2010-08-15 11 +287 val_287 2010-08-15 12 +288 val_288 2010-08-15 11 +288 val_288 2010-08-15 11 +288 val_288 2010-08-15 12 +288 val_288 2010-08-15 12 +289 val_289 2010-08-15 11 +289 val_289 2010-08-15 12 +291 val_291 2010-08-15 11 +291 val_291 2010-08-15 12 +292 val_292 2010-08-15 11 +292 val_292 2010-08-15 12 +296 val_296 2010-08-15 11 +296 val_296 2010-08-15 12 +298 val_298 2010-08-15 11 +298 val_298 2010-08-15 11 +298 val_298 2010-08-15 11 +298 val_298 2010-08-15 12 +298 val_298 2010-08-15 12 +298 val_298 2010-08-15 12 +302 val_302 2010-08-15 11 +302 val_302 2010-08-15 12 +305 val_305 2010-08-15 11 +305 val_305 2010-08-15 12 +306 val_306 2010-08-15 11 +306 val_306 2010-08-15 12 +307 val_307 2010-08-15 11 +307 val_307 2010-08-15 11 +307 val_307 2010-08-15 12 +307 val_307 2010-08-15 12 +308 val_308 2010-08-15 11 +308 val_308 2010-08-15 12 +309 val_309 2010-08-15 11 +309 val_309 2010-08-15 11 +309 val_309 2010-08-15 12 +309 val_309 2010-08-15 12 +310 val_310 2010-08-15 11 +310 val_310 2010-08-15 12 +311 val_311 2010-08-15 11 +311 val_311 2010-08-15 11 +311 val_311 2010-08-15 11 +311 val_311 2010-08-15 12 +311 val_311 2010-08-15 12 +311 val_311 2010-08-15 12 +315 val_315 2010-08-15 11 +315 val_315 2010-08-15 12 +316 val_316 2010-08-15 11 +316 val_316 2010-08-15 11 +316 val_316 2010-08-15 11 +316 val_316 2010-08-15 12 +316 val_316 2010-08-15 12 +316 val_316 2010-08-15 12 +317 val_317 2010-08-15 11 +317 val_317 2010-08-15 11 +317 val_317 2010-08-15 12 +317 val_317 2010-08-15 12 +318 val_318 2010-08-15 11 +318 val_318 2010-08-15 11 +318 val_318 2010-08-15 11 +318 val_318 2010-08-15 12 +318 val_318 2010-08-15 12 +318 val_318 2010-08-15 12 +321 val_321 2010-08-15 11 +321 val_321 2010-08-15 11 +321 val_321 2010-08-15 12 +321 val_321 2010-08-15 12 +322 val_322 2010-08-15 11 +322 val_322 2010-08-15 11 +322 val_322 2010-08-15 12 +322 val_322 2010-08-15 12 +323 val_323 2010-08-15 11 +323 val_323 2010-08-15 12 +325 val_325 2010-08-15 11 +325 val_325 2010-08-15 11 +325 val_325 2010-08-15 12 +325 val_325 2010-08-15 12 +327 val_327 2010-08-15 11 +327 val_327 2010-08-15 11 +327 val_327 2010-08-15 11 +327 val_327 2010-08-15 12 +327 val_327 2010-08-15 12 +327 val_327 2010-08-15 12 +331 val_331 2010-08-15 11 +331 val_331 2010-08-15 11 +331 val_331 2010-08-15 12 +331 val_331 2010-08-15 12 +332 val_332 2010-08-15 11 +332 val_332 2010-08-15 12 +333 val_333 2010-08-15 11 +333 val_333 2010-08-15 11 +333 val_333 2010-08-15 12 +333 val_333 2010-08-15 12 +335 val_335 2010-08-15 11 +335 val_335 2010-08-15 12 +336 val_336 2010-08-15 11 +336 val_336 2010-08-15 12 +338 val_338 2010-08-15 11 +338 val_338 2010-08-15 12 +339 val_339 2010-08-15 11 +339 val_339 2010-08-15 12 +341 val_341 2010-08-15 11 +341 val_341 2010-08-15 12 +342 val_342 2010-08-15 11 +342 val_342 2010-08-15 11 +342 val_342 2010-08-15 12 +342 val_342 2010-08-15 12 +344 val_344 2010-08-15 11 +344 val_344 2010-08-15 11 +344 val_344 2010-08-15 12 +344 val_344 2010-08-15 12 +345 val_345 2010-08-15 11 +345 val_345 2010-08-15 12 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 12 +348 val_348 2010-08-15 12 +348 val_348 2010-08-15 12 +348 val_348 2010-08-15 12 +348 val_348 2010-08-15 12 +351 val_351 2010-08-15 11 +351 val_351 2010-08-15 12 +353 val_353 2010-08-15 11 +353 val_353 2010-08-15 11 +353 val_353 2010-08-15 12 +353 val_353 2010-08-15 12 +356 val_356 2010-08-15 11 +356 val_356 2010-08-15 12 +360 val_360 2010-08-15 11 +360 val_360 2010-08-15 12 +362 val_362 2010-08-15 11 +362 val_362 2010-08-15 12 +364 val_364 2010-08-15 11 +364 val_364 2010-08-15 12 +365 val_365 2010-08-15 11 +365 val_365 2010-08-15 12 +366 val_366 2010-08-15 11 +366 val_366 2010-08-15 12 +367 val_367 2010-08-15 11 +367 val_367 2010-08-15 11 +367 val_367 2010-08-15 12 +367 val_367 2010-08-15 12 +368 val_368 2010-08-15 11 +368 val_368 2010-08-15 12 +369 val_369 2010-08-15 11 +369 val_369 2010-08-15 11 +369 val_369 2010-08-15 11 +369 val_369 2010-08-15 12 +369 val_369 2010-08-15 12 +369 val_369 2010-08-15 12 +373 val_373 2010-08-15 11 +373 val_373 2010-08-15 12 +374 val_374 2010-08-15 11 +374 val_374 2010-08-15 12 +375 val_375 2010-08-15 11 +375 val_375 2010-08-15 12 +377 val_377 2010-08-15 11 +377 val_377 2010-08-15 12 +378 val_378 2010-08-15 11 +378 val_378 2010-08-15 12 +379 val_379 2010-08-15 11 +379 val_379 2010-08-15 12 +382 val_382 2010-08-15 11 +382 val_382 2010-08-15 11 +382 val_382 2010-08-15 12 +382 val_382 2010-08-15 12 +384 val_384 2010-08-15 11 +384 val_384 2010-08-15 11 +384 val_384 2010-08-15 11 +384 val_384 2010-08-15 12 +384 val_384 2010-08-15 12 +384 val_384 2010-08-15 12 +386 val_386 2010-08-15 11 +386 val_386 2010-08-15 12 +389 val_389 2010-08-15 11 +389 val_389 2010-08-15 12 +392 val_392 2010-08-15 11 +392 val_392 2010-08-15 12 +393 val_393 2010-08-15 11 +393 val_393 2010-08-15 12 +394 val_394 2010-08-15 11 +394 val_394 2010-08-15 12 +395 val_395 2010-08-15 11 +395 val_395 2010-08-15 11 +395 val_395 2010-08-15 12 +395 val_395 2010-08-15 12 +396 val_396 2010-08-15 11 +396 val_396 2010-08-15 11 +396 val_396 2010-08-15 11 +396 val_396 2010-08-15 12 +396 val_396 2010-08-15 12 +396 val_396 2010-08-15 12 +397 val_397 2010-08-15 11 +397 val_397 2010-08-15 11 +397 val_397 2010-08-15 12 +397 val_397 2010-08-15 12 +399 val_399 2010-08-15 11 +399 val_399 2010-08-15 11 +399 val_399 2010-08-15 12 +399 val_399 2010-08-15 12 +400 val_400 2010-08-15 11 +400 val_400 2010-08-15 12 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 12 +401 val_401 2010-08-15 12 +401 val_401 2010-08-15 12 +401 val_401 2010-08-15 12 +401 val_401 2010-08-15 12 +402 val_402 2010-08-15 11 +402 val_402 2010-08-15 12 +403 val_403 2010-08-15 11 +403 val_403 2010-08-15 11 +403 val_403 2010-08-15 11 +403 val_403 2010-08-15 12 +403 val_403 2010-08-15 12 +403 val_403 2010-08-15 12 +404 val_404 2010-08-15 11 +404 val_404 2010-08-15 11 +404 val_404 2010-08-15 12 +404 val_404 2010-08-15 12 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 12 +406 val_406 2010-08-15 12 +406 val_406 2010-08-15 12 +406 val_406 2010-08-15 12 +407 val_407 2010-08-15 11 +407 val_407 2010-08-15 12 +409 val_409 2010-08-15 11 +409 val_409 2010-08-15 11 +409 val_409 2010-08-15 11 +409 val_409 2010-08-15 12 +409 val_409 2010-08-15 12 +409 val_409 2010-08-15 12 +411 val_411 2010-08-15 11 +411 val_411 2010-08-15 12 +413 val_413 2010-08-15 11 +413 val_413 2010-08-15 11 +413 val_413 2010-08-15 12 +413 val_413 2010-08-15 12 +414 val_414 2010-08-15 11 +414 val_414 2010-08-15 11 +414 val_414 2010-08-15 12 +414 val_414 2010-08-15 12 +417 val_417 2010-08-15 11 +417 val_417 2010-08-15 11 +417 val_417 2010-08-15 11 +417 val_417 2010-08-15 12 +417 val_417 2010-08-15 12 +417 val_417 2010-08-15 12 +418 val_418 2010-08-15 11 +418 val_418 2010-08-15 12 +419 val_419 2010-08-15 11 +419 val_419 2010-08-15 12 +421 val_421 2010-08-15 11 +421 val_421 2010-08-15 12 +424 val_424 2010-08-15 11 +424 val_424 2010-08-15 11 +424 val_424 2010-08-15 12 +424 val_424 2010-08-15 12 +427 val_427 2010-08-15 11 +427 val_427 2010-08-15 12 +429 val_429 2010-08-15 11 +429 val_429 2010-08-15 11 +429 val_429 2010-08-15 12 +429 val_429 2010-08-15 12 +430 val_430 2010-08-15 11 +430 val_430 2010-08-15 11 +430 val_430 2010-08-15 11 +430 val_430 2010-08-15 12 +430 val_430 2010-08-15 12 +430 val_430 2010-08-15 12 +431 val_431 2010-08-15 11 +431 val_431 2010-08-15 11 +431 val_431 2010-08-15 11 +431 val_431 2010-08-15 12 +431 val_431 2010-08-15 12 +431 val_431 2010-08-15 12 +432 val_432 2010-08-15 11 +432 val_432 2010-08-15 12 +435 val_435 2010-08-15 11 +435 val_435 2010-08-15 12 +436 val_436 2010-08-15 11 +436 val_436 2010-08-15 12 +437 val_437 2010-08-15 11 +437 val_437 2010-08-15 12 +438 val_438 2010-08-15 11 +438 val_438 2010-08-15 11 +438 val_438 2010-08-15 11 +438 val_438 2010-08-15 12 +438 val_438 2010-08-15 12 +438 val_438 2010-08-15 12 +439 val_439 2010-08-15 11 +439 val_439 2010-08-15 11 +439 val_439 2010-08-15 12 +439 val_439 2010-08-15 12 +443 val_443 2010-08-15 11 +443 val_443 2010-08-15 12 +444 val_444 2010-08-15 11 +444 val_444 2010-08-15 12 +446 val_446 2010-08-15 11 +446 val_446 2010-08-15 12 +448 val_448 2010-08-15 11 +448 val_448 2010-08-15 12 +449 val_449 2010-08-15 11 +449 val_449 2010-08-15 12 +452 val_452 2010-08-15 11 +452 val_452 2010-08-15 12 +453 val_453 2010-08-15 11 +453 val_453 2010-08-15 12 +454 val_454 2010-08-15 11 +454 val_454 2010-08-15 11 +454 val_454 2010-08-15 11 +454 val_454 2010-08-15 12 +454 val_454 2010-08-15 12 +454 val_454 2010-08-15 12 +455 val_455 2010-08-15 11 +455 val_455 2010-08-15 12 +457 val_457 2010-08-15 11 +457 val_457 2010-08-15 12 +458 val_458 2010-08-15 11 +458 val_458 2010-08-15 11 +458 val_458 2010-08-15 12 +458 val_458 2010-08-15 12 +459 val_459 2010-08-15 11 +459 val_459 2010-08-15 11 +459 val_459 2010-08-15 12 +459 val_459 2010-08-15 12 +460 val_460 2010-08-15 11 +460 val_460 2010-08-15 12 +462 val_462 2010-08-15 11 +462 val_462 2010-08-15 11 +462 val_462 2010-08-15 12 +462 val_462 2010-08-15 12 +463 val_463 2010-08-15 11 +463 val_463 2010-08-15 11 +463 val_463 2010-08-15 12 +463 val_463 2010-08-15 12 +466 val_466 2010-08-15 11 +466 val_466 2010-08-15 11 +466 val_466 2010-08-15 11 +466 val_466 2010-08-15 12 +466 val_466 2010-08-15 12 +466 val_466 2010-08-15 12 +467 val_467 2010-08-15 11 +467 val_467 2010-08-15 12 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 12 +468 val_468 2010-08-15 12 +468 val_468 2010-08-15 12 +468 val_468 2010-08-15 12 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 12 +469 val_469 2010-08-15 12 +469 val_469 2010-08-15 12 +469 val_469 2010-08-15 12 +469 val_469 2010-08-15 12 +470 val_470 2010-08-15 11 +470 val_470 2010-08-15 12 +472 val_472 2010-08-15 11 +472 val_472 2010-08-15 12 +475 val_475 2010-08-15 11 +475 val_475 2010-08-15 12 +477 val_477 2010-08-15 11 +477 val_477 2010-08-15 12 +478 val_478 2010-08-15 11 +478 val_478 2010-08-15 11 +478 val_478 2010-08-15 12 +478 val_478 2010-08-15 12 +479 val_479 2010-08-15 11 +479 val_479 2010-08-15 12 +480 val_480 2010-08-15 11 +480 val_480 2010-08-15 11 +480 val_480 2010-08-15 11 +480 val_480 2010-08-15 12 +480 val_480 2010-08-15 12 +480 val_480 2010-08-15 12 +481 val_481 2010-08-15 11 +481 val_481 2010-08-15 12 +482 val_482 2010-08-15 11 +482 val_482 2010-08-15 12 +483 val_483 2010-08-15 11 +483 val_483 2010-08-15 12 +484 val_484 2010-08-15 11 +484 val_484 2010-08-15 12 +485 val_485 2010-08-15 11 +485 val_485 2010-08-15 12 +487 val_487 2010-08-15 11 +487 val_487 2010-08-15 12 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 12 +489 val_489 2010-08-15 12 +489 val_489 2010-08-15 12 +489 val_489 2010-08-15 12 +490 val_490 2010-08-15 11 +490 val_490 2010-08-15 12 +491 val_491 2010-08-15 11 +491 val_491 2010-08-15 12 +492 val_492 2010-08-15 11 +492 val_492 2010-08-15 11 +492 val_492 2010-08-15 12 +492 val_492 2010-08-15 12 +493 val_493 2010-08-15 11 +493 val_493 2010-08-15 12 +494 val_494 2010-08-15 11 +494 val_494 2010-08-15 12 +495 val_495 2010-08-15 11 +495 val_495 2010-08-15 12 +496 val_496 2010-08-15 11 +496 val_496 2010-08-15 12 +497 val_497 2010-08-15 11 +497 val_497 2010-08-15 12 +498 val_498 2010-08-15 11 +498 val_498 2010-08-15 11 +498 val_498 2010-08-15 11 +498 val_498 2010-08-15 12 +498 val_498 2010-08-15 12 +498 val_498 2010-08-15 12 diff --git a/sql/hive/src/test/resources/golden/merge4-9-ad3dc168c8b6f048717e39ab16b0a319 b/sql/hive/src/test/resources/golden/merge4-9-ad3dc168c8b6f048717e39ab16b0a319 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/no from clause-0-b42b408a87b258921240058f880a721a b/sql/hive/src/test/resources/golden/no from clause-0-b42b408a87b258921240058f880a721a new file mode 100644 index 0000000000000..390d344ecb9d3 --- /dev/null +++ b/sql/hive/src/test/resources/golden/no from clause-0-b42b408a87b258921240058f880a721a @@ -0,0 +1 @@ +1 1 -1 diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-0-36f9196395758cebfed837a1c391a1e b/sql/hive/src/test/resources/golden/nullformatCTAS-0-36f9196395758cebfed837a1c391a1e new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-1-b5a31d4cb34218b8de1ac3fed59fa75b b/sql/hive/src/test/resources/golden/nullformatCTAS-1-b5a31d4cb34218b8de1ac3fed59fa75b new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-10-7f4f04b87c7ef9653b4646949b24cf0b b/sql/hive/src/test/resources/golden/nullformatCTAS-10-7f4f04b87c7ef9653b4646949b24cf0b new file mode 100644 index 0000000000000..e74deff51c9ba --- /dev/null +++ b/sql/hive/src/test/resources/golden/nullformatCTAS-10-7f4f04b87c7ef9653b4646949b24cf0b @@ -0,0 +1,10 @@ +1.0 1 +1.0 1 +1.0 1 +1.0 1 +1.0 1 +NULL 1 +NULL NULL +1.0 NULL +1.0 1 +1.0 1 diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-11-4a4c16b53c612d00012d338c97bf5281 b/sql/hive/src/test/resources/golden/nullformatCTAS-11-4a4c16b53c612d00012d338c97bf5281 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-12-7f4f04b87c7ef9653b4646949b24cf0b b/sql/hive/src/test/resources/golden/nullformatCTAS-12-7f4f04b87c7ef9653b4646949b24cf0b new file mode 100644 index 0000000000000..00ebb521970dd --- /dev/null +++ b/sql/hive/src/test/resources/golden/nullformatCTAS-12-7f4f04b87c7ef9653b4646949b24cf0b @@ -0,0 +1,10 @@ +1.0 1 +1.0 1 +1.0 1 +1.0 1 +1.0 1 +fooNull 1 +fooNull fooNull +1.0 fooNull +1.0 1 +1.0 1 diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-13-2e59caa113585495d8684fee69d88bc0 b/sql/hive/src/test/resources/golden/nullformatCTAS-13-2e59caa113585495d8684fee69d88bc0 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-14-ad9fe9d68c2cf492259af4f6167c1b12 b/sql/hive/src/test/resources/golden/nullformatCTAS-14-ad9fe9d68c2cf492259af4f6167c1b12 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-2-aa2bdbd93668dceae43d1a02f2ede68d b/sql/hive/src/test/resources/golden/nullformatCTAS-2-aa2bdbd93668dceae43d1a02f2ede68d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-3-b0057150f237050f38c1efa1f2d6b273 b/sql/hive/src/test/resources/golden/nullformatCTAS-3-b0057150f237050f38c1efa1f2d6b273 new file mode 100644 index 0000000000000..b00bcb3624532 --- /dev/null +++ b/sql/hive/src/test/resources/golden/nullformatCTAS-3-b0057150f237050f38c1efa1f2d6b273 @@ -0,0 +1,6 @@ +a string +b string +c string +d string + +Detailed Table Information Table(tableName:base_tab, dbName:default, owner:animal, createTime:1423973915, lastAccessTime:0, retention:0, sd:StorageDescriptor(cols:[FieldSchema(name:a, type:string, comment:null), FieldSchema(name:b, type:string, comment:null), FieldSchema(name:c, type:string, comment:null), FieldSchema(name:d, type:string, comment:null)], location:file:/tmp/sparkHiveWarehouse2573474017665704744/base_tab, inputFormat:org.apache.hadoop.mapred.TextInputFormat, outputFormat:org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat, compressed:false, numBuckets:-1, serdeInfo:SerDeInfo(name:null, serializationLib:org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, parameters:{serialization.format=1}), bucketCols:[], sortCols:[], parameters:{}, skewedInfo:SkewedInfo(skewedColNames:[], skewedColValues:[], skewedColValueLocationMaps:{}), storedAsSubDirectories:false), partitionKeys:[], parameters:{numFiles=1, transient_lastDdlTime=1423973915, COLUMN_STATS_ACCURATE=true, totalSize=130, numRows=0, rawDataSize=0}, viewOriginalText:null, viewExpandedText:null, tableType:MANAGED_TABLE) diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-4-16c7086f39d6458b6c5cf2479f0473bd b/sql/hive/src/test/resources/golden/nullformatCTAS-4-16c7086f39d6458b6c5cf2479f0473bd new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-5-183d77b734ce6a373de5b3ebe1cd04c9 b/sql/hive/src/test/resources/golden/nullformatCTAS-5-183d77b734ce6a373de5b3ebe1cd04c9 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-6-159fff36b548e00ee952d1df8ef19833 b/sql/hive/src/test/resources/golden/nullformatCTAS-6-159fff36b548e00ee952d1df8ef19833 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-7-46900b082b02ce3e58087d1f41128f65 b/sql/hive/src/test/resources/golden/nullformatCTAS-7-46900b082b02ce3e58087d1f41128f65 new file mode 100644 index 0000000000000..264c973ff7af1 --- /dev/null +++ b/sql/hive/src/test/resources/golden/nullformatCTAS-7-46900b082b02ce3e58087d1f41128f65 @@ -0,0 +1,4 @@ +a string +b string + +Detailed Table Information Table(tableName:null_tab3, dbName:default, owner:animal, createTime:1423973928, lastAccessTime:0, retention:0, sd:StorageDescriptor(cols:[FieldSchema(name:a, type:string, comment:null), FieldSchema(name:b, type:string, comment:null)], location:file:/tmp/sparkHiveWarehouse2573474017665704744/null_tab3, inputFormat:org.apache.hadoop.mapred.TextInputFormat, outputFormat:org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat, compressed:false, numBuckets:-1, serdeInfo:SerDeInfo(name:null, serializationLib:org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, parameters:{serialization.null.format=fooNull, serialization.format=1}), bucketCols:[], sortCols:[], parameters:{}, skewedInfo:SkewedInfo(skewedColNames:[], skewedColValues:[], skewedColValueLocationMaps:{}), storedAsSubDirectories:false), partitionKeys:[], parameters:{numFiles=1, transient_lastDdlTime=1423973928, COLUMN_STATS_ACCURATE=true, totalSize=80, numRows=10, rawDataSize=70}, viewOriginalText:null, viewExpandedText:null, tableType:MANAGED_TABLE) diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-8-7f26cbd6be5631a3acce26f667d1c5d8 b/sql/hive/src/test/resources/golden/nullformatCTAS-8-7f26cbd6be5631a3acce26f667d1c5d8 new file mode 100644 index 0000000000000..881917bcf1c69 --- /dev/null +++ b/sql/hive/src/test/resources/golden/nullformatCTAS-8-7f26cbd6be5631a3acce26f667d1c5d8 @@ -0,0 +1,18 @@ +CREATE TABLE `null_tab3`( + `a` string, + `b` string) +ROW FORMAT DELIMITED + NULL DEFINED AS 'fooNull' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse2573474017665704744/null_tab3' +TBLPROPERTIES ( + 'numFiles'='1', + 'transient_lastDdlTime'='1423973928', + 'COLUMN_STATS_ACCURATE'='true', + 'totalSize'='80', + 'numRows'='10', + 'rawDataSize'='70') diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-9-22e1b3899de7087b39c24d9d8f643b47 b/sql/hive/src/test/resources/golden/nullformatCTAS-9-22e1b3899de7087b39c24d9d8f643b47 new file mode 100644 index 0000000000000..3a2e3f4984a0e --- /dev/null +++ b/sql/hive/src/test/resources/golden/nullformatCTAS-9-22e1b3899de7087b39c24d9d8f643b47 @@ -0,0 +1 @@ +-1 diff --git a/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-0-c6d02549aec166e16bfc44d5905fa33a b/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-0-c6d02549aec166e16bfc44d5905fa33a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-1-a8987ff8c7b9ca95bf8b32314694ed1f b/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-1-a8987ff8c7b9ca95bf8b32314694ed1f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-2-26f54240cf5b909086fc34a34d7fdb56 b/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-2-26f54240cf5b909086fc34a34d7fdb56 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-3-d08d5280027adea681001ad82a5a6974 b/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-3-d08d5280027adea681001ad82a5a6974 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-4-22eb25b5be6daf72a6649adfe5041749 b/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-4-22eb25b5be6daf72a6649adfe5041749 new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-4-22eb25b5be6daf72a6649adfe5041749 @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/udf_reflect2-0-50131c0ba7b7a6b65c789a5a8497bada b/sql/hive/src/test/resources/golden/udf_reflect2-0-50131c0ba7b7a6b65c789a5a8497bada new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_reflect2-0-50131c0ba7b7a6b65c789a5a8497bada @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/udf_reflect2-1-7bec330c7bc6f71cbaf9bf1883d1b184 b/sql/hive/src/test/resources/golden/udf_reflect2-1-7bec330c7bc6f71cbaf9bf1883d1b184 new file mode 100644 index 0000000000000..cd35e5b290db5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_reflect2-1-7bec330c7bc6f71cbaf9bf1883d1b184 @@ -0,0 +1 @@ +reflect2(arg0,method[,arg1[,arg2..]]) calls method of arg0 with reflection diff --git a/sql/hive/src/test/resources/golden/udf_reflect2-2-c5a05379f482215a5a484bed0299bf19 b/sql/hive/src/test/resources/golden/udf_reflect2-2-c5a05379f482215a5a484bed0299bf19 new file mode 100644 index 0000000000000..48ef97292ab62 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_reflect2-2-c5a05379f482215a5a484bed0299bf19 @@ -0,0 +1,3 @@ +reflect2(arg0,method[,arg1[,arg2..]]) calls method of arg0 with reflection +Use this UDF to call Java methods by matching the argument signature + diff --git a/sql/hive/src/test/resources/golden/udf_reflect2-3-effc057c78c00b0af26a4ac0f5f116ca b/sql/hive/src/test/resources/golden/udf_reflect2-3-effc057c78c00b0af26a4ac0f5f116ca new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/udf_reflect2-4-73d466e70e96e9e5f0cd373b37d4e1f4 b/sql/hive/src/test/resources/golden/udf_reflect2-4-73d466e70e96e9e5f0cd373b37d4e1f4 new file mode 100644 index 0000000000000..176ea0358d7ea --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_reflect2-4-73d466e70e96e9e5f0cd373b37d4e1f4 @@ -0,0 +1,5 @@ +238 -18 238 238 238 238.0 238.0 238 val_238 val_238_concat false true false false false val_238 -1 -1 VALUE_238 al_238 al_2 VAL_238 val_238 2013-02-15 19:41:20 113 1 5 19 41 20 1360986080000 +86 86 86 86 86 86.0 86.0 86 val_86 val_86_concat true true true true true val_86 -1 -1 VALUE_86 al_86 al_8 VAL_86 val_86 2013-02-15 19:41:20 113 1 5 19 41 20 1360986080000 +311 55 311 311 311 311.0 311.0 311 val_311 val_311_concat false true false false false val_311 5 6 VALUE_311 al_311 al_3 VAL_311 val_311 2013-02-15 19:41:20 113 1 5 19 41 20 1360986080000 +27 27 27 27 27 27.0 27.0 27 val_27 val_27_concat false true false false false val_27 -1 -1 VALUE_27 al_27 al_2 VAL_27 val_27 2013-02-15 19:41:20 113 1 5 19 41 20 1360986080000 +165 -91 165 165 165 165.0 165.0 165 val_165 val_165_concat false true false false false val_165 4 4 VALUE_165 al_165 al_1 VAL_165 val_165 2013-02-15 19:41:20 113 1 5 19 41 20 1360986080000 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala index ba391293884bd..0270e63557963 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -17,10 +17,8 @@ package org.apache.spark.sql -import org.scalatest.FunSuite +import scala.collection.JavaConversions._ -import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ @@ -55,9 +53,36 @@ class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer matches the expected result. * @param rdd the [[DataFrame]] to be executed - * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ]. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. */ protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = { + QueryTest.checkAnswer(rdd, expectedAnswer) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = { + checkAnswer(rdd, Seq(expectedAnswer)) + } + + def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = { + test(sqlString) { + checkAnswer(sqlContext.sql(sqlString), expectedAnswer) + } + } +} + +object QueryTest { + /** + * Runs the plan and makes sure the answer matches the expected result. + * If there was exception during the execution or the contents of the DataFrame does not + * match the expected result, an error message will be returned. Otherwise, a [[None]] will + * be returned. + * @param rdd the [[DataFrame]] to be executed + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Option[String] = { val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. @@ -73,18 +98,20 @@ class QueryTest extends PlanTest { } val sparkAnswer = try rdd.collect().toSeq catch { case e: Exception => - fail( + val errorMessage = s""" |Exception thrown while executing query: |${rdd.queryExecution} |== Exception == |$e |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} - """.stripMargin) + """.stripMargin + return Some(errorMessage) } if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { - fail(s""" + val errorMessage = + s""" |Results do not match for query: |${rdd.logicalPlan} |== Analyzed Plan == @@ -93,22 +120,21 @@ class QueryTest extends PlanTest { |${rdd.queryExecution.executedPlan} |== Results == |${sideBySide( - s"== Correct Answer - ${expectedAnswer.size} ==" +: - prepareAnswer(expectedAnswer).map(_.toString), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} - """.stripMargin) + s"== Correct Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer).map(_.toString), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} + """.stripMargin + return Some(errorMessage) } - } - protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = { - checkAnswer(rdd, Seq(expectedAnswer)) + return None } - def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = { - test(sqlString) { - checkAnswer(sqlContext.sql(sqlString), expectedAnswer) + def checkAnswer(rdd: DataFrame, expectedAnswer: java.util.List[Row]): String = { + checkAnswer(rdd, expectedAnswer.toSeq) match { + case Some(errorMessage) => errorMessage + case None => null } } - } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 081d94b6fc020..98f1c0e69e29d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, ExprId} +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, ExprId} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.scalatest.FunSuite @@ -35,11 +35,11 @@ class PlanTest extends FunSuite { * we must normalize them to check if two different queries are identical. */ protected def normalizeExprIds(plan: LogicalPlan) = { - val list = plan.flatMap(_.expressions.flatMap(_.references).map(_.exprId.id)) - val minId = if (list.isEmpty) 0 else list.min plan transformAllExpressions { case a: AttributeReference => - AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(a.exprId.id - minId)) + AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) + case a: Alias => + Alias(a.child, a.name)(exprId = ExprId(0)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 7c8b5205e239e..44d24273e722a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest} import org.apache.spark.storage.RDDBlockId class CachedTableSuite extends QueryTest { @@ -96,7 +96,7 @@ class CachedTableSuite extends QueryTest { cacheTable("test") sql("SELECT * FROM test").collect() sql("DROP TABLE test") - intercept[org.apache.hadoop.hive.ql.metadata.InvalidTableException] { + intercept[AnalysisException] { sql("SELECT * FROM test").collect() } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala new file mode 100644 index 0000000000000..968557c9c4686 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.{OutputStream, PrintStream} + +import scala.util.Try + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.{AnalysisException, QueryTest} + + +class ErrorPositionSuite extends QueryTest with BeforeAndAfter { + + before { + Seq((1, 1, 1)).toDF("a", "a", "b").registerTempTable("dupAttributes") + } + + positionTest("ambiguous attribute reference 1", + "SELECT a from dupAttributes", "a") + + positionTest("ambiguous attribute reference 2", + "SELECT a, b from dupAttributes", "a") + + positionTest("ambiguous attribute reference 3", + "SELECT b, a from dupAttributes", "a") + + positionTest("unresolved attribute 1", + "SELECT x FROM src", "x") + + positionTest("unresolved attribute 2", + "SELECT x FROM src", "x") + + positionTest("unresolved attribute 3", + "SELECT key, x FROM src", "x") + + positionTest("unresolved attribute 4", + """SELECT key, + |x FROM src + """.stripMargin, "x") + + positionTest("unresolved attribute 5", + """SELECT key, + | x FROM src + """.stripMargin, "x") + + positionTest("unresolved attribute 6", + """SELECT key, + | + | 1 + x FROM src + """.stripMargin, "x") + + positionTest("unresolved attribute 7", + """SELECT key, + | + | 1 + x + 1 FROM src + """.stripMargin, "x") + + positionTest("multi-char unresolved attribute", + """SELECT key, + | + | 1 + abcd + 1 FROM src + """.stripMargin, "abcd") + + positionTest("unresolved attribute group by", + """SELECT key FROM src GROUP BY + |x + """.stripMargin, "x") + + positionTest("unresolved attribute order by", + """SELECT key FROM src ORDER BY + |x + """.stripMargin, "x") + + positionTest("unresolved attribute where", + """SELECT key FROM src + |WHERE x = true + """.stripMargin, "x") + + positionTest("unresolved attribute backticks", + "SELECT `x` FROM src", "`x`") + + positionTest("parse error", + "SELECT WHERE", "WHERE") + + positionTest("bad relation", + "SELECT * FROM badTable", "badTable") + + ignore("other expressions") { + positionTest("bad addition", + "SELECT 1 + array(1)", "1 + array") + } + + /** Hive can be very noisy, messing up the output of our tests. */ + private def quietly[A](f: => A): A = { + val origErr = System.err + val origOut = System.out + try { + System.setErr(new PrintStream(new OutputStream { + def write(b: Int) = {} + })) + System.setOut(new PrintStream(new OutputStream { + def write(b: Int) = {} + })) + + f + } finally { + System.setErr(origErr) + System.setOut(origOut) + } + } + + /** + * Creates a test that checks to see if the error thrown when analyzing a given query includes + * the location of the given token in the query string. + * + * @param name the name of the test + * @param query the query to analyze + * @param token a unique token in the string that should be indicated by the exception + */ + def positionTest(name: String, query: String, token: String) = { + def parseTree = + Try(quietly(HiveQl.dumpTree(HiveQl.getAst(query)))).getOrElse("") + + test(name) { + val error = intercept[AnalysisException] { + quietly(sql(query)) + } + + assert(!error.getMessage.contains("Seq(")) + assert(!error.getMessage.contains("List(")) + + val (line, expectedLineNum) = query.split("\n").zipWithIndex.collect { + case (l, i) if l.contains(token) => (l, i + 1) + }.headOption.getOrElse(sys.error(s"Invalid test. Token $token not in $query")) + val actualLine = error.line.getOrElse { + fail( + s"line not returned for error '${error.getMessage}' on token $token\n$parseTree" + ) + } + assert(actualLine === expectedLineNum, "wrong line") + + val expectedStart = line.indexOf(token) + val actualStart = error.startPosition.getOrElse { + fail( + s"start not returned for error on token $token\n" + + HiveQl.dumpTree(HiveQl.getAst(query)) + ) + } + assert(expectedStart === actualStart, + s"""Incorrect start position. + |== QUERY == + |$query + | + |== AST == + |$parseTree + | + |Actual: $actualStart, Expected: $expectedStart + |$line + |${" " * actualStart}^ + |0123456789 123456789 1234567890 + | 2 3 + """.stripMargin) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 2d3ff680125ad..09bbd5c867e4e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.hive import java.util -import java.sql.Date import java.util.{Locale, TimeZone} import org.apache.hadoop.hive.ql.udf.UDAFPercentile @@ -76,7 +75,7 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { Literal(0.asInstanceOf[Float]) :: Literal(0.asInstanceOf[Double]) :: Literal("0") :: - Literal(new Date(2014, 9, 23)) :: + Literal(new java.sql.Date(114, 8, 23)) :: Literal(Decimal(BigDecimal(123.123))) :: Literal(new java.sql.Timestamp(123123)) :: Literal(Array[Byte](1,2,3)) :: @@ -143,7 +142,6 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { case (r1: Array[Byte], r2: Array[Byte]) if r1 != null && r2 != null && r1.length == r2.length => r1.zip(r2).map { case (b1, b2) => assert(b1 === b2) } - case (r1: Date, r2: Date) => assert(r1.compareTo(r2) === 0) case (r1, r2) => assert(r1 === r2) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index aad48ada52642..fa8e11ffec2b4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.hive +import org.apache.spark.sql.hive.test.TestHive import org.scalatest.FunSuite import org.apache.spark.sql.test.ExamplePointUDT @@ -36,4 +37,11 @@ class HiveMetastoreCatalogSuite extends FunSuite { assert(HiveMetastoreTypes.toMetastoreType(udt) === HiveMetastoreTypes.toMetastoreType(udt.sqlType)) } + + test("duplicated metastore relations") { + import TestHive.implicits._ + val df = TestHive.sql("SELECT * FROM src") + println(df.queryExecution) + df.as('a).join(df.as('b), $"a.key" === $"b.key") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala new file mode 100644 index 0000000000000..7ff5719adb3ab --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.parquet.ParquetTest +import org.apache.spark.sql.{QueryTest, SQLConf} + +case class Cases(lower: String, UPPER: String) + +class HiveParquetSuite extends QueryTest with ParquetTest { + val sqlContext = TestHive + + import sqlContext._ + + def run(prefix: String): Unit = { + test(s"$prefix: Case insensitive attribute names") { + withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") { + val expected = (1 to 4).map(i => Row(i.toString)) + checkAnswer(sql("SELECT upper FROM cases"), expected) + checkAnswer(sql("SELECT LOWER FROM cases"), expected) + } + } + + test(s"$prefix: SELECT on Parquet table") { + val data = (1 to 4).map(i => (i, s"val_$i")) + withParquetTable(data, "t") { + checkAnswer(sql("SELECT * FROM t"), data.map(Row.fromTuple)) + } + } + + test(s"$prefix: Simple column projection + filter on Parquet table") { + withParquetTable((1 to 4).map(i => (i % 2 == 0, i, s"val_$i")), "t") { + checkAnswer( + sql("SELECT `_1`, `_3` FROM t WHERE `_1` = true"), + Seq(Row(true, "val_2"), Row(true, "val_4"))) + } + } + + test(s"$prefix: Converting Hive to Parquet Table via saveAsParquetFile") { + withTempPath { dir => + sql("SELECT * FROM src").saveAsParquetFile(dir.getCanonicalPath) + parquetFile(dir.getCanonicalPath).registerTempTable("p") + withTempTable("p") { + checkAnswer( + sql("SELECT * FROM src ORDER BY key"), + sql("SELECT * from p ORDER BY key").collect().toSeq) + } + } + } + + test(s"$prefix: INSERT OVERWRITE TABLE Parquet table") { + withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") { + withTempPath { file => + sql("SELECT * FROM t LIMIT 1").saveAsParquetFile(file.getCanonicalPath) + parquetFile(file.getCanonicalPath).registerTempTable("p") + withTempTable("p") { + // let's do three overwrites for good measure + sql("INSERT OVERWRITE TABLE p SELECT * FROM t") + sql("INSERT OVERWRITE TABLE p SELECT * FROM t") + sql("INSERT OVERWRITE TABLE p SELECT * FROM t") + checkAnswer(sql("SELECT * FROM p"), sql("SELECT * FROM t").collect().toSeq) + } + } + } + } + } + + withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") { + run("Parquet data source enabled") + } + + withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "false") { + run("Parquet data source disabled") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 4dd96bd5a1b77..d4b175fa443a4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -19,7 +19,11 @@ package org.apache.spark.sql.hive import java.io.File +import org.scalatest.BeforeAndAfter + import com.google.common.io.Files + +import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.{QueryTest, _} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.types._ @@ -29,13 +33,22 @@ import org.apache.spark.sql.hive.test.TestHive._ case class TestData(key: Int, value: String) -class InsertIntoHiveTableSuite extends QueryTest { +class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { + import org.apache.spark.sql.hive.test.TestHive.implicits._ + val testData = TestHive.sparkContext.parallelize( - (1 to 100).map(i => TestData(i, i.toString))) - testData.registerTempTable("testData") + (1 to 100).map(i => TestData(i, i.toString))).toDF() + + before { + // Since every we are doing tests for DDL statements, + // it is better to reset before every test. + TestHive.reset() + // Register the testData, which will be used in every test. + testData.registerTempTable("testData") + } test("insertInto() HiveTable") { - createTable[TestData]("createAndInsertTest") + sql("CREATE TABLE createAndInsertTest (key int, value string)") // Add some data. testData.insertInto("createAndInsertTest") @@ -43,7 +56,7 @@ class InsertIntoHiveTableSuite extends QueryTest { // Make sure the table has also been updated. checkAnswer( sql("SELECT * FROM createAndInsertTest"), - testData.collect().toSeq.map(Row.fromTuple) + testData.collect().toSeq ) // Add more data. @@ -52,7 +65,7 @@ class InsertIntoHiveTableSuite extends QueryTest { // Make sure the table has been updated. checkAnswer( sql("SELECT * FROM createAndInsertTest"), - testData.toDataFrame.collect().toSeq ++ testData.toDataFrame.collect().toSeq + testData.toDF().collect().toSeq ++ testData.toDF().collect().toSeq ) // Now overwrite. @@ -61,28 +74,30 @@ class InsertIntoHiveTableSuite extends QueryTest { // Make sure the registered table has also been updated. checkAnswer( sql("SELECT * FROM createAndInsertTest"), - testData.collect().toSeq.map(Row.fromTuple) + testData.collect().toSeq ) } test("Double create fails when allowExisting = false") { - createTable[TestData]("doubleCreateAndInsertTest") + sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") - intercept[org.apache.hadoop.hive.ql.metadata.HiveException] { - createTable[TestData]("doubleCreateAndInsertTest", allowExisting = false) - } + val message = intercept[QueryExecutionException] { + sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") + }.getMessage + + println("message!!!!" + message) } test("Double create does not fail when allowExisting = true") { - createTable[TestData]("createAndInsertTest") - createTable[TestData]("createAndInsertTest") + sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") + sql("CREATE TABLE IF NOT EXISTS doubleCreateAndInsertTest (key int, value string)") } test("SPARK-4052: scala.collection.Map as value type of MapType") { val schema = StructType(StructField("m", MapType(StringType, StringType), true) :: Nil) val rowRDD = TestHive.sparkContext.parallelize( (1 to 100).map(i => Row(scala.collection.mutable.HashMap(s"key$i" -> s"value$i")))) - val df = applySchema(rowRDD, schema) + val df = TestHive.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithMapValue") sql("CREATE TABLE hiveTableWithMapValue(m MAP )") sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") @@ -96,7 +111,7 @@ class InsertIntoHiveTableSuite extends QueryTest { } test("SPARK-4203:random partition directory order") { - createTable[TestData]("tmp_table") + sql("CREATE TABLE tmp_table (key int, value string)") val tmpDir = Files.createTempDir() sql(s"CREATE TABLE table_with_partition(c1 string) PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string) location '${tmpDir.toURI.toString}' ") sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='1') SELECT 'blarr' FROM tmp_table") @@ -127,7 +142,7 @@ class InsertIntoHiveTableSuite extends QueryTest { val schema = StructType(Seq( StructField("a", ArrayType(StringType, containsNull = false)))) val rowRDD = TestHive.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i")))) - val df = applySchema(rowRDD, schema) + val df = TestHive.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithArrayValue") sql("CREATE TABLE hiveTableWithArrayValue(a Array )") sql("INSERT OVERWRITE TABLE hiveTableWithArrayValue SELECT a FROM tableWithArrayValue") @@ -144,7 +159,7 @@ class InsertIntoHiveTableSuite extends QueryTest { StructField("m", MapType(StringType, StringType, valueContainsNull = false)))) val rowRDD = TestHive.sparkContext.parallelize( (1 to 100).map(i => Row(Map(s"key$i" -> s"value$i")))) - val df = applySchema(rowRDD, schema) + val df = TestHive.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithMapValue") sql("CREATE TABLE hiveTableWithMapValue(m Map )") sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") @@ -161,7 +176,7 @@ class InsertIntoHiveTableSuite extends QueryTest { StructField("s", StructType(Seq(StructField("f", StringType, nullable = false)))))) val rowRDD = TestHive.sparkContext.parallelize( (1 to 100).map(i => Row(Row(s"value$i")))) - val df = applySchema(rowRDD, schema) + val df = TestHive.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithStructValue") sql("CREATE TABLE hiveTableWithStructValue(s Struct )") sql("INSERT OVERWRITE TABLE hiveTableWithStructValue SELECT s FROM tableWithStructValue") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala new file mode 100644 index 0000000000000..e12a6c21ccac4 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -0,0 +1,81 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.hive + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.Row + +class ListTablesSuite extends QueryTest with BeforeAndAfterAll { + + import org.apache.spark.sql.hive.test.TestHive.implicits._ + + val df = + sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDF("key", "value") + + override def beforeAll(): Unit = { + // The catalog in HiveContext is a case insensitive one. + catalog.registerTable(Seq("ListTablesSuiteTable"), df.logicalPlan) + catalog.registerTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable"), df.logicalPlan) + sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)") + sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB") + sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)") + } + + override def afterAll(): Unit = { + catalog.unregisterTable(Seq("ListTablesSuiteTable")) + catalog.unregisterTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable")) + sql("DROP TABLE IF EXISTS HiveListTablesSuiteTable") + sql("DROP TABLE IF EXISTS ListTablesSuiteDB.HiveInDBListTablesSuiteTable") + sql("DROP DATABASE IF EXISTS ListTablesSuiteDB") + } + + test("get all tables of current database") { + Seq(tables(), sql("SHOW TABLes")).foreach { + case allTables => + // We are using default DB. + checkAnswer( + allTables.filter("tableName = 'listtablessuitetable'"), + Row("listtablessuitetable", true)) + assert(allTables.filter("tableName = 'indblisttablessuitetable'").count() === 0) + checkAnswer( + allTables.filter("tableName = 'hivelisttablessuitetable'"), + Row("hivelisttablessuitetable", false)) + assert(allTables.filter("tableName = 'hiveindblisttablessuitetable'").count() === 0) + } + } + + test("getting all tables with a database name") { + Seq(tables("listtablessuiteDb"), sql("SHOW TABLes in listTablesSuitedb")).foreach { + case allTables => + checkAnswer( + allTables.filter("tableName = 'listtablessuitetable'"), + Row("listtablessuitetable", true)) + checkAnswer( + allTables.filter("tableName = 'indblisttablessuitetable'"), + Row("indblisttablessuitetable", true)) + assert(allTables.filter("tableName = 'hivelisttablessuitetable'").count() === 0) + checkAnswer( + allTables.filter("tableName = 'hiveindblisttablessuitetable'"), + Row("hiveindblisttablessuitetable", false)) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 85795acb658e2..be8bd95802b45 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -23,26 +23,33 @@ import org.scalatest.BeforeAndAfterEach import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.metastore.TableType +import org.apache.hadoop.hive.ql.metadata.Table +import org.apache.hadoop.mapred.InvalidInputException import org.apache.spark.sql.catalyst.util import org.apache.spark.sql._ import org.apache.spark.util.Utils import org.apache.spark.sql.types._ - -/* Implicits */ import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.parquet.ParquetRelation2 +import org.apache.spark.sql.sources.LogicalRelation + +import scala.collection.mutable.ArrayBuffer /** * Tests for persisting tables created though the data sources API into the metastore. */ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { + override def afterEach(): Unit = { reset() - if (ctasPath.exists()) Utils.deleteRecursively(ctasPath) + if (tempPath.exists()) Utils.deleteRecursively(tempPath) } val filePath = Utils.getSparkClassLoader.getResource("sample.json").getFile - var ctasPath: File = util.getTempFilePath("jsonCTAS").getCanonicalFile + var tempPath: File = util.getTempFilePath("jsonCTAS").getCanonicalFile test ("persistent JSON table") { sql( @@ -151,7 +158,8 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { test("check change without refresh") { val tempDir = File.createTempFile("sparksql", "json") tempDir.delete() - sparkContext.parallelize(("a", "b") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath) + sparkContext.parallelize(("a", "b") :: Nil).toDF() + .toJSON.saveAsTextFile(tempDir.getCanonicalPath) sql( s""" @@ -167,7 +175,8 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { Row("a", "b")) FileUtils.deleteDirectory(tempDir) - sparkContext.parallelize(("a1", "b1", "c1") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath) + sparkContext.parallelize(("a1", "b1", "c1") :: Nil).toDF() + .toJSON.saveAsTextFile(tempDir.getCanonicalPath) // Schema is cached so the new column does not show. The updated values in existing columns // will show. @@ -175,7 +184,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { sql("SELECT * FROM jsonTable"), Row("a1", "b1")) - refreshTable("jsonTable") + sql("REFRESH TABLE jsonTable") // Check that the refresh worked checkAnswer( @@ -187,7 +196,8 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { test("drop, change, recreate") { val tempDir = File.createTempFile("sparksql", "json") tempDir.delete() - sparkContext.parallelize(("a", "b") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath) + sparkContext.parallelize(("a", "b") :: Nil).toDF() + .toJSON.saveAsTextFile(tempDir.getCanonicalPath) sql( s""" @@ -203,7 +213,8 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { Row("a", "b")) FileUtils.deleteDirectory(tempDir) - sparkContext.parallelize(("a", "b", "c") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath) + sparkContext.parallelize(("a", "b", "c") :: Nil).toDF() + .toJSON.saveAsTextFile(tempDir.getCanonicalPath) sql("DROP TABLE jsonTable") @@ -267,7 +278,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { |CREATE TABLE ctasJsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( - | path '${ctasPath}' + | path '${tempPath}' |) AS |SELECT * FROM jsonTable """.stripMargin) @@ -294,19 +305,19 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { |CREATE TABLE ctasJsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( - | path '${ctasPath}' + | path '${tempPath}' |) AS |SELECT * FROM jsonTable """.stripMargin) - // Create the table again should trigger a AlreadyExistsException. - val message = intercept[RuntimeException] { + // Create the table again should trigger a AnalysisException. + val message = intercept[AnalysisException] { sql( s""" |CREATE TABLE ctasJsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( - | path '${ctasPath}' + | path '${tempPath}' |) AS |SELECT * FROM jsonTable """.stripMargin) @@ -322,7 +333,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { |CREATE TABLE IF NOT EXISTS ctasJsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( - | path '${ctasPath}' + | path '${tempPath}' |) AS |SELECT a FROM jsonTable """.stripMargin) @@ -348,9 +359,6 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { |) """.stripMargin) - new Path("/Users/yhuai/Desktop/whatever") - - val expectedPath = catalog.hiveDefaultTableFilePath("ctasJsonTable") val filesystemPath = new Path(expectedPath) val fs = filesystemPath.getFileSystem(sparkContext.hadoopConfiguration) @@ -361,9 +369,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { s""" |CREATE TABLE ctasJsonTable |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | - |) AS + |AS |SELECT * FROM jsonTable """.stripMargin) @@ -402,38 +408,352 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { sql("DROP TABLE jsonTable").collect().foreach(println) } - test("save and load table") { + test("SPARK-5839 HiveMetastoreCatalog does not recognize table aliases of data source tables.") { val originalDefaultSource = conf.defaultDataSourceName - conf.setConf("spark.sql.default.datasource", "org.apache.spark.sql.json") val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) val df = jsonRDD(rdd) + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + // Save the df as a managed table (by not specifiying the path). + df.saveAsTable("savedJsonTable") + + checkAnswer( + sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), + (1 to 4).map(i => Row(i, s"str${i}"))) + + checkAnswer( + sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), + (6 to 10).map(i => Row(i, s"str${i}"))) + + invalidateTable("savedJsonTable") + + checkAnswer( + sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), + (1 to 4).map(i => Row(i, s"str${i}"))) + + checkAnswer( + sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), + (6 to 10).map(i => Row(i, s"str${i}"))) + + // Drop table will also delete the data. + sql("DROP TABLE savedJsonTable") + + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + } + + test("save table") { + val originalDefaultSource = conf.defaultDataSourceName + + val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) + val df = jsonRDD(rdd) + + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + // Save the df as a managed table (by not specifiying the path). df.saveAsTable("savedJsonTable") checkAnswer( sql("SELECT * FROM savedJsonTable"), df.collect()) - createTable("createdJsonTable", catalog.hiveDefaultTableFilePath("savedJsonTable"), false) + // Right now, we cannot append to an existing JSON table. + intercept[RuntimeException] { + df.saveAsTable("savedJsonTable", SaveMode.Append) + } + + // We can overwrite it. + df.saveAsTable("savedJsonTable", SaveMode.Overwrite) + checkAnswer( + sql("SELECT * FROM savedJsonTable"), + df.collect()) + + // When the save mode is Ignore, we will do nothing when the table already exists. + df.select("b").saveAsTable("savedJsonTable", SaveMode.Ignore) + assert(df.schema === table("savedJsonTable").schema) + checkAnswer( + sql("SELECT * FROM savedJsonTable"), + df.collect()) + + // Drop table will also delete the data. + sql("DROP TABLE savedJsonTable") + intercept[InvalidInputException] { + jsonFile(catalog.hiveDefaultTableFilePath("savedJsonTable")) + } + + // Create an external table by specifying the path. + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + df.saveAsTable( + "savedJsonTable", + "org.apache.spark.sql.json", + SaveMode.Append, + Map("path" -> tempPath.toString)) + checkAnswer( + sql("SELECT * FROM savedJsonTable"), + df.collect()) + + // Data should not be deleted after we drop the table. + sql("DROP TABLE savedJsonTable") + checkAnswer( + jsonFile(tempPath.toString), + df.collect()) + + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + } + + test("create external table") { + val originalDefaultSource = conf.defaultDataSourceName + + val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) + val df = jsonRDD(rdd) + + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + df.saveAsTable( + "savedJsonTable", + "org.apache.spark.sql.json", + SaveMode.Append, + Map("path" -> tempPath.toString)) + + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + createExternalTable("createdJsonTable", tempPath.toString) assert(table("createdJsonTable").schema === df.schema) checkAnswer( sql("SELECT * FROM createdJsonTable"), df.collect()) - val message = intercept[RuntimeException] { - createTable("createdJsonTable", filePath.toString, false) + var message = intercept[AnalysisException] { + createExternalTable("createdJsonTable", filePath.toString) }.getMessage assert(message.contains("Table createdJsonTable already exists."), "We should complain that ctasJsonTable already exists") - createTable("createdJsonTable", filePath.toString, true) - // createdJsonTable should be not changed. - assert(table("createdJsonTable").schema === df.schema) + // Data should not be deleted. + sql("DROP TABLE createdJsonTable") checkAnswer( - sql("SELECT * FROM createdJsonTable"), + jsonFile(tempPath.toString), df.collect()) - conf.setConf("spark.sql.default.datasource", originalDefaultSource) + // Try to specify the schema. + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + val schema = StructType(StructField("b", StringType, true) :: Nil) + createExternalTable( + "createdJsonTable", + "org.apache.spark.sql.json", + schema, + Map("path" -> tempPath.toString)) + checkAnswer( + sql("SELECT * FROM createdJsonTable"), + sql("SELECT b FROM savedJsonTable").collect()) + + sql("DROP TABLE createdJsonTable") + + message = intercept[RuntimeException] { + createExternalTable( + "createdJsonTable", + "org.apache.spark.sql.json", + schema, + Map.empty[String, String]) + }.getMessage + assert( + message.contains("'path' must be specified for json data."), + "We should complain that path is not specified.") + + sql("DROP TABLE savedJsonTable") + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + } + + if (HiveShim.version == "0.13.1") { + test("scan a parquet table created through a CTAS statement") { + val originalConvertMetastore = getConf("spark.sql.hive.convertMetastoreParquet", "true") + val originalUseDataSource = getConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + setConf("spark.sql.hive.convertMetastoreParquet", "true") + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + + val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) + jsonRDD(rdd).registerTempTable("jt") + sql( + """ + |create table test_parquet_ctas STORED AS parquET + |AS select tmp.a from jt tmp where tmp.a < 5 + """.stripMargin) + + checkAnswer( + sql(s"SELECT a FROM test_parquet_ctas WHERE a > 2 "), + Row(3) :: Row(4) :: Nil + ) + + table("test_parquet_ctas").queryExecution.optimizedPlan match { + case LogicalRelation(p: ParquetRelation2) => // OK + case _ => + fail( + "test_parquet_ctas should be converted to " + + s"${classOf[ParquetRelation2].getCanonicalName}") + } + + // Clenup and reset confs. + sql("DROP TABLE IF EXISTS jt") + sql("DROP TABLE IF EXISTS test_parquet_ctas") + setConf("spark.sql.hive.convertMetastoreParquet", originalConvertMetastore) + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalUseDataSource) + } + } + + test("Pre insert nullability check (ArrayType)") { + val df1 = + createDataFrame(Tuple1(Seq(Int.box(1), null.asInstanceOf[Integer])) :: Nil).toDF("a") + val expectedSchema1 = + StructType( + StructField("a", ArrayType(IntegerType, containsNull = true), nullable = true) :: Nil) + assert(df1.schema === expectedSchema1) + df1.saveAsTable("arrayInParquet", "parquet", SaveMode.Overwrite) + + val df2 = + createDataFrame(Tuple1(Seq(2, 3)) :: Nil).toDF("a") + val expectedSchema2 = + StructType( + StructField("a", ArrayType(IntegerType, containsNull = false), nullable = true) :: Nil) + assert(df2.schema === expectedSchema2) + df2.insertInto("arrayInParquet", overwrite = false) + createDataFrame(Tuple1(Seq(4, 5)) :: Nil).toDF("a") + .saveAsTable("arrayInParquet", SaveMode.Append) // This one internally calls df2.insertInto. + createDataFrame(Tuple1(Seq(Int.box(6), null.asInstanceOf[Integer])) :: Nil).toDF("a") + .saveAsTable("arrayInParquet", "parquet", SaveMode.Append) + refreshTable("arrayInParquet") + + checkAnswer( + sql("SELECT a FROM arrayInParquet"), + Row(ArrayBuffer(1, null)) :: + Row(ArrayBuffer(2, 3)) :: + Row(ArrayBuffer(4, 5)) :: + Row(ArrayBuffer(6, null)) :: Nil) + + sql("DROP TABLE arrayInParquet") + } + + test("Pre insert nullability check (MapType)") { + val df1 = + createDataFrame(Tuple1(Map(1 -> null.asInstanceOf[Integer])) :: Nil).toDF("a") + val mapType1 = MapType(IntegerType, IntegerType, valueContainsNull = true) + val expectedSchema1 = + StructType( + StructField("a", mapType1, nullable = true) :: Nil) + assert(df1.schema === expectedSchema1) + df1.saveAsTable("mapInParquet", "parquet", SaveMode.Overwrite) + + val df2 = + createDataFrame(Tuple1(Map(2 -> 3)) :: Nil).toDF("a") + val mapType2 = MapType(IntegerType, IntegerType, valueContainsNull = false) + val expectedSchema2 = + StructType( + StructField("a", mapType2, nullable = true) :: Nil) + assert(df2.schema === expectedSchema2) + df2.insertInto("mapInParquet", overwrite = false) + createDataFrame(Tuple1(Map(4 -> 5)) :: Nil).toDF("a") + .saveAsTable("mapInParquet", SaveMode.Append) // This one internally calls df2.insertInto. + createDataFrame(Tuple1(Map(6 -> null.asInstanceOf[Integer])) :: Nil).toDF("a") + .saveAsTable("mapInParquet", "parquet", SaveMode.Append) + refreshTable("mapInParquet") + + checkAnswer( + sql("SELECT a FROM mapInParquet"), + Row(Map(1 -> null)) :: + Row(Map(2 -> 3)) :: + Row(Map(4 -> 5)) :: + Row(Map(6 -> null)) :: Nil) + + sql("DROP TABLE mapInParquet") + } + + test("SPARK-6024 wide schema support") { + // We will need 80 splits for this schema if the threshold is 4000. + val schema = StructType((1 to 5000).map(i => StructField(s"c_${i}", StringType, true))) + assert( + schema.json.size > conf.schemaStringLengthThreshold, + "To correctly test the fix of SPARK-6024, the value of " + + s"spark.sql.sources.schemaStringLengthThreshold needs to be less than ${schema.json.size}") + // Manually create a metastore data source table. + catalog.createDataSourceTable( + tableName = "wide_schema", + userSpecifiedSchema = Some(schema), + provider = "json", + options = Map("path" -> "just a dummy path"), + isExternal = false) + + invalidateTable("wide_schema") + + val actualSchema = table("wide_schema").schema + assert(schema === actualSchema) + } + + test("SPARK-6655 still support a schema stored in spark.sql.sources.schema") { + val tableName = "spark6655" + val schema = StructType(StructField("int", IntegerType, true) :: Nil) + // Manually create the metadata in metastore. + val tbl = new Table("default", tableName) + tbl.setProperty("spark.sql.sources.provider", "json") + tbl.setProperty("spark.sql.sources.schema", schema.json) + tbl.setProperty("EXTERNAL", "FALSE") + tbl.setTableType(TableType.MANAGED_TABLE) + tbl.setSerdeParam("path", catalog.hiveDefaultTableFilePath(tableName)) + catalog.synchronized { + catalog.client.createTable(tbl) + } + + invalidateTable(tableName) + val actualSchema = table(tableName).schema + assert(schema === actualSchema) + sql(s"drop table $tableName") + } + + + test("insert into a table") { + def createDF(from: Int, to: Int): DataFrame = + createDataFrame((from to to).map(i => Tuple2(i, s"str$i"))).toDF("c1", "c2") + + createDF(0, 9).saveAsTable("insertParquet", "parquet") + checkAnswer( + sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), + (6 to 9).map(i => Row(i, s"str$i"))) + + intercept[AnalysisException] { + createDF(10, 19).saveAsTable("insertParquet", "parquet") + } + + createDF(10, 19).saveAsTable("insertParquet", "parquet", SaveMode.Append) + checkAnswer( + sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), + (6 to 19).map(i => Row(i, s"str$i"))) + + createDF(20, 29).saveAsTable("insertParquet", "parquet", SaveMode.Append) + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 25"), + (6 to 24).map(i => Row(i, s"str$i"))) + + intercept[AnalysisException] { + createDF(30, 39).saveAsTable("insertParquet") + } + + createDF(30, 39).saveAsTable("insertParquet", SaveMode.Append) + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 35"), + (6 to 34).map(i => Row(i, s"str$i"))) + + createDF(40, 49).insertInto("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 45"), + (6 to 44).map(i => Row(i, s"str$i"))) + + createDF(50, 59).saveAsTable("insertParquet", SaveMode.Overwrite) + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 51 AND p.c1 < 55"), + (52 to 54).map(i => Row(i, s"str$i"))) + createDF(60, 69).saveAsTable("insertParquet", SaveMode.Ignore) + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p"), + (50 to 59).map(i => Row(i, s"str$i"))) + + createDF(70, 79).insertInto("insertParquet", overwrite = true) + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p"), + (70 to 79).map(i => Row(i, s"str$i"))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala similarity index 52% rename from sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala index f332cb389f339..d6ddd539d159d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala @@ -15,37 +15,19 @@ * limitations under the License. */ -package org.apache.spark.sql.jdbc +package org.apache.spark.sql.hive -import scala.collection.mutable.MutableList +import org.scalatest.FunSuite -import com.spotify.docker.client._ +import org.apache.spark.SparkConf +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.sql.hive.test.TestHive -/** - * A factory and morgue for DockerClient objects. In the DockerClient we use, - * calling close() closes the desired DockerClient but also renders all other - * DockerClients inoperable. This is inconvenient if we have more than one - * open, such as during tests. - */ -object DockerClientFactory { - var numClients: Int = 0 - val zombies = new MutableList[DockerClient]() - - def get(): DockerClient = { - this.synchronized { - numClients = numClients + 1 - DefaultDockerClient.fromEnv.build() - } - } +class SerializationSuite extends FunSuite { - def close(dc: DockerClient) { - this.synchronized { - numClients = numClients - 1 - zombies += dc - if (numClients == 0) { - zombies.foreach(_.close()) - zombies.clear() - } - } + test("[SPARK-5840] HiveContext should be serializable") { + val hiveContext = new HiveContext(TestHive.sparkContext) + hiveContext.hiveconf + new JavaSerializer(new SparkConf()).newInstance().serialize(hiveContext) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 6f07fd5a879c0..ccd0e5aa51f95 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -120,18 +120,18 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { // Try to analyze a temp table sql("""SELECT * FROM src""").registerTempTable("tempTable") - intercept[NotImplementedError] { + intercept[UnsupportedOperationException] { analyze("tempTable") } catalog.unregisterTable(Seq("tempTable")) } test("estimates the size of a test MetastoreRelation") { - val rdd = sql("""SELECT * FROM src""") - val sizes = rdd.queryExecution.analyzed.collect { case mr: MetastoreRelation => + val df = sql("""SELECT * FROM src""") + val sizes = df.queryExecution.analyzed.collect { case mr: MetastoreRelation => mr.statistics.sizeInBytes } - assert(sizes.size === 1, s"Size wrong for:\n ${rdd.queryExecution}") + assert(sizes.size === 1, s"Size wrong for:\n ${df.queryExecution}") assert(sizes(0).equals(BigInt(5812)), s"expected exact size 5812 for test table 'src', got: ${sizes(0)}") } @@ -145,10 +145,10 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { ct: ClassTag[_]) = { before() - var rdd = sql(query) + var df = sql(query) // Assert src has a size smaller than the threshold. - val sizes = rdd.queryExecution.analyzed.collect { + val sizes = df.queryExecution.analyzed.collect { case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } assert(sizes.size === 2 && sizes(0) <= conf.autoBroadcastJoinThreshold @@ -157,21 +157,21 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { // Using `sparkPlan` because for relevant patterns in HashJoin to be // matched, other strategies need to be applied. - var bhj = rdd.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } + var bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } assert(bhj.size === 1, - s"actual query plans do not contain broadcast join: ${rdd.queryExecution}") + s"actual query plans do not contain broadcast join: ${df.queryExecution}") - checkAnswer(rdd, expectedAnswer) // check correctness of output + checkAnswer(df, expectedAnswer) // check correctness of output TestHive.conf.settings.synchronized { val tmp = conf.autoBroadcastJoinThreshold sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1""") - rdd = sql(query) - bhj = rdd.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } + df = sql(query) + bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") - val shj = rdd.queryExecution.sparkPlan.collect { case j: ShuffledHashJoin => j } + val shj = df.queryExecution.sparkPlan.collect { case j: ShuffledHashJoin => j } assert(shj.size === 1, "ShuffledHashJoin should be planned when BroadcastHashJoin is turned off") @@ -199,10 +199,10 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { |left semi JOIN src b ON a.key=86 and a.key = b.key""".stripMargin val answer = Row(86, "val_86") - var rdd = sql(leftSemiJoinQuery) + var df = sql(leftSemiJoinQuery) // Assert src has a size smaller than the threshold. - val sizes = rdd.queryExecution.analyzed.collect { + val sizes = df.queryExecution.analyzed.collect { case r if implicitly[ClassTag[MetastoreRelation]].runtimeClass .isAssignableFrom(r.getClass) => r.statistics.sizeInBytes @@ -213,25 +213,25 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { // Using `sparkPlan` because for relevant patterns in HashJoin to be // matched, other strategies need to be applied. - var bhj = rdd.queryExecution.sparkPlan.collect { + var bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastLeftSemiJoinHash => j } assert(bhj.size === 1, - s"actual query plans do not contain broadcast join: ${rdd.queryExecution}") + s"actual query plans do not contain broadcast join: ${df.queryExecution}") - checkAnswer(rdd, answer) // check correctness of output + checkAnswer(df, answer) // check correctness of output TestHive.conf.settings.synchronized { val tmp = conf.autoBroadcastJoinThreshold sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") - rdd = sql(leftSemiJoinQuery) - bhj = rdd.queryExecution.sparkPlan.collect { + df = sql(leftSemiJoinQuery) + bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastLeftSemiJoinHash => j } assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") - val shj = rdd.queryExecution.sparkPlan.collect { + val shj = df.queryExecution.sparkPlan.collect { case j: LeftSemiJoinHash => j } assert(shj.size === 1, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala new file mode 100644 index 0000000000000..85b6bc93d7122 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +/* Implicits */ + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.hive.test.TestHive._ + +case class FunctionResult(f1: String, f2: String) + +class UDFSuite extends QueryTest { + test("UDF case insensitive") { + udf.register("random0", () => { Math.random()}) + udf.register("RANDOM1", () => { Math.random()}) + udf.register("strlenScala", (_: String).length + (_:Int)) + assert(sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index f8a957d55d57e..a90bd1e257ade 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -22,8 +22,8 @@ import java.io._ import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen} import org.apache.spark.Logging +import org.apache.spark.sql.sources.DescribeCommand import org.apache.spark.sql.execution.{SetCommand, ExplainCommand} -import org.apache.spark.sql.hive.DescribeCommand import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala new file mode 100644 index 0000000000000..efbef68cd4447 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.hive.test.TestHive._ + +/** + * A set of tests that validates commands can also be queried by like a table + */ +class HiveOperatorQueryableSuite extends QueryTest { + test("SPARK-5324 query result of describe command") { + loadTestTable("src") + + // register a describe command to be a temp table + sql("desc src").registerTempTable("mydesc") + checkAnswer( + sql("desc mydesc"), + Seq( + Row("col_name", "string", "name of the column"), + Row("data_type", "string", "data type of the column"), + Row("comment", "string", "comment of the column"))) + + checkAnswer( + sql("select * from mydesc"), + Seq( + Row("key", "int", null), + Row("value", "string", null))) + + checkAnswer( + sql("select col_name, data_type, comment from mydesc"), + Seq( + Row("key", "int", null), + Row("value", "string", null))) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala index c939e6e99d28a..bdb53ddf59c19 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala @@ -22,10 +22,12 @@ import org.apache.spark.sql.hive.test.TestHive class HivePlanTest extends QueryTest { import TestHive._ + import TestHive.implicits._ test("udf constant folding") { - val optimized = sql("SELECT cos(null) FROM src").queryExecution.optimizedPlan - val correctAnswer = sql("SELECT cast(null as double) FROM src").queryExecution.optimizedPlan + Seq.empty[Tuple1[Int]].toDF("a").registerTempTable("t") + val optimized = sql("SELECT cos(null) FROM t").queryExecution.optimizedPlan + val correctAnswer = sql("SELECT cast(null as double) FROM t").queryExecution.optimizedPlan comparePlans(optimized, correctAnswer) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 4c53b10ba96e9..c0d21bc9a89da 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.{SparkFiles, SparkException} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.plans.logical.Project -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ @@ -43,6 +43,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { private val originalTimeZone = TimeZone.getDefault private val originalLocale = Locale.getDefault + import org.apache.spark.sql.hive.test.TestHive.implicits._ + override def beforeAll() { TestHive.cacheTables = true // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) @@ -57,10 +59,10 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { Locale.setDefault(originalLocale) } - test("SPARK-4908: concurent hive native commands") { + test("SPARK-4908: concurrent hive native commands") { (1 to 100).par.map { _ => sql("USE default") - sql("SHOW TABLES") + sql("SHOW DATABASES") } } @@ -200,6 +202,9 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("having no references", "SELECT key FROM src GROUP BY key HAVING COUNT(*) > 1") + createQueryTest("no from clause", + "SELECT 1, +1, -1") + createQueryTest("boolean = number", """ |SELECT @@ -253,8 +258,30 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("Cast Timestamp to Timestamp in UDF", """ - | SELECT DATEDIFF(CAST(value AS timestamp), CAST('2002-03-21 00:00:00' AS timestamp)) - | FROM src LIMIT 1 + | SELECT DATEDIFF(CAST(value AS timestamp), CAST('2002-03-21 00:00:00' AS timestamp)) + | FROM src LIMIT 1 + """.stripMargin) + + createQueryTest("Date comparison test 1", + """ + | SELECT + | CAST(CAST('1970-01-01 22:00:00' AS timestamp) AS date) == + | CAST(CAST('1970-01-01 23:00:00' AS timestamp) AS date) + | FROM src LIMIT 1 + """.stripMargin) + + createQueryTest("Date comparison test 2", + "SELECT CAST(CAST(0 AS timestamp) AS date) > CAST(0 AS timestamp) FROM src LIMIT 1") + + createQueryTest("Date cast", + """ + | SELECT + | CAST(CAST(0 AS timestamp) AS date), + | CAST(CAST(CAST(0 AS timestamp) AS date) AS string), + | CAST(0 AS timestamp), + | CAST(CAST(0 AS timestamp) AS string), + | CAST(CAST(CAST('1970-01-01 23:00:00' AS timestamp) AS date) AS timestamp) + | FROM src LIMIT 1 """.stripMargin) createQueryTest("Simple Average", @@ -402,7 +429,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' WITH SERDEPROPERTIES |('serialization.last.column.takes.rest'='true') FROM src; """.stripMargin.replaceAll("\n", " ")) - + createQueryTest("LIKE", "SELECT * FROM src WHERE value LIKE '%1%'") @@ -440,6 +467,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("sampling") { sql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s") + sql("SELECT * FROM src TABLESAMPLE(100 PERCENT) s") } test("DataFrame toString") { @@ -540,7 +568,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TestHive.sparkContext.parallelize( TestData(1, "str1") :: TestData(2, "str2") :: Nil) - testData.registerTempTable("REGisteredTABle") + testData.toDF().registerTempTable("REGisteredTABle") assertResult(Array(Row(2, "str2"))) { sql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " + @@ -556,8 +584,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("SPARK-1704: Explain commands as a DataFrame") { sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - val rdd = sql("explain select key, count(value) from src group by key") - assert(isExplanation(rdd)) + val df = sql("explain select key, count(value) from src group by key") + assert(isExplanation(df)) TestHive.reset() } @@ -565,7 +593,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("SPARK-2180: HAVING support in GROUP BY clauses (positive)") { val fixture = List(("foo", 2), ("bar", 1), ("foo", 4), ("bar", 3)) .zipWithIndex.map {case Pair(Pair(value, attr), key) => HavingRow(key, value, attr)} - TestHive.sparkContext.parallelize(fixture).registerTempTable("having_test") + TestHive.sparkContext.parallelize(fixture).toDF().registerTempTable("having_test") val results = sql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3") .collect() @@ -583,30 +611,44 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assert(sql("select key from src having key > 490").collect().size < 100) } + test("SPARK-5383 alias for udfs with multi output columns") { + assert( + sql("select stack(2, key, value, key, value) as (a, b) from src limit 5") + .collect() + .size == 5) + + assert( + sql("select a, b from (select stack(2, key, value, key, value) as (a, b) from src) t limit 5") + .collect() + .size == 5) + } + test("SPARK-5367: resolve star expression in udf") { assert(sql("select concat(*) from src limit 5").collect().size == 5) assert(sql("select array(*) from src limit 5").collect().size == 5) + assert(sql("select concat(key, *) from src limit 5").collect().size == 5) + assert(sql("select array(key, *) from src limit 5").collect().size == 5) } test("Query Hive native command execution result") { - val tableName = "test_native_commands" + val databaseName = "test_native_commands" assertResult(0) { - sql(s"DROP TABLE IF EXISTS $tableName").count() + sql(s"DROP DATABASE IF EXISTS $databaseName").count() } assertResult(0) { - sql(s"CREATE TABLE $tableName(key INT, value STRING)").count() + sql(s"CREATE DATABASE $databaseName").count() } assert( - sql("SHOW TABLES") + sql("SHOW DATABASES") .select('result) .collect() .map(_.getString(0)) - .contains(tableName)) + .contains(databaseName)) - assert(isExplanation(sql(s"EXPLAIN SELECT key, COUNT(*) FROM $tableName GROUP BY key"))) + assert(isExplanation(sql(s"EXPLAIN SELECT key, COUNT(*) FROM src GROUP BY key"))) TestHive.reset() } @@ -699,12 +741,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TestHive.sparkContext.parallelize( TestData(1, "str1") :: TestData(1, "str2") :: Nil) - testData.registerTempTable("test_describe_commands2") + testData.toDF().registerTempTable("test_describe_commands2") assertResult( Array( - Row("a", "IntegerType", null), - Row("b", "StringType", null)) + Row("a", "int", ""), + Row("b", "string", "")) ) { sql("DESCRIBE test_describe_commands2") .select('col_name, 'data_type, 'comment) @@ -818,6 +860,22 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } + test("SPARK-5592: get java.net.URISyntaxException when dynamic partitioning") { + sql(""" + |create table sc as select * + |from (select '2011-01-11', '2011-01-11+14:18:26' from src tablesample (1 rows) + |union all + |select '2011-01-11', '2011-01-11+15:18:26' from src tablesample (1 rows) + |union all + |select '2011-01-11', '2011-01-11+16:18:26' from src tablesample (1 rows) ) s + """.stripMargin) + sql("create table sc_part (key string) partitioned by (ts string) stored as rcfile") + sql("set hive.exec.dynamic.partition=true") + sql("set hive.exec.dynamic.partition.mode=nonstrict") + sql("insert overwrite table sc_part partition(ts) select * from sc") + sql("drop table sc_part") + } + test("Partition spec validation") { sql("DROP TABLE IF EXISTS dp_test") sql("CREATE TABLE dp_test(key INT, value STRING) PARTITIONED BY (dp INT, sp INT)") @@ -843,8 +901,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } test("SPARK-3414 regression: should store analyzed logical plan when registering a temp table") { - sparkContext.makeRDD(Seq.empty[LogEntry]).registerTempTable("rawLogs") - sparkContext.makeRDD(Seq.empty[LogFile]).registerTempTable("logFiles") + sparkContext.makeRDD(Seq.empty[LogEntry]).toDF().registerTempTable("rawLogs") + sparkContext.makeRDD(Seq.empty[LogFile]).toDF().registerTempTable("logFiles") sql( """ @@ -922,8 +980,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { val testVal = "test.val.0" val nonexistentKey = "nonexistent" val KV = "([^=]+)=([^=]*)".r - def collectResults(rdd: DataFrame): Set[(String, String)] = - rdd.collect().map { + def collectResults(df: DataFrame): Set[(String, String)] = + df.collect().map { case Row(key: String, value: String) => key -> value case Row(KV(key, value)) => key -> value }.toSet diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index 422e843d2b0d2..f4440e5b7846a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.hive.test.TestHive.{sparkContext, jsonRDD, sql} +import org.apache.spark.sql.hive.test.TestHive.implicits._ case class Nested(a: Int, B: Int) case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested]) @@ -28,16 +29,24 @@ case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested]) */ class HiveResolutionSuite extends HiveComparisonTest { - case class NestedData(a: Seq[NestedData2], B: NestedData2) - case class NestedData2(a: NestedData3, B: NestedData3) - case class NestedData3(a: Int, B: Int) - test("SPARK-3698: case insensitive test for nested data") { - sparkContext.makeRDD(Seq.empty[NestedData]).registerTempTable("nested") + jsonRDD(sparkContext.makeRDD( + """{"a": [{"a": {"a": 1}}]}""" :: Nil)).registerTempTable("nested") // This should be successfully analyzed sql("SELECT a[0].A.A from nested").queryExecution.analyzed } + test("SPARK-5278: check ambiguous reference to fields") { + jsonRDD(sparkContext.makeRDD( + """{"a": [{"b": 1, "B": 2}]}""" :: Nil)).registerTempTable("nested") + + // there are 2 filed matching field name "b", we should report Ambiguous reference error + val exception = intercept[AnalysisException] { + sql("SELECT a[0].b from nested").queryExecution.analyzed + } + assert(exception.getMessage.contains("Ambiguous reference to fields")) + } + createQueryTest("table.attr", "SELECT src.key FROM src ORDER BY key LIMIT 1") @@ -67,8 +76,8 @@ class HiveResolutionSuite extends HiveComparisonTest { test("case insensitivity with scala reflection") { // Test resolution with Scala Reflection - TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) - .registerTempTable("caseSensitivityTest") + sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) + .toDF().registerTempTable("caseSensitivityTest") val query = sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") assert(query.schema.fields.map(_.name) === Seq("a", "b", "A", "B", "a", "b", "A", "B"), @@ -78,18 +87,27 @@ class HiveResolutionSuite extends HiveComparisonTest { ignore("case insensitivity with scala reflection joins") { // Test resolution with Scala Reflection - TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) - .registerTempTable("caseSensitivityTest") + sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) + .toDF().registerTempTable("caseSensitivityTest") sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a").collect() } test("nested repeated resolution") { - TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) - .registerTempTable("nestedRepeatedTest") + sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) + .toDF().registerTempTable("nestedRepeatedTest") assert(sql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1) } + createQueryTest("test ambiguousReferences resolved as hive", + """ + |CREATE TABLE t1(x INT); + |CREATE TABLE t2(a STRUCT, k INT); + |INSERT OVERWRITE TABLE t1 SELECT 1 FROM src LIMIT 1; + |INSERT OVERWRITE TABLE t2 SELECT named_struct("x",1),1 FROM src LIMIT 1; + |SELECT a.x FROM t1 a JOIN t2 b ON a.x = b.k; + """.stripMargin) + /** * Negative examples. Currently only left here for documentation purposes. * TODO(marmbrus): Test that catalyst fails on these queries. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 8fb5e050a237a..ab53c6309e089 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.Row -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.util.Utils diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index dd0df1a9f6320..d7c5d1a25a82b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -22,7 +22,7 @@ import java.util import java.util.Properties import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF +import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFAverage, GenericUDF} import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} @@ -47,7 +47,9 @@ case class ListStringCaseClass(l: Seq[String]) * A test suite for Hive custom UDFs. */ class HiveUdfSuite extends QueryTest { - import TestHive._ + + import TestHive.{udf, sql} + import TestHive.implicits._ test("spark sql udf test that returns a struct") { udf.register("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5)) @@ -60,7 +62,7 @@ class HiveUdfSuite extends QueryTest { | getStruct(1).f5 FROM src LIMIT 1 """.stripMargin).head() === Row(1, 2, 3, 4, 5)) } - + test("SPARK-4785 When called with arguments referring column fields, PMOD throws NPE") { checkAnswer( sql("SELECT PMOD(CAST(key as INT), 10) FROM src LIMIT 1"), @@ -91,10 +93,19 @@ class HiveUdfSuite extends QueryTest { sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") } + test("SPARK-6409 UDAFAverage test") { + sql(s"CREATE TEMPORARY FUNCTION test_avg AS '${classOf[GenericUDAFAverage].getName}'") + checkAnswer( + sql("SELECT test_avg(1), test_avg(substr(value,5)) FROM src"), + Seq(Row(1.0, 260.182))) + sql("DROP TEMPORARY FUNCTION IF EXISTS test_avg") + TestHive.reset() + } + test("SPARK-2693 udaf aggregates test") { checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"), sql("SELECT max(key) FROM src").collect().toSeq) - + checkAnswer(sql("SELECT percentile(key, array(1, 1)) FROM src LIMIT 1"), sql("SELECT array(max(key), max(key)) FROM src").collect().toSeq) } @@ -102,14 +113,14 @@ class HiveUdfSuite extends QueryTest { test("Generic UDAF aggregates") { checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999)) FROM src LIMIT 1"), sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq) - + checkAnswer(sql("SELECT percentile_approx(100.0, array(0.9, 0.9)) FROM src LIMIT 1"), sql("SELECT array(100, 100) FROM src LIMIT 1").collect().toSeq) } - + test("UDFIntegerToString") { val testData = TestHive.sparkContext.parallelize( - IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil) + IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil).toDF() testData.registerTempTable("integerTable") sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '${classOf[UDFIntegerToString].getName}'") @@ -125,7 +136,7 @@ class HiveUdfSuite extends QueryTest { val testData = TestHive.sparkContext.parallelize( ListListIntCaseClass(Nil) :: ListListIntCaseClass(Seq((1, 2, 3))) :: - ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil) + ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil).toDF() testData.registerTempTable("listListIntTable") sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'") @@ -140,7 +151,7 @@ class HiveUdfSuite extends QueryTest { test("UDFListString") { val testData = TestHive.sparkContext.parallelize( ListStringCaseClass(Seq("a", "b", "c")) :: - ListStringCaseClass(Seq("d", "e")) :: Nil) + ListStringCaseClass(Seq("d", "e")) :: Nil).toDF() testData.registerTempTable("listStringTable") sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'") @@ -154,7 +165,7 @@ class HiveUdfSuite extends QueryTest { test("UDFStringString") { val testData = TestHive.sparkContext.parallelize( - StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil) + StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil).toDF() testData.registerTempTable("stringTable") sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'") @@ -171,7 +182,7 @@ class HiveUdfSuite extends QueryTest { ListListIntCaseClass(Nil) :: ListListIntCaseClass(Seq((1, 2, 3))) :: ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: - Nil) + Nil).toDF() testData.registerTempTable("TwoListTable") sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 4efe0c5e0cd44..e177f290d501f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -17,22 +17,138 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.QueryTest - -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.hive.{MetastoreRelation, HiveShim} +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.parquet.ParquetRelation2 +import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.types._ +import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} case class Nested1(f1: Nested2) case class Nested2(f2: Nested3) case class Nested3(f3: Int) +case class NestedArray2(b: Seq[Int]) +case class NestedArray1(a: NestedArray2) + +case class Order( + id: Int, + make: String, + `type`: String, + price: Int, + pdate: String, + customer: String, + city: String, + state: String, + month: Int) + /** * A collection of hive query tests where we generate the answers ourselves instead of depending on * Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is * valid, but Hive currently cannot execute it. */ class SQLQuerySuite extends QueryTest { + + test("SPARK-6851: Self-joined converted parquet tables") { + val orders = Seq( + Order(1, "Atlas", "MTB", 234, "2015-01-07", "John D", "Pacifica", "CA", 20151), + Order(3, "Swift", "MTB", 285, "2015-01-17", "John S", "Redwood City", "CA", 20151), + Order(4, "Atlas", "Hybrid", 303, "2015-01-23", "Jones S", "San Mateo", "CA", 20151), + Order(7, "Next", "MTB", 356, "2015-01-04", "Jane D", "Daly City", "CA", 20151), + Order(10, "Next", "YFlikr", 187, "2015-01-09", "John D", "Fremont", "CA", 20151), + Order(11, "Swift", "YFlikr", 187, "2015-01-23", "John D", "Hayward", "CA", 20151), + Order(2, "Next", "Hybrid", 324, "2015-02-03", "Jane D", "Daly City", "CA", 20152), + Order(5, "Next", "Street", 187, "2015-02-08", "John D", "Fremont", "CA", 20152), + Order(6, "Atlas", "Street", 154, "2015-02-09", "John D", "Pacifica", "CA", 20152), + Order(8, "Swift", "Hybrid", 485, "2015-02-19", "John S", "Redwood City", "CA", 20152), + Order(9, "Atlas", "Split", 303, "2015-02-28", "Jones S", "San Mateo", "CA", 20152)) + + val orderUpdates = Seq( + Order(1, "Atlas", "MTB", 434, "2015-01-07", "John D", "Pacifica", "CA", 20151), + Order(11, "Swift", "YFlikr", 137, "2015-01-23", "John D", "Hayward", "CA", 20151)) + + orders.toDF.registerTempTable("orders1") + orderUpdates.toDF.registerTempTable("orderupdates1") + + sql( + """CREATE TABLE orders( + | id INT, + | make String, + | type String, + | price INT, + | pdate String, + | customer String, + | city String) + |PARTITIONED BY (state STRING, month INT) + |STORED AS PARQUET + """.stripMargin) + + sql( + """CREATE TABLE orderupdates( + | id INT, + | make String, + | type String, + | price INT, + | pdate String, + | customer String, + | city String) + |PARTITIONED BY (state STRING, month INT) + |STORED AS PARQUET + """.stripMargin) + + sql("set hive.exec.dynamic.partition.mode=nonstrict") + sql("INSERT INTO TABLE orders PARTITION(state, month) SELECT * FROM orders1") + sql("INSERT INTO TABLE orderupdates PARTITION(state, month) SELECT * FROM orderupdates1") + + checkAnswer( + sql( + """ + |select orders.state, orders.month + |from orders + |join ( + | select distinct orders.state,orders.month + | from orders + | join orderupdates + | on orderupdates.id = orders.id) ao + | on ao.state = orders.state and ao.month = orders.month + """.stripMargin), + (1 to 6).map(_ => Row("CA", 20151))) + } + + test("SPARK-5371: union with null and sum") { + val df = Seq((1, 1)).toDF("c1", "c2") + df.registerTempTable("table1") + + val query = sql( + """ + |SELECT + | MIN(c1), + | MIN(c2) + |FROM ( + | SELECT + | SUM(c1) c1, + | NULL c2 + | FROM table1 + | UNION ALL + | SELECT + | NULL c1, + | SUM(c2) c2 + | FROM table1 + |) a + """.stripMargin) + checkAnswer(query, Row(1, 1) :: Nil) + } + + test("explode nested Field") { + Seq(NestedArray1(NestedArray2(Seq(1, 2, 3)))).toDF.registerTempTable("nestedArray") + checkAnswer( + sql("SELECT ints FROM nestedArray LATERAL VIEW explode(a.b) a AS ints"), + Row(1) :: Row(2) :: Row(3) :: Nil) + } + test("SPARK-4512 Fix attribute reference resolution error when using SORT BY") { checkAnswer( sql("SELECT * FROM (SELECT key + key AS a FROM src SORT BY value) t ORDER BY t.a"), @@ -40,6 +156,73 @@ class SQLQuerySuite extends QueryTest { ) } + test("CTAS without serde") { + def checkRelation(tableName: String, isDataSourceParquet: Boolean): Unit = { + val relation = EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) + relation match { + case LogicalRelation(r: ParquetRelation2) => + if (!isDataSourceParquet) { + fail( + s"${classOf[MetastoreRelation].getCanonicalName} is expected, but found " + + s"${ParquetRelation2.getClass.getCanonicalName}.") + } + + case r: MetastoreRelation => + if (isDataSourceParquet) { + fail( + s"${ParquetRelation2.getClass.getCanonicalName} is expected, but found " + + s"${classOf[MetastoreRelation].getCanonicalName}.") + } + } + } + + val originalConf = getConf("spark.sql.hive.convertCTAS", "false") + + setConf("spark.sql.hive.convertCTAS", "true") + + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + var message = intercept[AnalysisException] { + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + }.getMessage + assert(message.contains("Table ctas1 already exists")) + checkRelation("ctas1", true) + sql("DROP TABLE ctas1") + + // Specifying database name for query can be converted to data source write path + // is not allowed right now. + message = intercept[AnalysisException] { + sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + }.getMessage + assert( + message.contains("Cannot specify database name in a CTAS statement"), + "When spark.sql.hive.convertCTAS is true, we should not allow " + + "database name specified.") + + sql("CREATE TABLE ctas1 stored as textfile AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", true) + sql("DROP TABLE ctas1") + + sql( + "CREATE TABLE ctas1 stored as sequencefile AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", true) + sql("DROP TABLE ctas1") + + sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", false) + sql("DROP TABLE ctas1") + + sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", false) + sql("DROP TABLE ctas1") + + sql("CREATE TABLE ctas1 stored as parquet AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", false) + sql("DROP TABLE ctas1") + + setConf("spark.sql.hive.convertCTAS", originalConf) + } + test("CTAS with serde") { sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect() sql( @@ -102,6 +285,37 @@ class SQLQuerySuite extends QueryTest { "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe", "serde_p1=p1", "serde_p2=p2", "tbl_p1=p11", "tbl_p2=p22","MANAGED_TABLE" ) + + if (HiveShim.version =="0.13.1") { + val origUseParquetDataSource = conf.parquetUseDataSourceApi + try { + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + sql( + """CREATE TABLE ctas5 + | STORED AS parquet AS + | SELECT key, value + | FROM src + | ORDER BY key, value""".stripMargin).collect() + + checkExistence(sql("DESC EXTENDED ctas5"), true, + "name:key", "type:string", "name:value", "ctas5", + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", + "MANAGED_TABLE" + ) + + val default = getConf("spark.sql.hive.convertMetastoreParquet", "true") + // use the Hive SerDe for parquet tables + sql("set spark.sql.hive.convertMetastoreParquet = false") + checkAnswer( + sql("SELECT key, value FROM ctas5 ORDER BY key, value"), + sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) + sql(s"set spark.sql.hive.convertMetastoreParquet = $default") + } finally { + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, origUseParquetDataSource.toString) + } + } } test("command substitution") { @@ -141,7 +355,8 @@ class SQLQuerySuite extends QueryTest { } test("double nested data") { - sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil).registerTempTable("nested") + sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil) + .toDF().registerTempTable("nested") checkAnswer( sql("SELECT f1.f2.f3 FROM nested"), Row(1)) @@ -151,7 +366,7 @@ class SQLQuerySuite extends QueryTest { sql("SELECT * FROM test_ctas_1234"), sql("SELECT * FROM nested").collect().toSeq) - intercept[org.apache.hadoop.hive.ql.metadata.InvalidTableException] { + intercept[AnalysisException] { sql("CREATE TABLE test_ctas_12345 AS SELECT * from notexists").collect() } } @@ -159,12 +374,12 @@ class SQLQuerySuite extends QueryTest { test("test CTAS") { checkAnswer(sql("CREATE TABLE test_ctas_123 AS SELECT key, value FROM src"), Seq.empty[Row]) checkAnswer( - sql("SELECT key, value FROM test_ctas_123 ORDER BY key"), + sql("SELECT key, value FROM test_ctas_123 ORDER BY key"), sql("SELECT key, value FROM src ORDER BY key").collect().toSeq) } test("SPARK-4825 save join to table") { - val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) + val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).toDF() sql("CREATE TABLE test1 (key INT, value STRING)") testData.insertInto("test1") sql("CREATE TABLE test2 (key INT, value STRING)") @@ -244,7 +459,7 @@ class SQLQuerySuite extends QueryTest { val rowRdd = sparkContext.parallelize(row :: Nil) - applySchema(rowRdd, schema).registerTempTable("testTable") + TestHive.createDataFrame(rowRdd, schema).registerTempTable("testTable") sql( """CREATE TABLE nullValuesInInnerComplexTypes @@ -282,4 +497,45 @@ class SQLQuerySuite extends QueryTest { dropTempTable("data") } + + test("logical.Project should not be resolved if it contains aggregates or generators") { + // This test is used to test the fix of SPARK-5875. + // The original issue was that Project's resolved will be true when it contains + // AggregateExpressions or Generators. However, in this case, the Project + // is not in a valid state (cannot be executed). Because of this bug, the analysis rule of + // PreInsertionCasts will actually start to work before ImplicitGenerate and then + // generates an invalid query plan. + val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i+1}]}""")) + jsonRDD(rdd).registerTempTable("data") + val originalConf = getConf("spark.sql.hive.convertCTAS", "false") + setConf("spark.sql.hive.convertCTAS", "false") + + sql("CREATE TABLE explodeTest (key bigInt)") + table("explodeTest").queryExecution.analyzed match { + case metastoreRelation: MetastoreRelation => // OK + case _ => + fail("To correctly test the fix of SPARK-5875, explodeTest should be a MetastoreRelation") + } + + sql(s"INSERT OVERWRITE TABLE explodeTest SELECT explode(a) AS val FROM data") + checkAnswer( + sql("SELECT key from explodeTest"), + (1 to 5).flatMap(i => Row(i) :: Row(i + 1) :: Nil) + ) + + sql("DROP TABLE explodeTest") + dropTempTable("data") + setConf("spark.sql.hive.convertCTAS", originalConf) + } + + test("sanity test for SPARK-6618") { + (1 to 100).par.map { i => + val tableName = s"SPARK_6618_table_$i" + sql(s"CREATE TABLE $tableName (col1 string)") + catalog.lookupRelation(Seq(tableName)) + table(tableName) + tables() + sql(s"DROP TABLE $tableName") + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala new file mode 100644 index 0000000000000..dccac4d128bfc --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -0,0 +1,902 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.File + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.{QueryTest, SQLConf, SaveMode} +import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} +import org.apache.spark.sql.hive.execution.HiveTableScan +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.json.JSONRelation +import org.apache.spark.sql.sources.{InsertIntoDataSource, LogicalRelation} +import org.apache.spark.sql.parquet.{ParquetRelation2, ParquetTableScan} +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.types._ + +// The data where the partitioning key exists only in the directory structure. +case class ParquetData(intField: Int, stringField: String) +// The data that also includes the partitioning key +case class ParquetDataWithKey(p: Int, intField: Int, stringField: String) + +case class StructContainer(intStructField :Int, stringStructField: String) + +case class ParquetDataWithComplexTypes( + intField: Int, + stringField: String, + structField: StructContainer, + arrayField: Seq[Int]) + +case class ParquetDataWithKeyAndComplexTypes( + p: Int, + intField: Int, + stringField: String, + structField: StructContainer, + arrayField: Seq[Int]) + +/** + * A suite to test the automatic conversion of metastore tables with parquet data to use the + * built in parquet support. + */ +class ParquetMetastoreSuiteBase extends ParquetPartitioningTest { + override def beforeAll(): Unit = { + super.beforeAll() + + sql(s""" + create external table partitioned_parquet + ( + intField INT, + stringField STRING + ) + PARTITIONED BY (p int) + ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + STORED AS + INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + location '${partitionedTableDir.getCanonicalPath}' + """) + + sql(s""" + create external table partitioned_parquet_with_key + ( + intField INT, + stringField STRING + ) + PARTITIONED BY (p int) + ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + STORED AS + INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + location '${partitionedTableDirWithKey.getCanonicalPath}' + """) + + sql(s""" + create external table normal_parquet + ( + intField INT, + stringField STRING + ) + ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + STORED AS + INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + location '${new File(normalTableDir, "normal").getCanonicalPath}' + """) + + sql(s""" + CREATE EXTERNAL TABLE partitioned_parquet_with_complextypes + ( + intField INT, + stringField STRING, + structField STRUCT, + arrayField ARRAY + ) + PARTITIONED BY (p int) + ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + STORED AS + INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + LOCATION '${partitionedTableDirWithComplexTypes.getCanonicalPath}' + """) + + sql(s""" + CREATE EXTERNAL TABLE partitioned_parquet_with_key_and_complextypes + ( + intField INT, + stringField STRING, + structField STRUCT, + arrayField ARRAY + ) + PARTITIONED BY (p int) + ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + STORED AS + INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + LOCATION '${partitionedTableDirWithKeyAndComplexTypes.getCanonicalPath}' + """) + + (1 to 10).foreach { p => + sql(s"ALTER TABLE partitioned_parquet ADD PARTITION (p=$p)") + } + + (1 to 10).foreach { p => + sql(s"ALTER TABLE partitioned_parquet_with_key ADD PARTITION (p=$p)") + } + + (1 to 10).foreach { p => + sql(s"ALTER TABLE partitioned_parquet_with_key_and_complextypes ADD PARTITION (p=$p)") + } + + (1 to 10).foreach { p => + sql(s"ALTER TABLE partitioned_parquet_with_complextypes ADD PARTITION (p=$p)") + } + + val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) + jsonRDD(rdd1).registerTempTable("jt") + val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":[$i, null]}""")) + jsonRDD(rdd2).registerTempTable("jt_array") + + setConf("spark.sql.hive.convertMetastoreParquet", "true") + } + + override def afterAll(): Unit = { + sql("DROP TABLE partitioned_parquet") + sql("DROP TABLE partitioned_parquet_with_key") + sql("DROP TABLE partitioned_parquet_with_complextypes") + sql("DROP TABLE partitioned_parquet_with_key_and_complextypes") + sql("DROP TABLE normal_parquet") + sql("DROP TABLE IF EXISTS jt") + sql("DROP TABLE IF EXISTS jt_array") + setConf("spark.sql.hive.convertMetastoreParquet", "false") + } + + test(s"conversion is working") { + assert( + sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { + case _: HiveTableScan => true + }.isEmpty) + assert( + sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { + case _: ParquetTableScan => true + case _: PhysicalRDD => true + }.nonEmpty) + } +} + +class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { + val originalConf = conf.parquetUseDataSourceApi + + override def beforeAll(): Unit = { + super.beforeAll() + + sql( + """ + |create table test_parquet + |( + | intField INT, + | stringField STRING + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + } + + override def afterAll(): Unit = { + super.afterAll() + sql("DROP TABLE IF EXISTS test_parquet") + + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } + + test("scan an empty parquet table") { + checkAnswer(sql("SELECT count(*) FROM test_parquet"), Row(0)) + } + + test("scan an empty parquet table with upper case") { + checkAnswer(sql("SELECT count(INTFIELD) FROM TEST_parquet"), Row(0)) + } + + test("insert into an empty parquet table") { + sql( + """ + |create table test_insert_parquet + |( + | intField INT, + | stringField STRING + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + // Insert into am empty table. + sql("insert into table test_insert_parquet select a, b from jt where jt.a > 5") + checkAnswer( + sql(s"SELECT intField, stringField FROM test_insert_parquet WHERE intField < 8"), + Row(6, "str6") :: Row(7, "str7") :: Nil + ) + // Insert overwrite. + sql("insert overwrite table test_insert_parquet select a, b from jt where jt.a < 5") + checkAnswer( + sql(s"SELECT intField, stringField FROM test_insert_parquet WHERE intField > 2"), + Row(3, "str3") :: Row(4, "str4") :: Nil + ) + sql("DROP TABLE IF EXISTS test_insert_parquet") + + // Create it again. + sql( + """ + |create table test_insert_parquet + |( + | intField INT, + | stringField STRING + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + // Insert overwrite an empty table. + sql("insert overwrite table test_insert_parquet select a, b from jt where jt.a < 5") + checkAnswer( + sql(s"SELECT intField, stringField FROM test_insert_parquet WHERE intField > 2"), + Row(3, "str3") :: Row(4, "str4") :: Nil + ) + // Insert into the table. + sql("insert into table test_insert_parquet select a, b from jt") + checkAnswer( + sql(s"SELECT intField, stringField FROM test_insert_parquet"), + (1 to 10).map(i => Row(i, s"str$i")) ++ (1 to 4).map(i => Row(i, s"str$i")) + ) + sql("DROP TABLE IF EXISTS test_insert_parquet") + } + + test("scan a parquet table created through a CTAS statement") { + sql( + """ + |create table test_parquet_ctas ROW FORMAT + |SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + |AS select * from jt + """.stripMargin) + + checkAnswer( + sql(s"SELECT a, b FROM test_parquet_ctas WHERE a = 1"), + Seq(Row(1, "str1")) + ) + + table("test_parquet_ctas").queryExecution.optimizedPlan match { + case LogicalRelation(p: ParquetRelation2) => // OK + case _ => + fail( + s"test_parquet_ctas should be converted to ${classOf[ParquetRelation2].getCanonicalName}") + } + + sql("DROP TABLE IF EXISTS test_parquet_ctas") + } + + test("MetastoreRelation in InsertIntoTable will be converted") { + sql( + """ + |create table test_insert_parquet + |( + | intField INT + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") + df.queryExecution.executedPlan match { + case ExecutedCommand( + InsertIntoDataSource( + LogicalRelation(r: ParquetRelation2), query, overwrite)) => // OK + case o => fail("test_insert_parquet should be converted to a " + + s"${classOf[ParquetRelation2].getCanonicalName} and " + + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." + + s"However, found a ${o.toString} ") + } + + checkAnswer( + sql("SELECT intField FROM test_insert_parquet WHERE test_insert_parquet.intField > 5"), + sql("SELECT a FROM jt WHERE jt.a > 5").collect() + ) + + sql("DROP TABLE IF EXISTS test_insert_parquet") + } + + test("MetastoreRelation in InsertIntoHiveTable will be converted") { + sql( + """ + |create table test_insert_parquet + |( + | int_array array + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") + df.queryExecution.executedPlan match { + case ExecutedCommand( + InsertIntoDataSource( + LogicalRelation(r: ParquetRelation2), query, overwrite)) => // OK + case o => fail("test_insert_parquet should be converted to a " + + s"${classOf[ParquetRelation2].getCanonicalName} and " + + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." + + s"However, found a ${o.toString} ") + } + + checkAnswer( + sql("SELECT int_array FROM test_insert_parquet"), + sql("SELECT a FROM jt_array").collect() + ) + + sql("DROP TABLE IF EXISTS test_insert_parquet") + } + + test("SPARK-6450 regression test") { + sql( + """CREATE TABLE IF NOT EXISTS ms_convert (key INT) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + // This shouldn't throw AnalysisException + val analyzed = sql( + """SELECT key FROM ms_convert + |UNION ALL + |SELECT key FROM ms_convert + """.stripMargin).queryExecution.analyzed + + assertResult(2) { + analyzed.collect { + case r @ LogicalRelation(_: ParquetRelation2) => r + }.size + } + + sql("DROP TABLE ms_convert") + } + + test("Caching converted data source Parquet Relations") { + def checkCached(tableIdentifer: catalog.QualifiedTableName): Unit = { + // Converted test_parquet should be cached. + catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) match { + case null => fail("Converted test_parquet should be cached in the cache.") + case logical @ LogicalRelation(parquetRelation: ParquetRelation2) => // OK + case other => + fail( + "The cached test_parquet should be a Parquet Relation. " + + s"However, $other is returned form the cache.") + } + } + + sql("DROP TABLE IF EXISTS test_insert_parquet") + sql("DROP TABLE IF EXISTS test_parquet_partitioned_cache_test") + + sql( + """ + |create table test_insert_parquet + |( + | intField INT, + | stringField STRING + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + var tableIdentifer = catalog.QualifiedTableName("default", "test_insert_parquet") + + // First, make sure the converted test_parquet is not cached. + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + // Table lookup will make the table cached. + table("test_insert_parquet") + checkCached(tableIdentifer) + // For insert into non-partitioned table, we will do the conversion, + // so the converted test_insert_parquet should be cached. + invalidateTable("test_insert_parquet") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + sql( + """ + |INSERT INTO TABLE test_insert_parquet + |select a, b from jt + """.stripMargin) + checkCached(tableIdentifer) + // Make sure we can read the data. + checkAnswer( + sql("select * from test_insert_parquet"), + sql("select a, b from jt").collect()) + // Invalidate the cache. + invalidateTable("test_insert_parquet") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + + // Create a partitioned table. + sql( + """ + |create table test_parquet_partitioned_cache_test + |( + | intField INT, + | stringField STRING + |) + |PARTITIONED BY (date string) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + tableIdentifer = catalog.QualifiedTableName("default", "test_parquet_partitioned_cache_test") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + sql( + """ + |INSERT INTO TABLE test_parquet_partitioned_cache_test + |PARTITION (date='2015-04-01') + |select a, b from jt + """.stripMargin) + // Right now, insert into a partitioned Parquet is not supported in data source Parquet. + // So, we expect it is not cached. + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + sql( + """ + |INSERT INTO TABLE test_parquet_partitioned_cache_test + |PARTITION (date='2015-04-02') + |select a, b from jt + """.stripMargin) + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + + // Make sure we can cache the partitioned table. + table("test_parquet_partitioned_cache_test") + checkCached(tableIdentifer) + // Make sure we can read the data. + checkAnswer( + sql("select STRINGField, date, intField from test_parquet_partitioned_cache_test"), + sql( + """ + |select b, '2015-04-01', a FROM jt + |UNION ALL + |select b, '2015-04-02', a FROM jt + """.stripMargin).collect()) + + invalidateTable("test_parquet_partitioned_cache_test") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + + sql("DROP TABLE test_insert_parquet") + sql("DROP TABLE test_parquet_partitioned_cache_test") + } +} + +class ParquetDataSourceOffMetastoreSuite extends ParquetMetastoreSuiteBase { + val originalConf = conf.parquetUseDataSourceApi + + override def beforeAll(): Unit = { + super.beforeAll() + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + } + + override def afterAll(): Unit = { + super.afterAll() + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } + + test("MetastoreRelation in InsertIntoTable will not be converted") { + sql( + """ + |create table test_insert_parquet + |( + | intField INT + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") + df.queryExecution.executedPlan match { + case insert: execution.InsertIntoHiveTable => // OK + case o => fail(s"The SparkPlan should be ${classOf[InsertIntoHiveTable].getCanonicalName}. " + + s"However, found ${o.toString}.") + } + + checkAnswer( + sql("SELECT intField FROM test_insert_parquet WHERE test_insert_parquet.intField > 5"), + sql("SELECT a FROM jt WHERE jt.a > 5").collect() + ) + + sql("DROP TABLE IF EXISTS test_insert_parquet") + } + + // TODO: enable it after the fix of SPARK-5950. + ignore("MetastoreRelation in InsertIntoHiveTable will not be converted") { + sql( + """ + |create table test_insert_parquet + |( + | int_array array + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") + df.queryExecution.executedPlan match { + case insert: execution.InsertIntoHiveTable => // OK + case o => fail(s"The SparkPlan should be ${classOf[InsertIntoHiveTable].getCanonicalName}. " + + s"However, found ${o.toString}.") + } + + checkAnswer( + sql("SELECT int_array FROM test_insert_parquet"), + sql("SELECT a FROM jt_array").collect() + ) + + sql("DROP TABLE IF EXISTS test_insert_parquet") + } +} + +/** + * A suite of tests for the Parquet support through the data sources API. + */ +class ParquetSourceSuiteBase extends ParquetPartitioningTest { + override def beforeAll(): Unit = { + super.beforeAll() + + sql( s""" + create temporary table partitioned_parquet + USING org.apache.spark.sql.parquet + OPTIONS ( + path '${partitionedTableDir.getCanonicalPath}' + ) + """) + + sql( s""" + create temporary table partitioned_parquet_with_key + USING org.apache.spark.sql.parquet + OPTIONS ( + path '${partitionedTableDirWithKey.getCanonicalPath}' + ) + """) + + sql( s""" + create temporary table normal_parquet + USING org.apache.spark.sql.parquet + OPTIONS ( + path '${new File(partitionedTableDir, "p=1").getCanonicalPath}' + ) + """) + + sql( s""" + CREATE TEMPORARY TABLE partitioned_parquet_with_key_and_complextypes + USING org.apache.spark.sql.parquet + OPTIONS ( + path '${partitionedTableDirWithKeyAndComplexTypes.getCanonicalPath}' + ) + """) + + sql( s""" + CREATE TEMPORARY TABLE partitioned_parquet_with_complextypes + USING org.apache.spark.sql.parquet + OPTIONS ( + path '${partitionedTableDirWithComplexTypes.getCanonicalPath}' + ) + """) + } + + test("SPARK-6016 make sure to use the latest footers") { + sql("drop table if exists spark_6016_fix") + + // Create a DataFrame with two partitions. So, the created table will have two parquet files. + val df1 = jsonRDD(sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i}"""), 2)) + df1.saveAsTable("spark_6016_fix", "parquet", SaveMode.Overwrite) + checkAnswer( + sql("select * from spark_6016_fix"), + (1 to 10).map(i => Row(i)) + ) + + // Create a DataFrame with four partitions. So, the created table will have four parquet files. + val df2 = jsonRDD(sparkContext.parallelize((1 to 10).map(i => s"""{"b":$i}"""), 4)) + df2.saveAsTable("spark_6016_fix", "parquet", SaveMode.Overwrite) + // For the bug of SPARK-6016, we are caching two outdated footers for df1. Then, + // since the new table has four parquet files, we are trying to read new footers from two files + // and then merge metadata in footers of these four (two outdated ones and two latest one), + // which will cause an error. + checkAnswer( + sql("select * from spark_6016_fix"), + (1 to 10).map(i => Row(i)) + ) + + sql("drop table spark_6016_fix") + } +} + +class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { + val originalConf = conf.parquetUseDataSourceApi + + override def beforeAll(): Unit = { + super.beforeAll() + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + } + + override def afterAll(): Unit = { + super.afterAll() + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } + + test("values in arrays and maps stored in parquet are always nullable") { + val df = createDataFrame(Tuple2(Map(2 -> 3), Seq(4, 5, 6)) :: Nil).toDF("m", "a") + val mapType1 = MapType(IntegerType, IntegerType, valueContainsNull = false) + val arrayType1 = ArrayType(IntegerType, containsNull = false) + val expectedSchema1 = + StructType( + StructField("m", mapType1, nullable = true) :: + StructField("a", arrayType1, nullable = true) :: Nil) + assert(df.schema === expectedSchema1) + + df.saveAsTable("alwaysNullable", "parquet") + + val mapType2 = MapType(IntegerType, IntegerType, valueContainsNull = true) + val arrayType2 = ArrayType(IntegerType, containsNull = true) + val expectedSchema2 = + StructType( + StructField("m", mapType2, nullable = true) :: + StructField("a", arrayType2, nullable = true) :: Nil) + + assert(table("alwaysNullable").schema === expectedSchema2) + + checkAnswer( + sql("SELECT m, a FROM alwaysNullable"), + Row(Map(2 -> 3), Seq(4, 5, 6))) + + sql("DROP TABLE alwaysNullable") + } +} + +class ParquetDataSourceOffSourceSuite extends ParquetSourceSuiteBase { + val originalConf = conf.parquetUseDataSourceApi + + override def beforeAll(): Unit = { + super.beforeAll() + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + } + + override def afterAll(): Unit = { + super.afterAll() + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } +} + +/** + * A collection of tests for parquet data with various forms of partitioning. + */ +abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll { + var partitionedTableDir: File = null + var normalTableDir: File = null + var partitionedTableDirWithKey: File = null + var partitionedTableDirWithComplexTypes: File = null + var partitionedTableDirWithKeyAndComplexTypes: File = null + + override def beforeAll(): Unit = { + partitionedTableDir = File.createTempFile("parquettests", "sparksql") + partitionedTableDir.delete() + partitionedTableDir.mkdir() + + normalTableDir = File.createTempFile("parquettests", "sparksql") + normalTableDir.delete() + normalTableDir.mkdir() + + (1 to 10).foreach { p => + val partDir = new File(partitionedTableDir, s"p=$p") + sparkContext.makeRDD(1 to 10) + .map(i => ParquetData(i, s"part-$p")) + .toDF() + .saveAsParquetFile(partDir.getCanonicalPath) + } + + sparkContext + .makeRDD(1 to 10) + .map(i => ParquetData(i, s"part-1")) + .toDF() + .saveAsParquetFile(new File(normalTableDir, "normal").getCanonicalPath) + + partitionedTableDirWithKey = File.createTempFile("parquettests", "sparksql") + partitionedTableDirWithKey.delete() + partitionedTableDirWithKey.mkdir() + + (1 to 10).foreach { p => + val partDir = new File(partitionedTableDirWithKey, s"p=$p") + sparkContext.makeRDD(1 to 10) + .map(i => ParquetDataWithKey(p, i, s"part-$p")) + .toDF() + .saveAsParquetFile(partDir.getCanonicalPath) + } + + partitionedTableDirWithKeyAndComplexTypes = File.createTempFile("parquettests", "sparksql") + partitionedTableDirWithKeyAndComplexTypes.delete() + partitionedTableDirWithKeyAndComplexTypes.mkdir() + + (1 to 10).foreach { p => + val partDir = new File(partitionedTableDirWithKeyAndComplexTypes, s"p=$p") + sparkContext.makeRDD(1 to 10).map { i => + ParquetDataWithKeyAndComplexTypes( + p, i, s"part-$p", StructContainer(i, f"${i}_string"), 1 to i) + }.toDF().saveAsParquetFile(partDir.getCanonicalPath) + } + + partitionedTableDirWithComplexTypes = File.createTempFile("parquettests", "sparksql") + partitionedTableDirWithComplexTypes.delete() + partitionedTableDirWithComplexTypes.mkdir() + + (1 to 10).foreach { p => + val partDir = new File(partitionedTableDirWithComplexTypes, s"p=$p") + sparkContext.makeRDD(1 to 10).map { i => + ParquetDataWithComplexTypes(i, s"part-$p", StructContainer(i, f"${i}_string"), 1 to i) + }.toDF().saveAsParquetFile(partDir.getCanonicalPath) + } + } + + override protected def afterAll(): Unit = { + partitionedTableDir.delete() + normalTableDir.delete() + partitionedTableDirWithKey.delete() + partitionedTableDirWithComplexTypes.delete() + partitionedTableDirWithKeyAndComplexTypes.delete() + } + + Seq( + "partitioned_parquet", + "partitioned_parquet_with_key", + "partitioned_parquet_with_complextypes", + "partitioned_parquet_with_key_and_complextypes").foreach { table => + + test(s"ordering of the partitioning columns $table") { + checkAnswer( + sql(s"SELECT p, stringField FROM $table WHERE p = 1"), + Seq.fill(10)(Row(1, "part-1")) + ) + + checkAnswer( + sql(s"SELECT stringField, p FROM $table WHERE p = 1"), + Seq.fill(10)(Row("part-1", 1)) + ) + } + + test(s"project the partitioning column $table") { + checkAnswer( + sql(s"SELECT p, count(*) FROM $table group by p"), + Row(1, 10) :: + Row(2, 10) :: + Row(3, 10) :: + Row(4, 10) :: + Row(5, 10) :: + Row(6, 10) :: + Row(7, 10) :: + Row(8, 10) :: + Row(9, 10) :: + Row(10, 10) :: Nil + ) + } + + test(s"project partitioning and non-partitioning columns $table") { + checkAnswer( + sql(s"SELECT stringField, p, count(intField) FROM $table GROUP BY p, stringField"), + Row("part-1", 1, 10) :: + Row("part-2", 2, 10) :: + Row("part-3", 3, 10) :: + Row("part-4", 4, 10) :: + Row("part-5", 5, 10) :: + Row("part-6", 6, 10) :: + Row("part-7", 7, 10) :: + Row("part-8", 8, 10) :: + Row("part-9", 9, 10) :: + Row("part-10", 10, 10) :: Nil + ) + } + + test(s"simple count $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table"), + Row(100)) + } + + test(s"pruned count $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE p = 1"), + Row(10)) + } + + test(s"non-existent partition $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE p = 1000"), + Row(0)) + } + + test(s"multi-partition pruned count $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE p IN (1,2,3)"), + Row(30)) + } + + test(s"non-partition predicates $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE intField IN (1,2,3)"), + Row(30)) + } + + test(s"sum $table") { + checkAnswer( + sql(s"SELECT SUM(intField) FROM $table WHERE intField IN (1,2,3) AND p = 1"), + Row(1 + 2 + 3)) + } + + test(s"hive udfs $table") { + checkAnswer( + sql(s"SELECT concat(stringField, stringField) FROM $table"), + sql(s"SELECT stringField FROM $table").map { + case Row(s: String) => Row(s + s) + }.collect().toSeq) + } + } + + Seq( + "partitioned_parquet_with_key_and_complextypes", + "partitioned_parquet_with_complextypes").foreach { table => + + test(s"SPARK-5775 read struct from $table") { + checkAnswer( + sql(s"SELECT p, structField.intStructField, structField.stringStructField FROM $table WHERE p = 1"), + (1 to 10).map(i => Row(1, i, f"${i}_string"))) + } + + // Re-enable this after SPARK-5508 is fixed + ignore(s"SPARK-5775 read array from $table") { + checkAnswer( + sql(s"SELECT arrayField, p FROM $table WHERE p = 1"), + (1 to 10).map(i => Row(1 to i, 1))) + } + } + + + test("non-part select(*)") { + checkAnswer( + sql("SELECT COUNT(*) FROM normal_parquet"), + Row(10)) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala deleted file mode 100644 index 581f666399492..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.hive.test.TestHive - -case class Cases(lower: String, UPPER: String) - -class HiveParquetSuite extends QueryTest with ParquetTest { - val sqlContext = TestHive - - import sqlContext._ - - test("Case insensitive attribute names") { - withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") { - val expected = (1 to 4).map(i => Row(i.toString)) - checkAnswer(sql("SELECT upper FROM cases"), expected) - checkAnswer(sql("SELECT LOWER FROM cases"), expected) - } - } - - test("SELECT on Parquet table") { - val data = (1 to 4).map(i => (i, s"val_$i")) - withParquetTable(data, "t") { - checkAnswer(sql("SELECT * FROM t"), data.map(Row.fromTuple)) - } - } - - test("Simple column projection + filter on Parquet table") { - withParquetTable((1 to 4).map(i => (i % 2 == 0, i, s"val_$i")), "t") { - checkAnswer( - sql("SELECT `_1`, `_3` FROM t WHERE `_1` = true"), - Seq(Row(true, "val_2"), Row(true, "val_4"))) - } - } - - test("Converting Hive to Parquet Table via saveAsParquetFile") { - withTempPath { dir => - sql("SELECT * FROM src").saveAsParquetFile(dir.getCanonicalPath) - parquetFile(dir.getCanonicalPath).registerTempTable("p") - withTempTable("p") { - checkAnswer( - sql("SELECT * FROM src ORDER BY key"), - sql("SELECT * from p ORDER BY key").collect().toSeq) - } - } - } - - - test("INSERT OVERWRITE TABLE Parquet table") { - withParquetTable((1 to 4).map(i => (i, s"val_$i")), "t") { - withTempPath { file => - sql("SELECT * FROM t LIMIT 1").saveAsParquetFile(file.getCanonicalPath) - parquetFile(file.getCanonicalPath).registerTempTable("p") - withTempTable("p") { - // let's do three overwrites for good measure - sql("INSERT OVERWRITE TABLE p SELECT * FROM t") - sql("INSERT OVERWRITE TABLE p SELECT * FROM t") - sql("INSERT OVERWRITE TABLE p SELECT * FROM t") - checkAnswer(sql("SELECT * FROM p"), sql("SELECT * FROM t").collect().toSeq) - } - } - } - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala deleted file mode 100644 index 79fd99d9f89ff..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala +++ /dev/null @@ -1,271 +0,0 @@ - -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import java.io.File - -import org.apache.spark.sql.catalyst.expressions.Row -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.hive.execution.HiveTableScan -import org.apache.spark.sql.hive.test.TestHive._ - -// The data where the partitioning key exists only in the directory structure. -case class ParquetData(intField: Int, stringField: String) -// The data that also includes the partitioning key -case class ParquetDataWithKey(p: Int, intField: Int, stringField: String) - - -/** - * A suite to test the automatic conversion of metastore tables with parquet data to use the - * built in parquet support. - */ -class ParquetMetastoreSuite extends ParquetPartitioningTest { - override def beforeAll(): Unit = { - super.beforeAll() - - sql(s""" - create external table partitioned_parquet - ( - intField INT, - stringField STRING - ) - PARTITIONED BY (p int) - ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - STORED AS - INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - location '${partitionedTableDir.getCanonicalPath}' - """) - - sql(s""" - create external table partitioned_parquet_with_key - ( - intField INT, - stringField STRING - ) - PARTITIONED BY (p int) - ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - STORED AS - INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - location '${partitionedTableDirWithKey.getCanonicalPath}' - """) - - sql(s""" - create external table normal_parquet - ( - intField INT, - stringField STRING - ) - ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - STORED AS - INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - location '${new File(partitionedTableDir, "p=1").getCanonicalPath}' - """) - - (1 to 10).foreach { p => - sql(s"ALTER TABLE partitioned_parquet ADD PARTITION (p=$p)") - } - - (1 to 10).foreach { p => - sql(s"ALTER TABLE partitioned_parquet_with_key ADD PARTITION (p=$p)") - } - - setConf("spark.sql.hive.convertMetastoreParquet", "true") - } - - override def afterAll(): Unit = { - setConf("spark.sql.hive.convertMetastoreParquet", "false") - } - - test("conversion is working") { - assert( - sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { - case _: HiveTableScan => true - }.isEmpty) - assert( - sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { - case _: ParquetTableScan => true - }.nonEmpty) - } -} - -/** - * A suite of tests for the Parquet support through the data sources API. - */ -class ParquetSourceSuite extends ParquetPartitioningTest { - override def beforeAll(): Unit = { - super.beforeAll() - - sql( s""" - create temporary table partitioned_parquet - USING org.apache.spark.sql.parquet - OPTIONS ( - path '${partitionedTableDir.getCanonicalPath}' - ) - """) - - sql( s""" - create temporary table partitioned_parquet_with_key - USING org.apache.spark.sql.parquet - OPTIONS ( - path '${partitionedTableDirWithKey.getCanonicalPath}' - ) - """) - - sql( s""" - create temporary table normal_parquet - USING org.apache.spark.sql.parquet - OPTIONS ( - path '${new File(partitionedTableDir, "p=1").getCanonicalPath}' - ) - """) - } -} - -/** - * A collection of tests for parquet data with various forms of partitioning. - */ -abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll { - var partitionedTableDir: File = null - var partitionedTableDirWithKey: File = null - - override def beforeAll(): Unit = { - partitionedTableDir = File.createTempFile("parquettests", "sparksql") - partitionedTableDir.delete() - partitionedTableDir.mkdir() - - (1 to 10).foreach { p => - val partDir = new File(partitionedTableDir, s"p=$p") - sparkContext.makeRDD(1 to 10) - .map(i => ParquetData(i, s"part-$p")) - .saveAsParquetFile(partDir.getCanonicalPath) - } - - partitionedTableDirWithKey = File.createTempFile("parquettests", "sparksql") - partitionedTableDirWithKey.delete() - partitionedTableDirWithKey.mkdir() - - (1 to 10).foreach { p => - val partDir = new File(partitionedTableDirWithKey, s"p=$p") - sparkContext.makeRDD(1 to 10) - .map(i => ParquetDataWithKey(p, i, s"part-$p")) - .saveAsParquetFile(partDir.getCanonicalPath) - } - } - - Seq("partitioned_parquet", "partitioned_parquet_with_key").foreach { table => - test(s"ordering of the partitioning columns $table") { - checkAnswer( - sql(s"SELECT p, stringField FROM $table WHERE p = 1"), - Seq.fill(10)(Row(1, "part-1")) - ) - - checkAnswer( - sql(s"SELECT stringField, p FROM $table WHERE p = 1"), - Seq.fill(10)(Row("part-1", 1)) - ) - } - - test(s"project the partitioning column $table") { - checkAnswer( - sql(s"SELECT p, count(*) FROM $table group by p"), - Row(1, 10) :: - Row(2, 10) :: - Row(3, 10) :: - Row(4, 10) :: - Row(5, 10) :: - Row(6, 10) :: - Row(7, 10) :: - Row(8, 10) :: - Row(9, 10) :: - Row(10, 10) :: Nil - ) - } - - test(s"project partitioning and non-partitioning columns $table") { - checkAnswer( - sql(s"SELECT stringField, p, count(intField) FROM $table GROUP BY p, stringField"), - Row("part-1", 1, 10) :: - Row("part-2", 2, 10) :: - Row("part-3", 3, 10) :: - Row("part-4", 4, 10) :: - Row("part-5", 5, 10) :: - Row("part-6", 6, 10) :: - Row("part-7", 7, 10) :: - Row("part-8", 8, 10) :: - Row("part-9", 9, 10) :: - Row("part-10", 10, 10) :: Nil - ) - } - - test(s"simple count $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table"), - Row(100)) - } - - test(s"pruned count $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE p = 1"), - Row(10)) - } - - test(s"non-existant partition $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE p = 1000"), - Row(0)) - } - - test(s"multi-partition pruned count $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE p IN (1,2,3)"), - Row(30)) - } - - test(s"non-partition predicates $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE intField IN (1,2,3)"), - Row(30)) - } - - test(s"sum $table") { - checkAnswer( - sql(s"SELECT SUM(intField) FROM $table WHERE intField IN (1,2,3) AND p = 1"), - Row(1 + 2 + 3)) - } - - test(s"hive udfs $table") { - checkAnswer( - sql(s"SELECT concat(stringField, stringField) FROM $table"), - sql(s"SELECT stringField FROM $table").map { - case Row(s: String) => Row(s + s) - }.collect().toSeq) - } - } - - test("non-part select(*)") { - checkAnswer( - sql("SELECT COUNT(*) FROM normal_parquet"), - Row(10)) - } -} diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala index 254919e8f6fdc..30646ddbc29d8 100644 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala @@ -43,7 +43,9 @@ import org.apache.hadoop.mapred.InputFormat import org.apache.spark.sql.types.{Decimal, DecimalType} -case class HiveFunctionWrapper(functionClassName: String) extends java.io.Serializable { +private[hive] case class HiveFunctionWrapper(functionClassName: String) + extends java.io.Serializable { + // for Serialization def this() = this(null) @@ -160,7 +162,7 @@ private[hive] object HiveShim { if (value == null) null else new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) def getDateWritable(value: Any): hiveIo.DateWritable = - if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date]) + if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[Int]) def getTimestampWritable(value: Any): hiveIo.TimestampWritable = if (value == null) { @@ -245,8 +247,13 @@ private[hive] object HiveShim { def prepareWritable(w: Writable): Writable = { w } + + def setTblNullFormat(crtTbl: CreateTableDesc, tbl: Table) = {} } -class ShimFileSinkDesc(var dir: String, var tableInfo: TableDesc, var compressed: Boolean) +private[hive] class ShimFileSinkDesc( + var dir: String, + var tableInfo: TableDesc, + var compressed: Boolean) extends FileSinkDesc(dir, tableInfo, compressed) { } diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index 45ca59ae56a38..f9fcbdae15745 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -35,6 +35,7 @@ import org.apache.hadoop.hive.ql.Context import org.apache.hadoop.hive.ql.metadata.{Table, Hive, Partition} import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory +import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, DecimalTypeInfo, TypeInfoFactory} import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} import org.apache.hadoop.hive.serde2.objectinspector.{PrimitiveObjectInspector, ObjectInspector} @@ -55,7 +56,9 @@ import org.apache.spark.sql.types.{Decimal, DecimalType} * * @param functionClassName UDF class name */ -case class HiveFunctionWrapper(var functionClassName: String) extends java.io.Externalizable { +private[hive] case class HiveFunctionWrapper(var functionClassName: String) + extends java.io.Externalizable { + // for Serialization def this() = this(null) @@ -263,7 +266,7 @@ private[hive] object HiveShim { } def getDateWritable(value: Any): hiveIo.DateWritable = - if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date]) + if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[Int]) def getTimestampWritable(value: Any): hiveIo.TimestampWritable = if (value == null) { @@ -410,13 +413,22 @@ private[hive] object HiveShim { } w } + + def setTblNullFormat(crtTbl: CreateTableDesc, tbl: Table) = { + if (crtTbl != null && crtTbl.getNullFormat() != null) { + tbl.setSerdeParam(serdeConstants.SERIALIZATION_NULL_FORMAT, crtTbl.getNullFormat()) + } + } } /* * Bug introduced in hive-0.13. FileSinkDesc is serilizable, but its member path is not. * Fix it through wrapper. */ -class ShimFileSinkDesc(var dir: String, var tableInfo: TableDesc, var compressed: Boolean) +private[hive] class ShimFileSinkDesc( + var dir: String, + var tableInfo: TableDesc, + var compressed: Boolean) extends Serializable with Logging { var compressCodec: String = _ var compressType: String = _ diff --git a/streaming/pom.xml b/streaming/pom.xml index 1e92ba686a57d..6111e11695ac2 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.3.2-SNAPSHOT ../pom.xml @@ -82,6 +82,11 @@ junit test + + org.seleniumhq.selenium + selenium-java + test + com.novocode junit-interface diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index b780282bdac37..832ce782b4dad 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -43,10 +43,13 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) val delaySeconds = MetadataCleaner.getDelaySeconds(ssc.conf) val sparkConfPairs = ssc.conf.getAll - def sparkConf = { - new SparkConf(false).setAll(sparkConfPairs) + def createSparkConf(): SparkConf = { + val newSparkConf = new SparkConf(loadDefaults = false).setAll(sparkConfPairs) .remove("spark.driver.host") .remove("spark.driver.port") + val newMasterOption = new SparkConf(loadDefaults = true).getOption("spark.master") + newMasterOption.foreach { newMaster => newSparkConf.setMaster(newMaster) } + newSparkConf } def validate() { @@ -116,7 +119,10 @@ class CheckpointWriter( private var stopped = false private var fs_ : FileSystem = _ - class CheckpointWriteHandler(checkpointTime: Time, bytes: Array[Byte]) extends Runnable { + class CheckpointWriteHandler( + checkpointTime: Time, + bytes: Array[Byte], + clearCheckpointDataLater: Boolean) extends Runnable { def run() { var attempts = 0 val startTime = System.currentTimeMillis() @@ -163,7 +169,7 @@ class CheckpointWriter( val finishTime = System.currentTimeMillis() logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + checkpointFile + "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " ms") - jobGenerator.onCheckpointCompletion(checkpointTime) + jobGenerator.onCheckpointCompletion(checkpointTime, clearCheckpointDataLater) return } catch { case ioe: IOException => @@ -177,7 +183,7 @@ class CheckpointWriter( } } - def write(checkpoint: Checkpoint) { + def write(checkpoint: Checkpoint, clearCheckpointDataLater: Boolean) { val bos = new ByteArrayOutputStream() val zos = compressionCodec.compressedOutputStream(bos) val oos = new ObjectOutputStream(zos) @@ -185,7 +191,8 @@ class CheckpointWriter( oos.close() bos.close() try { - executor.execute(new CheckpointWriteHandler(checkpoint.checkpointTime, bos.toByteArray)) + executor.execute(new CheckpointWriteHandler( + checkpoint.checkpointTime, bos.toByteArray, clearCheckpointDataLater)) logDebug("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue") } catch { case rej: RejectedExecutionException => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 8ef0787137845..543224d4b07bc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -27,10 +27,12 @@ import scala.reflect.ClassTag import akka.actor.{Props, SupervisorStrategy} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.io.{BytesWritable, LongWritable, Text} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.spark._ +import org.apache.spark.annotation.Experimental +import org.apache.spark.input.FixedLengthBinaryInputFormat import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream._ @@ -114,7 +116,7 @@ class StreamingContext private[streaming] ( private[streaming] val sc: SparkContext = { if (isCheckpointPresent) { - new SparkContext(cp_.sparkConf) + new SparkContext(cp_.createSparkConf()) } else { sc_ } @@ -359,6 +361,30 @@ class StreamingContext private[streaming] ( new FileInputDStream[K, V, F](this, directory, filter, newFilesOnly) } + /** + * Create a input stream that monitors a Hadoop-compatible filesystem + * for new files and reads them using the given key-value types and input format. + * Files must be written to the monitored directory by "moving" them from another + * location within the same file system. File names starting with . are ignored. + * @param directory HDFS directory to monitor for new file + * @param filter Function to filter paths to process + * @param newFilesOnly Should process only new files and ignore existing files in the directory + * @param conf Hadoop configuration + * @tparam K Key type for reading HDFS file + * @tparam V Value type for reading HDFS file + * @tparam F Input format for reading HDFS file + */ + def fileStream[ + K: ClassTag, + V: ClassTag, + F <: NewInputFormat[K, V]: ClassTag + ] (directory: String, + filter: Path => Boolean, + newFilesOnly: Boolean, + conf: Configuration): InputDStream[(K, V)] = { + new FileInputDStream[K, V, F](this, directory, filter, newFilesOnly, Option(conf)) + } + /** * Create a input stream that monitors a Hadoop-compatible filesystem * for new files and reads them as text files (using key as LongWritable, value @@ -371,6 +397,37 @@ class StreamingContext private[streaming] ( fileStream[LongWritable, Text, TextInputFormat](directory).map(_._2.toString) } + /** + * :: Experimental :: + * + * Create an input stream that monitors a Hadoop-compatible filesystem + * for new files and reads them as flat binary files, assuming a fixed length per record, + * generating one byte array per record. Files must be written to the monitored directory + * by "moving" them from another location within the same file system. File names + * starting with . are ignored. + * + * '''Note:''' We ensure that the byte array for each record in the + * resulting RDDs of the DStream has the provided record length. + * + * @param directory HDFS directory to monitor for new file + * @param recordLength length of each record in bytes + */ + @Experimental + def binaryRecordsStream( + directory: String, + recordLength: Int): DStream[Array[Byte]] = { + val conf = sc_.hadoopConfiguration + conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength) + val br = fileStream[LongWritable, BytesWritable, FixedLengthBinaryInputFormat]( + directory, FileInputDStream.defaultFilter : Path => Boolean, newFilesOnly=true, conf) + val data = br.map { case (k, v) => + val bytes = v.getBytes + assert(bytes.length == recordLength, "Byte array does not have correct length") + bytes + } + data + } + /** * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. @@ -469,10 +526,23 @@ class StreamingContext private[streaming] ( * will be thrown in this thread. * @param timeout time to wait in milliseconds */ + @deprecated("Use awaitTerminationOrTimeout(Long) instead", "1.3.0") def awaitTermination(timeout: Long) { waiter.waitForStopOrError(timeout) } + /** + * Wait for the execution to stop. Any exceptions that occurs during the execution + * will be thrown in this thread. + * + * @param timeout time to wait in milliseconds + * @return `true` if it's stopped; or throw the reported error during the execution; or `false` + * if the waiting time elapsed before returning from the method. + */ + def awaitTerminationOrTimeout(timeout: Long): Boolean = { + waiter.waitForStopOrError(timeout) + } + /** * Stop the execution of the streams immediately (does not wait for all received data * to be processed). @@ -508,6 +578,7 @@ class StreamingContext private[streaming] ( // Even if we have already stopped, we still need to attempt to stop the SparkContext because // a user might stop(stopSparkContext = false) and then call stop(stopSparkContext = true). if (stopSparkContext) sc.stop() + uiTab.foreach(_.detach()) // The state should always be Stopped after calling `stop()`, even if we haven't started yet: state = Stopped } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala index 505e4431e4350..01cdcb0574040 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala @@ -36,7 +36,7 @@ import org.apache.spark.streaming.dstream.DStream * [[org.apache.spark.streaming.api.java.JavaPairDStream]]. */ class JavaDStream[T](val dstream: DStream[T])(implicit val classTag: ClassTag[T]) - extends JavaDStreamLike[T, JavaDStream[T], JavaRDD[T]] { + extends AbstractJavaDStreamLike[T, JavaDStream[T], JavaRDD[T]] { override def wrapRDD(rdd: RDD[T]): JavaRDD[T] = JavaRDD.fromRDD(rdd) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index c382a12f4d099..2eabdd9387913 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -34,6 +34,15 @@ import org.apache.spark.streaming._ import org.apache.spark.streaming.api.java.JavaDStream._ import org.apache.spark.streaming.dstream.DStream +/** + * As a workaround for https://issues.scala-lang.org/browse/SI-8905, implementations + * of JavaDStreamLike should extend this dummy abstract class instead of directly inheriting + * from the trait. See SPARK-3266 for additional details. + */ +private[streaming] +abstract class AbstractJavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], + R <: JavaRDDLike[T, R]] extends JavaDStreamLike[T, This, R] + trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T, R]] extends Serializable { implicit val classTag: ClassTag[T] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index de124cf40eff1..f94f2d0e8bd31 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -45,7 +45,7 @@ import org.apache.spark.streaming.dstream.DStream class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( implicit val kManifest: ClassTag[K], implicit val vManifest: ClassTag[V]) - extends JavaDStreamLike[(K, V), JavaPairDStream[K, V], JavaPairRDD[K, V]] { + extends AbstractJavaDStreamLike[(K, V), JavaPairDStream[K, V], JavaPairRDD[K, V]] { override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd) @@ -726,7 +726,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ - def saveAsHadoopFiles[F <: OutputFormat[K, V]](prefix: String, suffix: String) { + def saveAsHadoopFiles(prefix: String, suffix: String) { dstream.saveAsHadoopFiles(prefix, suffix) } @@ -734,12 +734,12 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ - def saveAsHadoopFiles( + def saveAsHadoopFiles[F <: OutputFormat[_, _]]( prefix: String, suffix: String, keyClass: Class[_], valueClass: Class[_], - outputFormatClass: Class[_ <: OutputFormat[_, _]]) { + outputFormatClass: Class[F]) { dstream.saveAsHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass) } @@ -747,12 +747,12 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ - def saveAsHadoopFiles( + def saveAsHadoopFiles[F <: OutputFormat[_, _]]( prefix: String, suffix: String, keyClass: Class[_], valueClass: Class[_], - outputFormatClass: Class[_ <: OutputFormat[_, _]], + outputFormatClass: Class[F], conf: JobConf) { dstream.saveAsHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass, conf) } @@ -761,7 +761,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ - def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[K, V]](prefix: String, suffix: String) { + def saveAsNewAPIHadoopFiles(prefix: String, suffix: String) { dstream.saveAsNewAPIHadoopFiles(prefix, suffix) } @@ -769,12 +769,12 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ - def saveAsNewAPIHadoopFiles( + def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[_, _]]( prefix: String, suffix: String, keyClass: Class[_], valueClass: Class[_], - outputFormatClass: Class[_ <: NewOutputFormat[_, _]]) { + outputFormatClass: Class[F]) { dstream.saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass) } @@ -782,12 +782,12 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ - def saveAsNewAPIHadoopFiles( + def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[_, _]]( prefix: String, suffix: String, keyClass: Class[_], valueClass: Class[_], - outputFormatClass: Class[_ <: NewOutputFormat[_, _]], + outputFormatClass: Class[F], conf: Configuration = new Configuration) { dstream.saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass, conf) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 9a2254bcdc1f7..e3db01c1e12c6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -29,6 +29,7 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} import org.apache.spark.rdd.RDD @@ -177,7 +178,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { /** * Create an input stream from network source hostname:port. Data is received using - * a TCP socket and the receive bytes it interepreted as object using the given + * a TCP socket and the receive bytes it interpreted as object using the given * converter. * @param hostname Hostname to connect to for receiving data * @param port Port to connect to for receiving data @@ -209,6 +210,24 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { ssc.textFileStream(directory) } + /** + * :: Experimental :: + * + * Create an input stream that monitors a Hadoop-compatible filesystem + * for new files and reads them as flat binary files with fixed record lengths, + * yielding byte arrays + * + * '''Note:''' We ensure that the byte array for each record in the + * resulting RDDs of the DStream has the provided record length. + * + * @param directory HDFS directory to monitor for new files + * @param recordLength The length at which to split the records + */ + @Experimental + def binaryRecordsStream(directory: String, recordLength: Int): JavaDStream[Array[Byte]] = { + ssc.binaryRecordsStream(directory, recordLength) + } + /** * Create an input stream from network source hostname:port, where data is received * as serialized blocks (serialized using the Spark's serializer) that can be directly @@ -298,6 +317,37 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { ssc.fileStream[K, V, F](directory, fn, newFilesOnly) } + /** + * Create an input stream that monitors a Hadoop-compatible filesystem + * for new files and reads them using the given key-value types and input format. + * Files must be written to the monitored directory by "moving" them from another + * location within the same file system. File names starting with . are ignored. + * @param directory HDFS directory to monitor for new file + * @param kClass class of key for reading HDFS file + * @param vClass class of value for reading HDFS file + * @param fClass class of input format for reading HDFS file + * @param filter Function to filter paths to process + * @param newFilesOnly Should process only new files and ignore existing files in the directory + * @param conf Hadoop configuration + * @tparam K Key type for reading HDFS file + * @tparam V Value type for reading HDFS file + * @tparam F Input format for reading HDFS file + */ + def fileStream[K, V, F <: NewInputFormat[K, V]]( + directory: String, + kClass: Class[K], + vClass: Class[V], + fClass: Class[F], + filter: JFunction[Path, JBoolean], + newFilesOnly: Boolean, + conf: Configuration): JavaPairInputDStream[K, V] = { + implicit val cmk: ClassTag[K] = ClassTag(kClass) + implicit val cmv: ClassTag[V] = ClassTag(vClass) + implicit val cmf: ClassTag[F] = ClassTag(fClass) + def fn = (x: Path) => filter.call(x).booleanValue() + ssc.fileStream[K, V, F](directory, fn, newFilesOnly, conf) + } + /** * Create an input stream with any arbitrary user implemented actor receiver. * @param props Props object defining creation of the actor @@ -547,10 +597,23 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * will be thrown in this thread. * @param timeout time to wait in milliseconds */ + @deprecated("Use awaitTerminationOrTimeout(Long) instead", "1.3.0") def awaitTermination(timeout: Long): Unit = { ssc.awaitTermination(timeout) } + /** + * Wait for the execution to stop. Any exceptions that occurs during the execution + * will be thrown in this thread. + * + * @param timeout time to wait in milliseconds + * @return `true` if it's stopped; or throw the reported error during the execution; or `false` + * if the waiting time elapsed before returning from the method. + */ + def awaitTerminationOrTimeout(timeout: Long): Boolean = { + ssc.awaitTerminationOrTimeout(timeout) + } + /** * Stop the execution of the streams. Will stop the associated JavaSparkContext as well. */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index e7c5639a63499..22de8c02e63c8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -18,14 +18,15 @@ package org.apache.spark.streaming.dstream import java.io.{IOException, ObjectInputStream} -import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable import scala.reflect.ClassTag +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} +import org.apache.spark.SerializableWritable import org.apache.spark.rdd.{RDD, UnionRDD} import org.apache.spark.streaming._ import org.apache.spark.util.{TimeStampedHashMap, Utils} @@ -68,13 +69,17 @@ import org.apache.spark.util.{TimeStampedHashMap, Utils} * processing semantics are undefined. */ private[streaming] -class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : ClassTag]( +class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( @transient ssc_ : StreamingContext, directory: String, filter: Path => Boolean = FileInputDStream.defaultFilter, - newFilesOnly: Boolean = true) + newFilesOnly: Boolean = true, + conf: Option[Configuration] = None) + (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]) extends InputDStream[(K, V)](ssc_) { + private val serializableConfOpt = conf.map(new SerializableWritable(_)) + // This is a def so that it works during checkpoint recovery: private def clock = ssc.scheduler.clock @@ -83,7 +88,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas // Initial ignore threshold based on which old, existing files in the directory (at the time of // starting the streaming application) will be ignored or considered - private val initialModTimeIgnoreThreshold = if (newFilesOnly) clock.currentTime() else 0L + private val initialModTimeIgnoreThreshold = if (newFilesOnly) clock.getTimeMillis() else 0L /* * Make sure that the information of files selected in the last few batches are remembered. @@ -156,7 +161,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas */ private def findNewFiles(currentTime: Long): Array[String] = { try { - lastNewFileFindingTime = clock.currentTime() + lastNewFileFindingTime = clock.getTimeMillis() // Calculate ignore threshold val modTimeIgnoreThreshold = math.max( @@ -169,7 +174,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas def accept(path: Path): Boolean = isNewFile(path, currentTime, modTimeIgnoreThreshold) } val newFiles = fs.listStatus(directoryPath, filter).map(_.getPath.toString) - val timeTaken = clock.currentTime() - lastNewFileFindingTime + val timeTaken = clock.getTimeMillis() - lastNewFileFindingTime logInfo("Finding new files took " + timeTaken + " ms") logDebug("# cached file times = " + fileToModTime.size) if (timeTaken > slideDuration.milliseconds) { @@ -237,7 +242,15 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas /** Generate one RDD from an array of files */ private def filesToRDD(files: Seq[String]): RDD[(K, V)] = { val fileRDDs = files.map(file =>{ - val rdd = context.sparkContext.newAPIHadoopFile[K, V, F](file) + val rdd = serializableConfOpt.map(_.value) match { + case Some(config) => context.sparkContext.newAPIHadoopFile( + file, + fm.runtimeClass.asInstanceOf[Class[F]], + km.runtimeClass.asInstanceOf[Class[K]], + vm.runtimeClass.asInstanceOf[Class[V]], + config) + case None => context.sparkContext.newAPIHadoopFile[K, V, F](file) + } if (rdd.partitions.size == 0) { logError("File " + file + " has no data in it. Spark Streaming can only ingest " + "files that have been \"moved\" to the directory assigned to the file stream. " + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 79263a7183977..ee5e639b26d91 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -23,7 +23,8 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StreamBlockId -import org.apache.spark.streaming.util.{RecurringTimer, SystemClock} +import org.apache.spark.streaming.util.RecurringTimer +import org.apache.spark.util.SystemClock /** Listener object for BlockGenerator events */ private[streaming] trait BlockGeneratorListener { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index f7a8ebee8a544..dcdc27d29c270 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -27,8 +27,8 @@ import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage._ -import org.apache.spark.streaming.util.{Clock, SystemClock, WriteAheadLogFileSegment, WriteAheadLogManager} -import org.apache.spark.util.Utils +import org.apache.spark.streaming.util.{WriteAheadLogFileSegment, WriteAheadLogManager} +import org.apache.spark.util.{Clock, SystemClock, Utils} /** Trait that represents the metadata related to storage of blocks */ private[streaming] trait ReceivedBlockStoreResult { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 8632c94349bf9..59488dfb0f8c6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -23,13 +23,15 @@ import akka.actor.{ActorRef, Props, Actor} import org.apache.spark.{SparkEnv, Logging} import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time} -import org.apache.spark.streaming.util.{Clock, ManualClock, RecurringTimer} +import org.apache.spark.streaming.util.RecurringTimer +import org.apache.spark.util.{Clock, ManualClock} /** Event classes for JobGenerator */ private[scheduler] sealed trait JobGeneratorEvent private[scheduler] case class GenerateJobs(time: Time) extends JobGeneratorEvent private[scheduler] case class ClearMetadata(time: Time) extends JobGeneratorEvent -private[scheduler] case class DoCheckpoint(time: Time) extends JobGeneratorEvent +private[scheduler] case class DoCheckpoint( + time: Time, clearCheckpointDataLater: Boolean) extends JobGeneratorEvent private[scheduler] case class ClearCheckpointData(time: Time) extends JobGeneratorEvent /** @@ -45,8 +47,14 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { val clock = { val clockClass = ssc.sc.conf.get( - "spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock") - Class.forName(clockClass).newInstance().asInstanceOf[Clock] + "spark.streaming.clock", "org.apache.spark.util.SystemClock") + try { + Class.forName(clockClass).newInstance().asInstanceOf[Clock] + } catch { + case e: ClassNotFoundException if clockClass.startsWith("org.apache.spark.streaming") => + val newClockClass = clockClass.replace("org.apache.spark.streaming", "org.apache.spark") + Class.forName(newClockClass).newInstance().asInstanceOf[Clock] + } } private val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds, @@ -156,8 +164,10 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** * Callback called when the checkpoint of a batch has been written. */ - def onCheckpointCompletion(time: Time) { - eventActor ! ClearCheckpointData(time) + def onCheckpointCompletion(time: Time, clearCheckpointDataLater: Boolean) { + if (clearCheckpointDataLater) { + eventActor ! ClearCheckpointData(time) + } } /** Processes all events */ @@ -166,7 +176,8 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { event match { case GenerateJobs(time) => generateJobs(time) case ClearMetadata(time) => clearMetadata(time) - case DoCheckpoint(time) => doCheckpoint(time) + case DoCheckpoint(time, clearCheckpointDataLater) => + doCheckpoint(time, clearCheckpointDataLater) case ClearCheckpointData(time) => clearCheckpointData(time) } } @@ -238,7 +249,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e) } - eventActor ! DoCheckpoint(time) + eventActor ! DoCheckpoint(time, clearCheckpointDataLater = false) } /** Clear DStream metadata for the given `time`. */ @@ -248,7 +259,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // If checkpointing is enabled, then checkpoint, // else mark batch to be fully processed if (shouldCheckpoint) { - eventActor ! DoCheckpoint(time) + eventActor ! DoCheckpoint(time, clearCheckpointDataLater = true) } else { // If checkpointing is not enabled, then delete metadata information about // received blocks (block data not saved in any case). Otherwise, wait for @@ -271,11 +282,11 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { } /** Perform checkpoint for the give `time`. */ - private def doCheckpoint(time: Time) { + private def doCheckpoint(time: Time, clearCheckpointDataLater: Boolean) { if (shouldCheckpoint && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) { logInfo("Checkpointing graph for time " + time) ssc.graph.updateCheckpointData(time) - checkpointWriter.write(new Checkpoint(ssc, time)) + checkpointWriter.write(new Checkpoint(ssc, time), clearCheckpointDataLater) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 0e0f5bd3b9db4..b3ffc71904c76 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -73,7 +73,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { logDebug("Stopping JobScheduler") // First, stop receiving - receiverTracker.stop() + receiverTracker.stop(processAllReceivedData) // Second, stop generating jobs. If it has to process all received data, // then this will wait for all the processing through JobScheduler to be over. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index e19ac939f9ac5..200cf4ef4b0f1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -27,8 +27,8 @@ import org.apache.hadoop.fs.Path import org.apache.spark.{SparkException, Logging, SparkConf} import org.apache.spark.streaming.Time -import org.apache.spark.streaming.util.{Clock, WriteAheadLogManager} -import org.apache.spark.util.Utils +import org.apache.spark.streaming.util.WriteAheadLogManager +import org.apache.spark.util.{Clock, Utils} /** Trait representing any event in the ReceivedBlockTracker that updates its state. */ private[streaming] sealed trait ReceivedBlockTrackerLogEvent @@ -150,7 +150,7 @@ private[streaming] class ReceivedBlockTracker( * returns only after the files are cleaned up. */ def cleanupOldBatches(cleanupThreshTime: Time, waitForCompletion: Boolean): Unit = synchronized { - assert(cleanupThreshTime.milliseconds < clock.currentTime()) + assert(cleanupThreshTime.milliseconds < clock.getTimeMillis()) val timesToCleanup = timeToAllocatedBlocks.keys.filter { _ < cleanupThreshTime }.toSeq logInfo("Deleting batches " + timesToCleanup) writeToLog(BatchCleanupEvent(timesToCleanup)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 4f998869731ed..b36aeb341d25e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -86,10 +86,10 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } /** Stop the receiver execution thread. */ - def stop() = synchronized { + def stop(graceful: Boolean) = synchronized { if (!receiverInputStreams.isEmpty && actor != null) { // First, stop the receivers - if (!skipReceiverLaunch) receiverExecutor.stop() + if (!skipReceiverLaunch) receiverExecutor.stop(graceful) // Finally, stop the actor ssc.env.actorSystem.stop(actor) @@ -218,6 +218,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** This thread class runs all the receivers on the cluster. */ class ReceiverLauncher { @transient val env = ssc.env + @volatile @transient private var running = false @transient val thread = new Thread() { override def run() { try { @@ -233,7 +234,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false thread.start() } - def stop() { + def stop(graceful: Boolean) { // Send the stop signal to all the receivers stopReceivers() @@ -241,9 +242,19 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // That is, for the receivers to quit gracefully. thread.join(10000) + if (graceful) { + val pollTime = 100 + def done = { receiverInfo.isEmpty && !running } + logInfo("Waiting for receiver job to terminate gracefully") + while(!done) { + Thread.sleep(pollTime) + } + logInfo("Waited for receiver job to terminate gracefully") + } + // Check if all the receivers have been deregistered or not if (!receiverInfo.isEmpty) { - logWarning("All of the receivers have not deregistered, " + receiverInfo) + logWarning("Not all of the receivers have deregistered, " + receiverInfo) } else { logInfo("All of the receivers have deregistered successfully") } @@ -295,7 +306,9 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Distribute the receivers and start them logInfo("Starting " + receivers.length + " receivers") + running = true ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver)) + running = false logInfo("All of the receivers have been terminated") } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 98e9a2e639e25..bfe8086fcf8fe 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -32,7 +32,7 @@ private[ui] class StreamingPage(parent: StreamingTab) extends WebUIPage("") with Logging { private val listener = parent.listener - private val startTime = Calendar.getInstance().getTime() + private val startTime = System.currentTimeMillis() private val emptyCell = "-" /** Render the page */ @@ -47,7 +47,7 @@ private[ui] class StreamingPage(parent: StreamingTab) /** Generate basic stats of the streaming program */ private def generateBasicStats(): Seq[Node] = { - val timeSinceStart = System.currentTimeMillis() - startTime.getTime + val timeSinceStart = System.currentTimeMillis() - startTime
  • Started at: {startTime.toString} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala index d9d04cd706a04..9a860ea4a6c68 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -36,6 +36,10 @@ private[spark] class StreamingTab(ssc: StreamingContext) ssc.addStreamingListener(listener) attachPage(new StreamingPage(this)) parent.attachTab(this) + + def detach() { + getSparkUI(ssc).detachTab(this) + } } private object StreamingTab { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala deleted file mode 100644 index d6d96d7ba00fd..0000000000000 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.util - -private[streaming] -trait Clock { - def currentTime(): Long - def waitTillTime(targetTime: Long): Long -} - -private[streaming] -class SystemClock() extends Clock { - - val minPollTime = 25L - - def currentTime(): Long = { - System.currentTimeMillis() - } - - def waitTillTime(targetTime: Long): Long = { - var currentTime = 0L - currentTime = System.currentTimeMillis() - - var waitTime = targetTime - currentTime - if (waitTime <= 0) { - return currentTime - } - - val pollTime = math.max(waitTime / 10.0, minPollTime).toLong - - while (true) { - currentTime = System.currentTimeMillis() - waitTime = targetTime - currentTime - if (waitTime <= 0) { - return currentTime - } - val sleepTime = math.min(waitTime, pollTime) - Thread.sleep(sleepTime) - } - -1 - } -} - -private[streaming] -class ManualClock() extends Clock { - - private var time = 0L - - def currentTime() = this.synchronized { - time - } - - def setTime(timeToSet: Long) = { - this.synchronized { - time = timeToSet - this.notifyAll() - } - } - - def addToTime(timeToAdd: Long) = { - this.synchronized { - time += timeToAdd - this.notifyAll() - } - } - def waitTillTime(targetTime: Long): Long = { - this.synchronized { - while (time < targetTime) { - this.wait(100) - } - } - currentTime() - } -} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala index 1a616a0434f2c..c8eef833eb431 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming.util import org.apache.spark.Logging +import org.apache.spark.util.{Clock, SystemClock} private[streaming] class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: String) @@ -38,7 +39,7 @@ class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: * current system time. */ def getStartTime(): Long = { - (math.floor(clock.currentTime.toDouble / period) + 1).toLong * period + (math.floor(clock.getTimeMillis().toDouble / period) + 1).toLong * period } /** @@ -48,7 +49,7 @@ class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: * more than current time. */ def getRestartTime(originalStartTime: Long): Long = { - val gap = clock.currentTime - originalStartTime + val gap = clock.getTimeMillis() - originalStartTime (math.floor(gap.toDouble / period).toLong + 1) * period + originalStartTime } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala index 166661b7496df..985ded9111f74 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala @@ -19,13 +19,12 @@ package org.apache.spark.streaming.util import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer -import scala.concurrent.duration.Duration import scala.concurrent.{Await, ExecutionContext, Future} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.Logging -import org.apache.spark.util.Utils +import org.apache.spark.util.{Clock, SystemClock, Utils} import WriteAheadLogManager._ /** @@ -82,7 +81,7 @@ private[streaming] class WriteAheadLogManager( var succeeded = false while (!succeeded && failures < maxFailures) { try { - fileSegment = getLogWriter(clock.currentTime).write(byteBuffer) + fileSegment = getLogWriter(clock.getTimeMillis()).write(byteBuffer) succeeded = true } catch { case ex: Exception => diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 2df8cf6a8a3df..57302ff407183 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -1828,4 +1828,22 @@ private List> fileTestPrepare(File testDir) throws IOException { return expected; } + + // SPARK-5795: no logic assertions, just testing that intended API invocations compile + private void compileSaveAsJavaAPI(JavaPairDStream pds) { + pds.saveAsNewAPIHadoopFiles( + "", "", LongWritable.class, Text.class, + org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class); + pds.saveAsHadoopFiles( + "", "", LongWritable.class, Text.class, + org.apache.hadoop.mapred.SequenceFileOutputFormat.class); + // Checks that a previous common workaround for this API still compiles + pds.saveAsNewAPIHadoopFiles( + "", "", LongWritable.class, Text.class, + (Class) org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class); + pds.saveAsHadoopFiles( + "", "", LongWritable.class, Text.class, + (Class) org.apache.hadoop.mapred.SequenceFileOutputFormat.class); + } + } diff --git a/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 1e24da7f5f60c..cfedb5a042a35 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -31,7 +31,7 @@ public void setUp() { SparkConf conf = new SparkConf() .setMaster("local[2]") .setAppName("test") - .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index e8f4a7779ec21..cf191715d29d6 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -22,13 +22,12 @@ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.language.existentials import scala.reflect.ClassTag -import util.ManualClock - import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.SparkContext._ import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.{DStream, WindowedDStream} +import org.apache.spark.util.{Clock, ManualClock} import org.apache.spark.HashPartitioner class BasicOperationsSuite extends TestSuiteBase { @@ -586,7 +585,7 @@ class BasicOperationsSuite extends TestSuiteBase { for (i <- 0 until input.size) { testServer.send(input(i).toString + "\n") Thread.sleep(200) - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) collectRddInfo() } @@ -637,8 +636,8 @@ class BasicOperationsSuite extends TestSuiteBase { ssc.graph.getOutputStreams().head.dependencies.head.asInstanceOf[DStream[T]] if (rememberDuration != null) ssc.remember(rememberDuration) val output = runStreams[(Int, Int)](ssc, cleanupTestInput.size, numExpectedOutput) - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - assert(clock.currentTime() === Seconds(10).milliseconds) + val clock = ssc.scheduler.clock.asInstanceOf[Clock] + assert(clock.getTimeMillis() === Seconds(10).milliseconds) assert(output.size === numExpectedOutput) operatedStream } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 8f8bc61437ba5..8ea91eca683cf 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -32,8 +32,7 @@ import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutput import org.scalatest.concurrent.Eventually._ import org.apache.spark.streaming.dstream.{DStream, FileInputDStream} -import org.apache.spark.streaming.util.ManualClock -import org.apache.spark.util.Utils +import org.apache.spark.util.{Clock, ManualClock, Utils} /** * This test suites tests the checkpointing functionality of DStreams - @@ -61,7 +60,7 @@ class CheckpointSuite extends TestSuiteBase { assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second") - conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") + conf.set("spark.streaming.clock", "org.apache.spark.util.ManualClock") val stateStreamCheckpointInterval = Seconds(1) val fs = FileSystem.getLocal(new Configuration()) @@ -147,7 +146,7 @@ class CheckpointSuite extends TestSuiteBase { // This tests whether spark conf persists through checkpoints, and certain // configs gets scrubbed - test("persistence of conf through checkpoints") { + test("recovery of conf through checkpoints") { val key = "spark.mykey" val value = "myvalue" System.setProperty(key, value) @@ -155,7 +154,7 @@ class CheckpointSuite extends TestSuiteBase { val originalConf = ssc.conf val cp = new Checkpoint(ssc, Time(1000)) - val cpConf = cp.sparkConf + val cpConf = cp.createSparkConf() assert(cpConf.get("spark.master") === originalConf.get("spark.master")) assert(cpConf.get("spark.app.name") === originalConf.get("spark.app.name")) assert(cpConf.get(key) === value) @@ -164,7 +163,8 @@ class CheckpointSuite extends TestSuiteBase { // Serialize/deserialize to simulate write to storage and reading it back val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp)) - val newCpConf = newCp.sparkConf + // Verify new SparkConf has all the previous properties + val newCpConf = newCp.createSparkConf() assert(newCpConf.get("spark.master") === originalConf.get("spark.master")) assert(newCpConf.get("spark.app.name") === originalConf.get("spark.app.name")) assert(newCpConf.get(key) === value) @@ -175,6 +175,20 @@ class CheckpointSuite extends TestSuiteBase { ssc = new StreamingContext(null, newCp, null) val restoredConf = ssc.conf assert(restoredConf.get(key) === value) + ssc.stop() + + // Verify new SparkConf picks up new master url if it is set in the properties. See SPARK-6331. + try { + val newMaster = "local[100]" + System.setProperty("spark.master", newMaster) + val newCpConf = newCp.createSparkConf() + assert(newCpConf.get("spark.master") === newMaster) + assert(newCpConf.get("spark.app.name") === originalConf.get("spark.app.name")) + ssc = new StreamingContext(null, newCp, null) + assert(ssc.sparkContext.master === newMaster) + } finally { + System.clearProperty("spark.master") + } } @@ -324,13 +338,13 @@ class CheckpointSuite extends TestSuiteBase { * Writes a file named `i` (which contains the number `i`) to the test directory and sets its * modification time to `clock`'s current time. */ - def writeFile(i: Int, clock: ManualClock): Unit = { + def writeFile(i: Int, clock: Clock): Unit = { val file = new File(testDir, i.toString) Files.write(i + "\n", file, Charsets.UTF_8) - assert(file.setLastModified(clock.currentTime())) + assert(file.setLastModified(clock.getTimeMillis())) // Check that the file's modification date is actually the value we wrote, since rounding or // truncation will break the test: - assert(file.lastModified() === clock.currentTime()) + assert(file.lastModified() === clock.getTimeMillis()) } /** @@ -372,13 +386,13 @@ class CheckpointSuite extends TestSuiteBase { ssc.start() // Advance half a batch so that the first file is created after the StreamingContext starts - clock.addToTime(batchDuration.milliseconds / 2) + clock.advance(batchDuration.milliseconds / 2) // Create files and advance manual clock to process them for (i <- Seq(1, 2, 3)) { writeFile(i, clock) // Advance the clock after creating the file to avoid a race when // setting its modification time - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) if (i != 3) { // Since we want to shut down while the 3rd batch is processing eventually(eventuallyTimeout) { @@ -386,7 +400,7 @@ class CheckpointSuite extends TestSuiteBase { } } } - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) eventually(eventuallyTimeout) { // Wait until all files have been recorded and all batches have started assert(recordedFiles(ssc) === Seq(1, 2, 3) && batchCounter.getNumStartedBatches === 3) @@ -410,7 +424,7 @@ class CheckpointSuite extends TestSuiteBase { writeFile(i, clock) // Advance the clock after creating the file to avoid a race when // setting its modification time - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) } // Recover context from checkpoint file and verify whether the files that were @@ -419,7 +433,7 @@ class CheckpointSuite extends TestSuiteBase { withStreamingContext(new StreamingContext(checkpointDir)) { ssc => // So that the restarted StreamingContext's clock has gone forward in time since failure ssc.conf.set("spark.streaming.manualClock.jump", (batchDuration * 3).milliseconds.toString) - val oldClockTime = clock.currentTime() + val oldClockTime = clock.getTimeMillis() clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val batchCounter = new BatchCounter(ssc) val outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[Int]] @@ -430,7 +444,7 @@ class CheckpointSuite extends TestSuiteBase { ssc.start() // Verify that the clock has traveled forward to the expected time eventually(eventuallyTimeout) { - clock.currentTime() === oldClockTime + clock.getTimeMillis() === oldClockTime } // Wait for pre-failure batch to be recomputed (3 while SSC was down plus last batch) val numBatchesAfterRestart = 4 @@ -441,12 +455,12 @@ class CheckpointSuite extends TestSuiteBase { writeFile(i, clock) // Advance the clock after creating the file to avoid a race when // setting its modification time - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) eventually(eventuallyTimeout) { assert(batchCounter.getNumCompletedBatches === index + numBatchesAfterRestart + 1) } } - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) logInfo("Output after restart = " + outputStream.output.mkString("[", ", ", "]")) assert(outputStream.output.size > 0, "No files processed after restart") ssc.stop() @@ -521,12 +535,12 @@ class CheckpointSuite extends TestSuiteBase { */ def advanceTimeWithRealDelay[V: ClassTag](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] = { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - logInfo("Manual clock before advancing = " + clock.currentTime()) + logInfo("Manual clock before advancing = " + clock.getTimeMillis()) for (i <- 1 to numBatches.toInt) { - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) Thread.sleep(batchDuration.milliseconds) } - logInfo("Manual clock after advancing = " + clock.currentTime()) + logInfo("Manual clock after advancing = " + clock.getTimeMillis()) Thread.sleep(batchDuration.milliseconds) val outputStream = ssc.graph.getOutputStreams.filter { dstream => diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index bddf51e130422..7ed6320a3d0bc 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -17,12 +17,8 @@ package org.apache.spark.streaming -import akka.actor.Actor -import akka.actor.Props -import akka.util.ByteString - import java.io.{File, BufferedWriter, OutputStreamWriter} -import java.net.{InetSocketAddress, SocketException, ServerSocket} +import java.net.{SocketException, ServerSocket} import java.nio.charset.Charset import java.util.concurrent.{Executors, TimeUnit, ArrayBlockingQueue} import java.util.concurrent.atomic.AtomicInteger @@ -36,9 +32,8 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.util.ManualClock -import org.apache.spark.util.Utils -import org.apache.spark.streaming.receiver.{ActorHelper, Receiver} +import org.apache.spark.util.{ManualClock, Utils} +import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.rdd.RDD import org.apache.hadoop.io.{Text, LongWritable} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat @@ -69,7 +64,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { for (i <- 0 until input.size) { testServer.send(input(i).toString + "\n") Thread.sleep(500) - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) } Thread.sleep(1000) logInfo("Stopping server") @@ -95,6 +90,57 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } } + test("binary records stream") { + val testDir: File = null + try { + val batchDuration = Seconds(2) + val testDir = Utils.createTempDir() + // Create a file that exists before the StreamingContext is created: + val existingFile = new File(testDir, "0") + Files.write("0\n", existingFile, Charset.forName("UTF-8")) + assert(existingFile.setLastModified(10000) && existingFile.lastModified === 10000) + + // Set up the streaming context and input streams + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + // This `setTime` call ensures that the clock is past the creation time of `existingFile` + clock.setTime(existingFile.lastModified + batchDuration.milliseconds) + val batchCounter = new BatchCounter(ssc) + val fileStream = ssc.binaryRecordsStream(testDir.toString, 1) + val outputBuffer = new ArrayBuffer[Seq[Array[Byte]]] + with SynchronizedBuffer[Seq[Array[Byte]]] + val outputStream = new TestOutputStream(fileStream, outputBuffer) + outputStream.register() + ssc.start() + + // Advance the clock so that the files are created after StreamingContext starts, but + // not enough to trigger a batch + clock.advance(batchDuration.milliseconds / 2) + + val input = Seq(1, 2, 3, 4, 5) + input.foreach { i => + Thread.sleep(batchDuration.milliseconds) + val file = new File(testDir, i.toString) + Files.write(Array[Byte](i.toByte), file) + assert(file.setLastModified(clock.getTimeMillis())) + assert(file.lastModified === clock.getTimeMillis()) + logInfo("Created file " + file) + // Advance the clock after creating the file to avoid a race when + // setting its modification time + clock.advance(batchDuration.milliseconds) + eventually(eventuallyTimeout) { + assert(batchCounter.getNumCompletedBatches === i) + } + } + + val expectedOutput = input.map(i => i.toByte) + val obtainedOutput = outputBuffer.flatten.toList.map(i => i(0).toByte) + assert(obtainedOutput === expectedOutput) + } + } finally { + if (testDir != null) Utils.deleteRecursively(testDir) + } + } test("file input stream - newFilesOnly = true") { testFileStream(newFilesOnly = true) @@ -128,7 +174,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { while((!MultiThreadTestReceiver.haveAllThreadsFinished || output.sum < numTotalRecords) && System.currentTimeMillis() - startTime < 5000) { Thread.sleep(100) - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) } Thread.sleep(1000) logInfo("Stopping context") @@ -163,7 +209,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { for (i <- 0 until input.size) { // Enqueue more than 1 item per tick but they should dequeue one at a time inputIterator.take(2).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) } Thread.sleep(1000) logInfo("Stopping context") @@ -205,12 +251,12 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Enqueue the first 3 items (one by one), they should be merged in the next batch val inputIterator = input.toIterator inputIterator.take(3).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) Thread.sleep(1000) // Enqueue the remaining items (again one by one), merged in the final batch inputIterator.foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) Thread.sleep(1000) logInfo("Stopping context") ssc.stop() @@ -257,19 +303,19 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Advance the clock so that the files are created after StreamingContext starts, but // not enough to trigger a batch - clock.addToTime(batchDuration.milliseconds / 2) + clock.advance(batchDuration.milliseconds / 2) // Over time, create files in the directory val input = Seq(1, 2, 3, 4, 5) input.foreach { i => val file = new File(testDir, i.toString) Files.write(i + "\n", file, Charset.forName("UTF-8")) - assert(file.setLastModified(clock.currentTime())) - assert(file.lastModified === clock.currentTime) + assert(file.setLastModified(clock.getTimeMillis())) + assert(file.lastModified === clock.getTimeMillis()) logInfo("Created file " + file) // Advance the clock after creating the file to avoid a race when // setting its modification time - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) eventually(eventuallyTimeout) { assert(batchCounter.getNumCompletedBatches === i) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 132ff2443fc0f..818f551dbe996 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -39,7 +39,7 @@ import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage._ import org.apache.spark.streaming.receiver._ import org.apache.spark.streaming.util._ -import org.apache.spark.util.AkkaUtils +import org.apache.spark.util.{AkkaUtils, ManualClock} import WriteAheadLogBasedBlockHandler._ import WriteAheadLogSuite._ @@ -165,7 +165,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche preCleanupLogFiles.size should be > 1 // this depends on the number of blocks inserted using generateAndStoreData() - manualClock.currentTime() shouldEqual 5000L + manualClock.getTimeMillis() shouldEqual 5000L val cleanupThreshTime = 3000L handler.cleanupOldBlocks(cleanupThreshTime) @@ -243,7 +243,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche val blockIds = Seq.fill(blocks.size)(generateBlockId()) val storeResults = blocks.zip(blockIds).map { case (block, id) => - manualClock.addToTime(500) // log rolling interval set to 1000 ms through SparkConf + manualClock.advance(500) // log rolling interval set to 1000 ms through SparkConf logDebug("Inserting block " + id) receivedBlockHandler.storeBlock(id, block) }.toList diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index fbb7b0bfebafc..a3a0fd5187403 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -34,9 +34,9 @@ import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult import org.apache.spark.streaming.scheduler._ -import org.apache.spark.streaming.util.{Clock, ManualClock, SystemClock, WriteAheadLogReader} +import org.apache.spark.streaming.util.WriteAheadLogReader import org.apache.spark.streaming.util.WriteAheadLogSuite._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} class ReceivedBlockTrackerSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { @@ -100,7 +100,7 @@ class ReceivedBlockTrackerSuite def incrementTime() { val timeIncrementMillis = 2000L - manualClock.addToTime(timeIncrementMillis) + manualClock.advance(timeIncrementMillis) } // Generate and add blocks to the given tracker @@ -138,13 +138,13 @@ class ReceivedBlockTrackerSuite tracker2.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1 // Allocate blocks to batch and verify whether the unallocated blocks got allocated - val batchTime1 = manualClock.currentTime + val batchTime1 = manualClock.getTimeMillis() tracker2.allocateBlocksToBatch(batchTime1) tracker2.getBlocksOfBatchAndStream(batchTime1, streamId) shouldEqual blockInfos1 // Add more blocks and allocate to another batch incrementTime() - val batchTime2 = manualClock.currentTime + val batchTime2 = manualClock.getTimeMillis() val blockInfos2 = addBlockInfos(tracker2) tracker2.allocateBlocksToBatch(batchTime2) tracker2.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2 diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 9f352bdcb0893..2e5005ef6ff14 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -100,7 +100,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w assert(cp.sparkConfPairs.toMap.getOrElse("spark.cleaner.ttl", "-1") === "10") ssc1.stop() val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp)) - assert(newCp.sparkConf.getInt("spark.cleaner.ttl", -1) === 10) + assert(newCp.createSparkConf().getInt("spark.cleaner.ttl", -1) === 10) ssc = new StreamingContext(null, newCp, null) assert(ssc.conf.getInt("spark.cleaner.ttl", -1) === 10) } @@ -190,7 +190,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w logInfo("Count = " + count + ", Running count = " + runningCount) } ssc.start() - ssc.awaitTermination(500) + ssc.awaitTerminationOrTimeout(500) ssc.stop(stopSparkContext = false, stopGracefully = true) logInfo("Running count = " + runningCount) logInfo("TestReceiver.counter = " + TestReceiver.counter.get()) @@ -205,6 +205,32 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w } } + test("stop slow receiver gracefully") { + val conf = new SparkConf().setMaster(master).setAppName(appName) + conf.set("spark.streaming.gracefulStopTimeout", "20000") + sc = new SparkContext(conf) + logInfo("==================================\n\n\n") + ssc = new StreamingContext(sc, Milliseconds(100)) + var runningCount = 0 + SlowTestReceiver.receivedAllRecords = false + //Create test receiver that sleeps in onStop() + val totalNumRecords = 15 + val recordsPerSecond = 1 + val input = ssc.receiverStream(new SlowTestReceiver(totalNumRecords, recordsPerSecond)) + input.count().foreachRDD { rdd => + val count = rdd.first() + runningCount += count.toInt + logInfo("Count = " + count + ", Running count = " + runningCount) + } + ssc.start() + ssc.awaitTerminationOrTimeout(500) + ssc.stop(stopSparkContext = false, stopGracefully = true) + logInfo("Running count = " + runningCount) + assert(runningCount > 0) + assert(runningCount == totalNumRecords) + Thread.sleep(100) + } + test("awaitTermination") { ssc = new StreamingContext(master, appName, batchDuration) val inputStream = addInputStream(ssc) @@ -217,7 +243,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w // test whether awaitTermination() exits after give amount of time failAfter(1000 millis) { - ssc.awaitTermination(500) + ssc.awaitTerminationOrTimeout(500) } // test whether awaitTermination() does not exit if not time is given @@ -262,7 +288,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w val exception = intercept[Exception] { ssc.start() - ssc.awaitTermination(5000) + ssc.awaitTerminationOrTimeout(5000) } assert(exception.getMessage.contains("map task"), "Expected exception not thrown") } @@ -273,11 +299,35 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w inputStream.transform { rdd => throw new TestException("error in transform"); rdd }.register() val exception = intercept[TestException] { ssc.start() - ssc.awaitTermination(5000) + ssc.awaitTerminationOrTimeout(5000) } assert(exception.getMessage.contains("transform"), "Expected exception not thrown") } + test("awaitTerminationOrTimeout") { + ssc = new StreamingContext(master, appName, batchDuration) + val inputStream = addInputStream(ssc) + inputStream.map(x => x).register() + + ssc.start() + + // test whether awaitTerminationOrTimeout() return false after give amount of time + failAfter(1000 millis) { + assert(ssc.awaitTerminationOrTimeout(500) === false) + } + + // test whether awaitTerminationOrTimeout() return true if context is stopped + failAfter(10000 millis) { // 10 seconds because spark takes a long time to shutdown + new Thread() { + override def run() { + Thread.sleep(500) + ssc.stop() + } + }.start() + assert(ssc.awaitTerminationOrTimeout(10000) === true) + } + } + test("DStream and generated RDD creation sites") { testPackage.test() } @@ -319,6 +369,38 @@ object TestReceiver { val counter = new AtomicInteger(1) } +/** Custom receiver for testing whether a slow receiver can be shutdown gracefully or not */ +class SlowTestReceiver(totalRecords: Int, recordsPerSecond: Int) extends Receiver[Int](StorageLevel.MEMORY_ONLY) with Logging { + + var receivingThreadOption: Option[Thread] = None + + def onStart() { + val thread = new Thread() { + override def run() { + logInfo("Receiving started") + for(i <- 1 to totalRecords) { + Thread.sleep(1000 / recordsPerSecond) + store(i) + } + SlowTestReceiver.receivedAllRecords = true + logInfo(s"Received all $totalRecords records") + } + } + receivingThreadOption = Some(thread) + thread.start() + } + + def onStop() { + // Simulate slow receiver by waiting for all records to be produced + while(!SlowTestReceiver.receivedAllRecords) Thread.sleep(100) + // no cleanup to be done, the receiving thread should stop on it own + } +} + +object SlowTestReceiver { + var receivedAllRecords = false +} + /** Streaming application for testing DStream and RDD creation sites */ package object testPackage extends Assertions { def test() { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 7d82c3e4aadcf..3565d621e8a6c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -31,10 +31,9 @@ import org.scalatest.concurrent.PatienceConfiguration import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream} import org.apache.spark.streaming.scheduler.{StreamingListenerBatchStarted, StreamingListenerBatchCompleted, StreamingListener} -import org.apache.spark.streaming.util.ManualClock import org.apache.spark.{SparkConf, Logging} import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils +import org.apache.spark.util.{ManualClock, Utils} /** * This is a input stream just for the testsuites. This is equivalent to a checkpointable, @@ -189,10 +188,10 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { def beforeFunction() { if (useManualClock) { logInfo("Using manual clock") - conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") + conf.set("spark.streaming.clock", "org.apache.spark.util.ManualClock") } else { logInfo("Using real clock") - conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock") + conf.set("spark.streaming.clock", "org.apache.spark.util.SystemClock") } } @@ -333,23 +332,23 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Advance manual clock val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - logInfo("Manual clock before advancing = " + clock.currentTime()) + logInfo("Manual clock before advancing = " + clock.getTimeMillis()) if (actuallyWait) { for (i <- 1 to numBatches) { logInfo("Actually waiting for " + batchDuration) - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) Thread.sleep(batchDuration.milliseconds) } } else { - clock.addToTime(numBatches * batchDuration.milliseconds) + clock.advance(numBatches * batchDuration.milliseconds) } - logInfo("Manual clock after advancing = " + clock.currentTime()) + logInfo("Manual clock after advancing = " + clock.getTimeMillis()) // Wait until expected number of output items have been generated val startTime = System.currentTimeMillis() while (output.size < numExpectedOutput && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { logInfo("output.size = " + output.size + ", numExpectedOutput = " + numExpectedOutput) - ssc.awaitTermination(50) + ssc.awaitTerminationOrTimeout(50) } val timeTaken = System.currentTimeMillis() - startTime logInfo("Output generated in " + timeTaken + " milliseconds") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala new file mode 100644 index 0000000000000..87a0395efbf2a --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import org.openqa.selenium.WebDriver +import org.openqa.selenium.htmlunit.HtmlUnitDriver +import org.scalatest._ +import org.scalatest.concurrent.Eventually._ +import org.scalatest.selenium.WebBrowser +import org.scalatest.time.SpanSugar._ + +import org.apache.spark._ + + + + +/** + * Selenium tests for the Spark Web UI. + */ +class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase { + + implicit var webDriver: WebDriver = _ + + override def beforeAll(): Unit = { + webDriver = new HtmlUnitDriver + } + + override def afterAll(): Unit = { + if (webDriver != null) { + webDriver.quit() + } + } + + /** + * Create a test SparkStreamingContext with the SparkUI enabled. + */ + private def newSparkStreamingContext(): StreamingContext = { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set("spark.ui.enabled", "true") + val ssc = new StreamingContext(conf, Seconds(1)) + assert(ssc.sc.ui.isDefined, "Spark UI is not started!") + ssc + } + + test("attaching and detaching a Streaming tab") { + withStreamingContext(newSparkStreamingContext()) { ssc => + val sparkUI = ssc.sparkContext.ui.get + + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to (sparkUI.appUIAddress.stripSuffix("/")) + find(cssSelector( """ul li a[href*="streaming"]""")) should not be (None) + } + + eventually(timeout(10 seconds), interval(50 milliseconds)) { + // check whether streaming page exists + go to (sparkUI.appUIAddress.stripSuffix("/") + "/streaming") + val statisticText = findAll(cssSelector("li strong")).map(_.text).toSeq + statisticText should contain("Network receivers:") + statisticText should contain("Batch interval:") + } + + ssc.stop(false) + + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to (sparkUI.appUIAddress.stripSuffix("/")) + find(cssSelector( """ul li a[href*="streaming"]""")) should be(None) + } + + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to (sparkUI.appUIAddress.stripSuffix("/") + "/streaming") + val statisticText = findAll(cssSelector("li strong")).map(_.text).toSeq + statisticText should not contain ("Network receivers:") + statisticText should not contain ("Batch interval:") + } + } + } +} + diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISuite.scala deleted file mode 100644 index 8e30118266855..0000000000000 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISuite.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming - -import scala.io.Source - -import org.scalatest.FunSuite -import org.scalatest.concurrent.Eventually._ -import org.scalatest.time.SpanSugar._ - -import org.apache.spark.SparkConf - -class UISuite extends FunSuite { - - // Ignored: See SPARK-1530 - ignore("streaming tab in spark UI") { - val conf = new SparkConf() - .setMaster("local") - .setAppName("test") - .set("spark.ui.enabled", "true") - val ssc = new StreamingContext(conf, Seconds(1)) - assert(ssc.sc.ui.isDefined, "Spark UI is not started!") - val ui = ssc.sc.ui.get - - eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL(ui.appUIAddress).mkString - assert(!html.contains("random data that should not be present")) - // test if streaming tab exist - assert(html.toLowerCase.contains("streaming")) - // test if other Spark tabs still exist - assert(html.toLowerCase.contains("stages")) - } - - eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL(ui.appUIAddress.stripSuffix("/") + "/streaming").mkString - assert(html.toLowerCase.contains("batch")) - assert(html.toLowerCase.contains("network")) - } - } -} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala new file mode 100644 index 0000000000000..4150b60635ed6 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import java.util.concurrent.CountDownLatch + +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming._ +import org.apache.spark.util.{ManualClock, Utils} + +class JobGeneratorSuite extends TestSuiteBase { + + // SPARK-6222 is a tricky regression bug which causes received block metadata + // to be deleted before the corresponding batch has completed. This occurs when + // the following conditions are met. + // 1. streaming checkpointing is enabled by setting streamingContext.checkpoint(dir) + // 2. input data is received through a receiver as blocks + // 3. a batch processing a set of blocks takes a long time, such that a few subsequent + // batches have been generated and submitted for processing. + // + // The JobGenerator (as of Mar 16, 2015) checkpoints twice per batch, once after generation + // of a batch, and another time after the completion of a batch. The cleanup of + // checkpoint data (including block metadata, etc.) from DStream must be done only after the + // 2nd checkpoint has completed, that is, after the batch has been completely processed. + // However, the issue is that the checkpoint data and along with it received block data is + // cleaned even in the case of the 1st checkpoint, causing pre-mature deletion of received block + // data. For example, if the 3rd batch is still being process, the 7th batch may get generated, + // and the corresponding "1st checkpoint" will delete received block metadata of batch older + // than 6th batch. That, is 3rd batch's block metadata gets deleted even before 3rd batch has + // been completely processed. + // + // This test tries to create that scenario by the following. + // 1. enable checkpointing + // 2. generate batches with received blocks + // 3. make the 3rd batch never complete + // 4. allow subsequent batches to be generated (to allow premature deletion of 3rd batch metadata) + // 5. verify whether 3rd batch's block metadata still exists + // + test("SPARK-6222: Do not clear received block data too soon") { + import JobGeneratorSuite._ + val checkpointDir = Utils.createTempDir() + val testConf = conf + testConf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") + testConf.set("spark.streaming.receiver.writeAheadLog.rollingInterval", "1") + + withStreamingContext(new StreamingContext(testConf, batchDuration)) { ssc => + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val numBatches = 10 + val longBatchNumber = 3 // 3rd batch will take a long time + val longBatchTime = longBatchNumber * batchDuration.milliseconds + + val testTimeout = timeout(10 seconds) + val inputStream = ssc.receiverStream(new TestReceiver) + + inputStream.foreachRDD((rdd: RDD[Int], time: Time) => { + if (time.milliseconds == longBatchTime) { + while (waitLatch.getCount() > 0) { + waitLatch.await() + println("Await over") + } + } + }) + + val batchCounter = new BatchCounter(ssc) + ssc.checkpoint(checkpointDir.getAbsolutePath) + ssc.start() + + // Make sure the only 1 batch of information is to be remembered + assert(inputStream.rememberDuration === batchDuration) + val receiverTracker = ssc.scheduler.receiverTracker + + // Get the blocks belonging to a batch + def getBlocksOfBatch(batchTime: Long) = { + receiverTracker.getBlocksOfBatchAndStream(Time(batchTime), inputStream.id) + } + + // Wait for new blocks to be received + def waitForNewReceivedBlocks() { + eventually(testTimeout) { + assert(receiverTracker.hasUnallocatedBlocks) + } + } + + // Wait for received blocks to be allocated to a batch + def waitForBlocksToBeAllocatedToBatch(batchTime: Long) { + eventually(testTimeout) { + assert(getBlocksOfBatch(batchTime).nonEmpty) + } + } + + // Generate a large number of batches with blocks in them + for (batchNum <- 1 to numBatches) { + waitForNewReceivedBlocks() + clock.advance(batchDuration.milliseconds) + waitForBlocksToBeAllocatedToBatch(clock.getTimeMillis()) + } + + // Wait for 3rd batch to start + eventually(testTimeout) { + ssc.scheduler.getPendingTimes().contains(Time(numBatches * batchDuration.milliseconds)) + } + + // Verify that the 3rd batch's block data is still present while the 3rd batch is incomplete + assert(getBlocksOfBatch(longBatchTime).nonEmpty, "blocks of incomplete batch already deleted") + assert(batchCounter.getNumCompletedBatches < longBatchNumber) + waitLatch.countDown() + } + } +} + +object JobGeneratorSuite { + val waitLatch = new CountDownLatch(1) +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 7ce9499dc614d..8335659667f22 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -26,7 +26,7 @@ import scala.language.{implicitConversions, postfixOps} import WriteAheadLogSuite._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.util.Utils +import org.apache.spark.util.{ManualClock, Utils} import org.scalatest.{BeforeAndAfter, FunSuite} import org.scalatest.concurrent.Eventually._ @@ -197,7 +197,7 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { val logFiles = getLogFilesInDirectory(testDir) assert(logFiles.size > 1) - manager.cleanupOldLogs(manualClock.currentTime() / 2, waitForCompletion) + manager.cleanupOldLogs(manualClock.getTimeMillis() / 2, waitForCompletion) if (waitForCompletion) { assert(getLogFilesInDirectory(testDir).size < logFiles.size) @@ -219,7 +219,7 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { // Recover old files and generate a second set of log files val dataToWrite2 = generateRandomData() - manualClock.addToTime(100000) + manualClock.advance(100000) writeDataUsingManager(testDir, dataToWrite2, manualClock) val logFiles2 = getLogFilesInDirectory(testDir) assert(logFiles2.size > logFiles1.size) @@ -279,12 +279,12 @@ object WriteAheadLogSuite { manualClock: ManualClock = new ManualClock, stopManager: Boolean = true ): WriteAheadLogManager = { - if (manualClock.currentTime < 100000) manualClock.setTime(10000) + if (manualClock.getTimeMillis() < 100000) manualClock.setTime(10000) val manager = new WriteAheadLogManager(logDirectory, hadoopConf, rollingIntervalSecs = 1, callerName = "WriteAheadLogSuite", clock = manualClock) // Ensure that 500 does not get sorted after 2000, so put a high base value. data.foreach { item => - manualClock.addToTime(500) + manualClock.advance(500) manager.writeToLog(item) } if (stopManager) manager.stop() diff --git a/tools/pom.xml b/tools/pom.xml index e7419ed2c607a..b47bb40b2c016 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -19,8 +19,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.3.2-SNAPSHOT ../pom.xml diff --git a/yarn/pom.xml b/yarn/pom.xml index 65344aa8738e0..08250ecdb33e5 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -19,8 +19,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.3.2-SNAPSHOT ../pom.xml diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 37e98e01fddf7..418f9048d6565 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -19,9 +19,9 @@ package org.apache.spark.deploy.yarn import scala.util.control.NonFatal -import java.io.IOException +import java.io.{File, IOException} import java.lang.reflect.InvocationTargetException -import java.net.Socket +import java.net.{Socket, URL} import java.util.concurrent.atomic.AtomicReference import akka.actor._ @@ -38,7 +38,8 @@ import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil} import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.scheduler.cluster.YarnSchedulerBackend import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{AkkaUtils, ChildFirstURLClassLoader, MutableURLClassLoader, + SignalLogger, Utils} /** * Common application master functionality for Spark on Yarn. @@ -54,7 +55,7 @@ private[spark] class ApplicationMaster( private val sparkConf = new SparkConf() private val yarnConf: YarnConfiguration = SparkHadoopUtil.get.newConfiguration(sparkConf) .asInstanceOf[YarnConfiguration] - private val isDriver = args.userClass != null + private val isClusterMode = args.userClass != null // Default to numExecutors * 2, with minimum of 3 private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures", @@ -81,7 +82,7 @@ private[spark] class ApplicationMaster( try { val appAttemptId = client.getAttemptId() - if (isDriver) { + if (isClusterMode) { // Set the web ui port to be ephemeral for yarn so we don't conflict with // other spark processes running on the same box System.setProperty("spark.ui.port", "0") @@ -139,7 +140,7 @@ private[spark] class ApplicationMaster( // doAs in order for the credentials to be passed on to the executor containers. val securityMgr = new SecurityManager(sparkConf) - if (isDriver) { + if (isClusterMode) { runDriver(securityMgr) } else { runExecutorLauncher(securityMgr) @@ -162,7 +163,7 @@ private[spark] class ApplicationMaster( * from the application code. */ final def getDefaultFinalStatus() = { - if (isDriver) { + if (isClusterMode) { FinalApplicationStatus.SUCCEEDED } else { FinalApplicationStatus.UNDEFINED @@ -243,15 +244,14 @@ private[spark] class ApplicationMaster( private def runAMActor( host: String, port: String, - isDriver: Boolean): Unit = { - + isClusterMode: Boolean): Unit = { val driverUrl = AkkaUtils.address( AkkaUtils.protocol(actorSystem), SparkEnv.driverActorSystemName, host, port, YarnSchedulerBackend.ACTOR_NAME) - actor = actorSystem.actorOf(Props(new AMActor(driverUrl, isDriver)), name = "YarnAM") + actor = actorSystem.actorOf(Props(new AMActor(driverUrl, isClusterMode)), name = "YarnAM") } private def runDriver(securityMgr: SecurityManager): Unit = { @@ -272,7 +272,7 @@ private[spark] class ApplicationMaster( runAMActor( sc.getConf.get("spark.driver.host"), sc.getConf.get("spark.driver.port"), - isDriver = true) + isClusterMode = true) registerAM(sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) userClassThread.join() } @@ -427,7 +427,7 @@ private[spark] class ApplicationMaster( sparkConf.set("spark.driver.host", driverHost) sparkConf.set("spark.driver.port", driverPort.toString) - runAMActor(driverHost, driverPort.toString, isDriver = false) + runAMActor(driverHost, driverPort.toString, isClusterMode = false) } /** Add the Yarn IP filter that is required for properly securing the UI. */ @@ -435,7 +435,7 @@ private[spark] class ApplicationMaster( val proxyBase = System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) val amFilter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" val params = client.getAmIpFilterParams(yarnConf, proxyBase) - if (isDriver) { + if (isClusterMode) { System.setProperty("spark.ui.filters", amFilter) params.foreach { case (k, v) => System.setProperty(s"spark.$amFilter.param.$k", v) } } else { @@ -453,12 +453,24 @@ private[spark] class ApplicationMaster( private def startUserApplication(): Thread = { logInfo("Starting the user application in a separate Thread") System.setProperty("spark.executor.instances", args.numExecutors.toString) + + val classpath = Client.getUserClasspath(sparkConf) + val urls = classpath.map { entry => + new URL("file:" + new File(entry.getPath()).getAbsolutePath()) + } + val userClassLoader = + if (Client.isUserClassPathFirst(sparkConf, isDriver = true)) { + new ChildFirstURLClassLoader(urls, Utils.getContextOrSparkClassLoader) + } else { + new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) + } + if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) { System.setProperty("spark.submit.pyFiles", PythonRunner.formatPaths(args.pyFiles).mkString(",")) } - val mainMethod = Class.forName(args.userClass, false, - Thread.currentThread.getContextClassLoader).getMethod("main", classOf[Array[String]]) + val mainMethod = userClassLoader.loadClass(args.userClass) + .getMethod("main", classOf[Array[String]]) val userThread = new Thread { override def run() { @@ -473,16 +485,16 @@ private[spark] class ApplicationMaster( e.getCause match { case _: InterruptedException => // Reporter thread can interrupt to stop user class - case e: Exception => + case cause: Throwable => + logError("User class threw exception: " + cause.getMessage, cause) finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_EXCEPTION_USER_CLASS, - "User class threw exception: " + e.getMessage) - // re-throw to get it logged - throw e + "User class threw exception: " + cause.getMessage) } } } } + userThread.setContextClassLoader(userClassLoader) userThread.setName("Driver") userThread.start() userThread @@ -491,7 +503,7 @@ private[spark] class ApplicationMaster( /** * An actor that communicates with the driver's scheduler backend. */ - private class AMActor(driverUrl: String, isDriver: Boolean) extends Actor { + private class AMActor(driverUrl: String, isClusterMode: Boolean) extends Actor { var driver: ActorSelection = _ override def preStart() = { @@ -503,7 +515,7 @@ private[spark] class ApplicationMaster( driver ! RegisterClusterManager // In cluster mode, the AM can directly monitor the driver status instead // of trying to deduce it from the lifecycle of the driver's actor - if (!isDriver) { + if (!isClusterMode) { context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) } } @@ -513,7 +525,7 @@ private[spark] class ApplicationMaster( logInfo(s"Driver terminated or disconnected! Shutting down. $x") // In cluster mode, do not rely on the disassociated event to exit // This avoids potentially reporting incorrect exit codes if the driver fails - if (!isDriver) { + if (!isClusterMode) { finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) } @@ -522,7 +534,6 @@ private[spark] class ApplicationMaster( driver ! x case RequestExecutors(requestedTotal) => - logInfo(s"Driver requested a total number of $requestedTotal executor(s).") Option(allocator) match { case Some(a) => a.requestTotalExecutors(requestedTotal) case None => logWarning("Container allocator is not ready to request executors yet.") diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 91e8574e94e2f..61f8fc3f5a014 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -183,8 +183,7 @@ private[spark] class Client( private[yarn] def copyFileToRemote( destDir: Path, srcPath: Path, - replication: Short, - setPerms: Boolean = false): Path = { + replication: Short): Path = { val destFs = destDir.getFileSystem(hadoopConf) val srcFs = srcPath.getFileSystem(hadoopConf) var destPath = srcPath @@ -193,9 +192,7 @@ private[spark] class Client( logInfo(s"Uploading resource $srcPath -> $destPath") FileUtil.copy(srcFs, srcPath, destFs, destPath, false, hadoopConf) destFs.setReplication(destPath, replication) - if (setPerms) { - destFs.setPermission(destPath, new FsPermission(APP_FILE_PERMISSION)) - } + destFs.setPermission(destPath, new FsPermission(APP_FILE_PERMISSION)) } else { logInfo(s"Source and destination file systems are the same. Not copying $srcPath") } @@ -239,23 +236,22 @@ private[spark] class Client( /** * Copy the given main resource to the distributed cache if the scheme is not "local". * Otherwise, set the corresponding key in our SparkConf to handle it downstream. - * Each resource is represented by a 4-tuple of: + * Each resource is represented by a 3-tuple of: * (1) destination resource name, * (2) local path to the resource, - * (3) Spark property key to set if the scheme is not local, and - * (4) whether to set permissions for this resource + * (3) Spark property key to set if the scheme is not local */ List( - (SPARK_JAR, sparkJar(sparkConf), CONF_SPARK_JAR, false), - (APP_JAR, args.userJar, CONF_SPARK_USER_JAR, true), - ("log4j.properties", oldLog4jConf.orNull, null, false) - ).foreach { case (destName, _localPath, confKey, setPermissions) => + (SPARK_JAR, sparkJar(sparkConf), CONF_SPARK_JAR), + (APP_JAR, args.userJar, CONF_SPARK_USER_JAR), + ("log4j.properties", oldLog4jConf.orNull, null) + ).foreach { case (destName, _localPath, confKey) => val localPath: String = if (_localPath != null) _localPath.trim() else "" if (!localPath.isEmpty()) { val localURI = new URI(localPath) if (localURI.getScheme != LOCAL_SCHEME) { val src = getQualifiedLocalPath(localURI, hadoopConf) - val destPath = copyFileToRemote(dst, src, replication, setPermissions) + val destPath = copyFileToRemote(dst, src, replication) val destFs = FileSystem.get(destPath.toUri(), hadoopConf) distCacheMgr.addResource(destFs, hadoopConf, destPath, localResources, LocalResourceType.FILE, destName, statCache) @@ -418,6 +414,8 @@ private[spark] class Client( // In our expts, using (default) throughput collector has severe perf ramifications in // multi-tenant machines javaOpts += "-XX:+UseConcMarkSweepGC" + javaOpts += "-XX:MaxTenuringThreshold=31" + javaOpts += "-XX:SurvivorRatio=8" javaOpts += "-XX:+CMSIncrementalMode" javaOpts += "-XX:+CMSIncrementalPacing" javaOpts += "-XX:CMSIncrementalDutyCycleMin=0" @@ -433,10 +431,11 @@ private[spark] class Client( // Include driver-specific java options if we are launching a driver if (isClusterMode) { - sparkConf.getOption("spark.driver.extraJavaOptions") + val driverOpts = sparkConf.getOption("spark.driver.extraJavaOptions") .orElse(sys.env.get("SPARK_JAVA_OPTS")) - .map(Utils.splitCommandString).getOrElse(Seq.empty) - .foreach(opts => javaOpts += opts) + driverOpts.foreach { opts => + javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) + } val libraryPaths = Seq(sys.props.get("spark.driver.extraLibraryPath"), sys.props.get("spark.driver.libraryPath")).flatten if (libraryPaths.nonEmpty) { @@ -458,7 +457,7 @@ private[spark] class Client( val msg = s"$amOptsKey is not allowed to alter memory settings (was '$opts')." throw new SparkException(msg) } - javaOpts ++= Utils.splitCommandString(opts) + javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } } @@ -704,7 +703,7 @@ object Client extends Logging { * Return the path to the given application's staging directory. */ private def getAppStagingDir(appId: ApplicationId): String = { - SPARK_STAGING + Path.SEPARATOR + appId.toString() + Path.SEPARATOR + buildPath(SPARK_STAGING, appId.toString()) } /** @@ -780,7 +779,13 @@ object Client extends Logging { /** * Populate the classpath entry in the given environment map. - * This includes the user jar, Spark jar, and any extra application jars. + * + * User jars are generally not added to the JVM's system classpath; those are handled by the AM + * and executor backend. When the deprecated `spark.yarn.user.classpath.first` is used, user jars + * are included in the system classpath, though. The extra class path and other uploaded files are + * always made available through the system class path. + * + * @param args Client arguments (when starting the AM) or null (when starting executors). */ private[yarn] def populateClasspath( args: ClientArguments, @@ -792,48 +797,38 @@ object Client extends Logging { addClasspathEntry( YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), env ) - - // Normally the users app.jar is last in case conflicts with spark jars if (sparkConf.getBoolean("spark.yarn.user.classpath.first", false)) { - addUserClasspath(args, sparkConf, env) - addFileToClasspath(sparkJar(sparkConf), SPARK_JAR, env) - populateHadoopClasspath(conf, env) - } else { - addFileToClasspath(sparkJar(sparkConf), SPARK_JAR, env) - populateHadoopClasspath(conf, env) - addUserClasspath(args, sparkConf, env) + val userClassPath = + if (args != null) { + getUserClasspath(Option(args.userJar), Option(args.addJars)) + } else { + getUserClasspath(sparkConf) + } + userClassPath.foreach { x => + addFileToClasspath(x, null, env) + } } - - // Append all jar files under the working directory to the classpath. - addClasspathEntry( - YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR + "*", env - ) + addFileToClasspath(new URI(sparkJar(sparkConf)), SPARK_JAR, env) + populateHadoopClasspath(conf, env) + sys.env.get(ENV_DIST_CLASSPATH).foreach(addClasspathEntry(_, env)) } /** - * Adds the user jars which have local: URIs (or alternate names, such as APP_JAR) explicitly - * to the classpath. + * Returns a list of URIs representing the user classpath. + * + * @param conf Spark configuration. */ - private def addUserClasspath( - args: ClientArguments, - conf: SparkConf, - env: HashMap[String, String]): Unit = { - - // If `args` is not null, we are launching an AM container. - // Otherwise, we are launching executor containers. - val (mainJar, secondaryJars) = - if (args != null) { - (args.userJar, args.addJars) - } else { - (conf.get(CONF_SPARK_USER_JAR, null), conf.get(CONF_SPARK_YARN_SECONDARY_JARS, null)) - } + def getUserClasspath(conf: SparkConf): Array[URI] = { + getUserClasspath(conf.getOption(CONF_SPARK_USER_JAR), + conf.getOption(CONF_SPARK_YARN_SECONDARY_JARS)) + } - addFileToClasspath(mainJar, APP_JAR, env) - if (secondaryJars != null) { - secondaryJars.split(",").filter(_.nonEmpty).foreach { jar => - addFileToClasspath(jar, null, env) - } - } + private def getUserClasspath( + mainJar: Option[String], + secondaryJars: Option[String]): Array[URI] = { + val mainUri = mainJar.orElse(Some(APP_JAR)).map(new URI(_)) + val secondaryUris = secondaryJars.map(_.split(",")).toSeq.flatten.map(new URI(_)) + (mainUri ++ secondaryUris).toArray } /** @@ -844,27 +839,19 @@ object Client extends Logging { * * If not a "local:" file and no alternate name, the environment is not modified. * - * @param path Path to add to classpath (optional). + * @param uri URI to add to classpath (optional). * @param fileName Alternate name for the file (optional). * @param env Map holding the environment variables. */ private def addFileToClasspath( - path: String, + uri: URI, fileName: String, env: HashMap[String, String]): Unit = { - if (path != null) { - scala.util.control.Exception.ignoring(classOf[URISyntaxException]) { - val uri = new URI(path) - if (uri.getScheme == LOCAL_SCHEME) { - addClasspathEntry(uri.getPath, env) - return - } - } - } - if (fileName != null) { - addClasspathEntry( - YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR + fileName, env - ) + if (uri != null && uri.getScheme == LOCAL_SCHEME) { + addClasspathEntry(uri.getPath, env) + } else if (fileName != null) { + addClasspathEntry(buildPath( + YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), fileName), env) } } @@ -960,4 +947,24 @@ object Client extends Logging { new Path(qualifiedURI) } + /** + * Whether to consider jars provided by the user to have precedence over the Spark jars when + * loading user classes. + */ + def isUserClassPathFirst(conf: SparkConf, isDriver: Boolean): Boolean = { + if (isDriver) { + conf.getBoolean("spark.driver.userClassPathFirst", false) + } else { + conf.getBoolean("spark.executor.userClassPathFirst", + conf.getBoolean("spark.files.userClassPathFirst", false)) + } + } + + /** + * Joins all the path components using Path.SEPARATOR. + */ + def buildPath(components: String*): String = { + components.mkString(Path.SEPARATOR) + } + } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index ee2002a35f523..c1d3f7320f53c 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -17,6 +17,7 @@ package org.apache.spark.deploy.yarn +import java.io.File import java.net.URI import java.nio.ByteBuffer @@ -37,7 +38,7 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{ConverterUtils, Records} -import org.apache.spark.{SecurityManager, SparkConf, Logging} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.network.util.JavaUtils class ExecutorRunnable( @@ -56,8 +57,8 @@ class ExecutorRunnable( var rpc: YarnRPC = YarnRPC.create(conf) var nmClient: NMClient = _ val yarnConf: YarnConfiguration = new YarnConfiguration(conf) - lazy val env = prepareEnvironment - + lazy val env = prepareEnvironment(container) + def run = { logInfo("Starting Executor Container") nmClient = NMClient.createNMClient() @@ -108,7 +109,13 @@ class ExecutorRunnable( } // Send the start request to the ContainerManager - nmClient.startContainer(container, ctx) + try { + nmClient.startContainer(container, ctx) + } catch { + case ex: Exception => + throw new SparkException(s"Exception while starting container ${container.getId}" + + s" on host $hostname", ex) + } } private def prepareCommand( @@ -128,14 +135,15 @@ class ExecutorRunnable( // Set the JVM memory val executorMemoryString = executorMemory + "m" - javaOpts += "-Xms" + executorMemoryString + " -Xmx" + executorMemoryString + " " + javaOpts += "-Xms" + executorMemoryString + javaOpts += "-Xmx" + executorMemoryString // Set extra Java options for the executor, if defined sys.props.get("spark.executor.extraJavaOptions").foreach { opts => - javaOpts += opts + javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } sys.env.get("SPARK_JAVA_OPTS").foreach { opts => - javaOpts += opts + javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } sys.props.get("spark.executor.extraLibraryPath").foreach { p => prefixEnv = Some(Utils.libraryPathEnvPrefix(Seq(p))) @@ -173,17 +181,27 @@ class ExecutorRunnable( // The options are based on // http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use // %20the%20Concurrent%20Low%20Pause%20Collector|outline - javaOpts += " -XX:+UseConcMarkSweepGC " - javaOpts += " -XX:+CMSIncrementalMode " - javaOpts += " -XX:+CMSIncrementalPacing " - javaOpts += " -XX:CMSIncrementalDutyCycleMin=0 " - javaOpts += " -XX:CMSIncrementalDutyCycle=10 " + javaOpts += "-XX:+UseConcMarkSweepGC" + javaOpts += "-XX:+CMSIncrementalMode" + javaOpts += "-XX:+CMSIncrementalPacing" + javaOpts += "-XX:CMSIncrementalDutyCycleMin=0" + javaOpts += "-XX:CMSIncrementalDutyCycle=10" } */ // For log4j configuration to reference javaOpts += ("-Dspark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR) + val userClassPath = Client.getUserClasspath(sparkConf).flatMap { uri => + val absPath = + if (new File(uri.getPath()).isAbsolute()) { + uri.getPath() + } else { + Client.buildPath(Environment.PWD.$(), uri.getPath()) + } + Seq("--user-class-path", "file:" + absPath) + }.toSeq + val commands = prefixEnv ++ Seq( YarnSparkHadoopUtil.expandEnvironment(Environment.JAVA_HOME) + "/bin/java", "-server", @@ -195,11 +213,13 @@ class ExecutorRunnable( "-XX:OnOutOfMemoryError='kill %p'") ++ javaOpts ++ Seq("org.apache.spark.executor.CoarseGrainedExecutorBackend", - masterAddress.toString, - slaveId.toString, - hostname.toString, - executorCores.toString, - appId, + "--driver-url", masterAddress.toString, + "--executor-id", slaveId.toString, + "--hostname", hostname.toString, + "--cores", executorCores.toString, + "--app-id", appId) ++ + userClassPath ++ + Seq( "1>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout", "2>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") @@ -254,7 +274,7 @@ class ExecutorRunnable( localResources } - private def prepareEnvironment: HashMap[String, String] = { + private def prepareEnvironment(container: Container): HashMap[String, String] = { val env = new HashMap[String, String]() val extraCp = sparkConf.getOption("spark.executor.extraClassPath") Client.populateClasspath(null, yarnConf, sparkConf, env, extraCp) @@ -270,6 +290,14 @@ class ExecutorRunnable( YarnSparkHadoopUtil.setEnvFromInputString(env, userEnvs) } + // Add log urls + sys.env.get("SPARK_USER").foreach { user => + val baseUrl = "http://%s/node/containerlogs/%s/%s" + .format(container.getNodeHttpAddress, ConverterUtils.toString(container.getId), user) + env("SPARK_LOG_URL_STDERR") = s"$baseUrl/stderr?start=0" + env("SPARK_LOG_URL_STDOUT") = s"$baseUrl/stdout?start=0" + } + System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k, v) => env(k) = v } env } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 0dbb6154b3039..c98763e15b58f 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -69,8 +69,7 @@ private[yarn] class YarnAllocator( } // Visible for testing. - val allocatedHostToContainersMap = - new HashMap[String, collection.mutable.Set[ContainerId]] + val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]] val allocatedContainerToHostMap = new HashMap[ContainerId, String] // Containers that we no longer care about. We've either already told the RM to release them or @@ -84,10 +83,11 @@ private[yarn] class YarnAllocator( private var executorIdCounter = 0 @volatile private var numExecutorsFailed = 0 - @volatile private var maxExecutors = args.numExecutors + @volatile private var targetNumExecutors = args.numExecutors // Keep track of which container is running which executor to remove the executors later - private val executorIdToContainer = new HashMap[String, Container] + // Visible for testing. + private[yarn] val executorIdToContainer = new HashMap[String, Container] // Executor memory in MB. protected val executorMemory = args.executorMemory @@ -133,10 +133,15 @@ private[yarn] class YarnAllocator( amClient.getMatchingRequests(RM_REQUEST_PRIORITY, location, resource).map(_.size).sum /** - * Request as many executors from the ResourceManager as needed to reach the desired total. + * Request as many executors from the ResourceManager as needed to reach the desired total. If + * the requested total is smaller than the current number of running executors, no executors will + * be killed. */ def requestTotalExecutors(requestedTotal: Int): Unit = synchronized { - maxExecutors = requestedTotal + if (requestedTotal != targetNumExecutors) { + logInfo(s"Driver requested a total number of $requestedTotal executor(s).") + targetNumExecutors = requestedTotal + } } /** @@ -147,8 +152,6 @@ private[yarn] class YarnAllocator( val container = executorIdToContainer.remove(executorId).get internalReleaseContainer(container) numExecutorsRunning -= 1 - maxExecutors -= 1 - assert(maxExecutors >= 0, "Allocator killed more executors than are allocated!") } else { logWarning(s"Attempted to kill unknown executor $executorId!") } @@ -163,15 +166,8 @@ private[yarn] class YarnAllocator( * This must be synchronized because variables read in this method are mutated by other methods. */ def allocateResources(): Unit = synchronized { - val numPendingAllocate = getNumPendingAllocate - val missing = maxExecutors - numPendingAllocate - numExecutorsRunning + updateResourceRequests() - if (missing > 0) { - logInfo(s"Will request $missing executor containers, each with ${resource.getVirtualCores} " + - s"cores and ${resource.getMemory} MB memory including $memoryOverhead MB overhead") - } - - addResourceRequests(missing) val progressIndicator = 0.1f // Poll the ResourceManager. This doubles as a heartbeat if there are no pending container // requests. @@ -201,15 +197,36 @@ private[yarn] class YarnAllocator( } /** - * Request numExecutors additional containers from YARN. Visible for testing. + * Update the set of container requests that we will sync with the RM based on the number of + * executors we have currently running and our target number of executors. + * + * Visible for testing. */ - def addResourceRequests(numExecutors: Int): Unit = { - for (i <- 0 until numExecutors) { - val request = new ContainerRequest(resource, null, null, RM_REQUEST_PRIORITY) - amClient.addContainerRequest(request) - val nodes = request.getNodes - val hostStr = if (nodes == null || nodes.isEmpty) "Any" else nodes.last - logInfo("Container request (host: %s, capability: %s".format(hostStr, resource)) + def updateResourceRequests(): Unit = { + val numPendingAllocate = getNumPendingAllocate + val missing = targetNumExecutors - numPendingAllocate - numExecutorsRunning + + if (missing > 0) { + logInfo(s"Will request $missing executor containers, each with ${resource.getVirtualCores} " + + s"cores and ${resource.getMemory} MB memory including $memoryOverhead MB overhead") + + for (i <- 0 until missing) { + val request = new ContainerRequest(resource, null, null, RM_REQUEST_PRIORITY) + amClient.addContainerRequest(request) + val nodes = request.getNodes + val hostStr = if (nodes == null || nodes.isEmpty) "Any" else nodes.last + logInfo(s"Container request (host: $hostStr, capability: $resource)") + } + } else if (missing < 0) { + val numToCancel = math.min(numPendingAllocate, -missing) + logInfo(s"Canceling requests for $numToCancel executor containers") + + val matchingRequests = amClient.getMatchingRequests(RM_REQUEST_PRIORITY, ANY_HOST, resource) + if (!matchingRequests.isEmpty) { + matchingRequests.head.take(numToCancel).foreach(amClient.removeContainerRequest) + } else { + logWarning("Expected to find pending requests, but found none.") + } } } @@ -266,7 +283,7 @@ private[yarn] class YarnAllocator( * containersToUse or remaining. * * @param allocatedContainer container that was given to us by YARN - * @location resource name, either a node, rack, or * + * @param location resource name, either a node, rack, or * * @param containersToUse list of containers that will be used * @param remaining list of containers that will not be used */ @@ -275,8 +292,14 @@ private[yarn] class YarnAllocator( location: String, containersToUse: ArrayBuffer[Container], remaining: ArrayBuffer[Container]): Unit = { + // SPARK-6050: certain Yarn configurations return a virtual core count that doesn't match the + // request; for example, capacity scheduler + DefaultResourceCalculator. So match on requested + // memory, but use the asked vcore count for matching, effectively disabling matching on vcore + // count. + val matchingResource = Resource.newInstance(allocatedContainer.getResource.getMemory, + resource.getVirtualCores) val matchingRequests = amClient.getMatchingRequests(allocatedContainer.getPriority, location, - allocatedContainer.getResource) + matchingResource) // Match the allocation to a request if (!matchingRequests.isEmpty) { @@ -294,7 +317,7 @@ private[yarn] class YarnAllocator( private def runAllocatedContainers(containersToUse: ArrayBuffer[Container]): Unit = { for (container <- containersToUse) { numExecutorsRunning += 1 - assert(numExecutorsRunning <= maxExecutors) + assert(numExecutorsRunning <= targetNumExecutors) val executorHostname = container.getNodeId.getHost val containerId = container.getId executorIdCounter += 1 @@ -303,7 +326,7 @@ private[yarn] class YarnAllocator( assert(container.getResource.getMemory >= resource.getMemory) logInfo("Launching container %s for on host %s".format(containerId, executorHostname)) - executorIdToContainer(executorId) = container + executorIdToContainer(executorId) = container val containerSet = allocatedHostToContainersMap.getOrElseUpdate(executorHostname, new HashSet[ContainerId]) @@ -330,7 +353,8 @@ private[yarn] class YarnAllocator( } } - private def processCompletedContainers(completedContainers: Seq[ContainerStatus]): Unit = { + // Visible for testing. + private[yarn] def processCompletedContainers(completedContainers: Seq[ContainerStatus]): Unit = { for (completedContainer <- completedContainers) { val containerId = completedContainer.getContainerId diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 690f927e938c3..8abdc26b43806 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler.cluster import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState} +import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException import org.apache.spark.{SparkException, Logging, SparkContext} import org.apache.spark.deploy.yarn.{Client, ClientArguments} @@ -78,18 +79,12 @@ private[spark] class YarnClientSchedulerBackend( ) // Warn against the following deprecated environment variables: env var -> suggestion val deprecatedEnvVars = Map( - "SPARK_MASTER_MEMORY" -> "SPARK_DRIVER_MEMORY or --driver-memory through spark-submit", "SPARK_WORKER_INSTANCES" -> "SPARK_WORKER_INSTANCES or --num-executors through spark-submit", "SPARK_WORKER_MEMORY" -> "SPARK_EXECUTOR_MEMORY or --executor-memory through spark-submit", "SPARK_WORKER_CORES" -> "SPARK_EXECUTOR_CORES or --executor-cores through spark-submit") - // Do the same for deprecated properties: property -> suggestion - val deprecatedProps = Map("spark.master.memory" -> "--driver-memory through spark-submit") optionTuples.foreach { case (optionName, envVar, sparkProp) => if (sc.getConf.contains(sparkProp)) { extraArgs += (optionName, sc.getConf.get(sparkProp)) - if (deprecatedProps.contains(sparkProp)) { - logWarning(s"NOTE: $sparkProp is deprecated. Use ${deprecatedProps(sparkProp)} instead.") - } } else if (System.getenv(envVar) != null) { extraArgs += (optionName, System.getenv(envVar)) if (deprecatedEnvVars.contains(envVar)) { @@ -133,8 +128,14 @@ private[spark] class YarnClientSchedulerBackend( val t = new Thread { override def run() { while (!stopping) { - val report = client.getApplicationReport(appId) - val state = report.getYarnApplicationState() + var state: YarnApplicationState = null + try { + val report = client.getApplicationReport(appId) + state = report.getYarnApplicationState() + } catch { + case e: ApplicationNotFoundException => + state = YarnApplicationState.KILLED + } if (state == YarnApplicationState.FINISHED || state == YarnApplicationState.KILLED || state == YarnApplicationState.FAILED) { diff --git a/yarn/src/test/resources/log4j.properties b/yarn/src/test/resources/log4j.properties index 287c8e3563503..aab41fa49430f 100644 --- a/yarn/src/test/resources/log4j.properties +++ b/yarn/src/test/resources/log4j.properties @@ -16,7 +16,7 @@ # # Set everything to be logged to the file target/unit-tests.log -log4j.rootCategory=INFO, file +log4j.rootCategory=DEBUG, file log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=true log4j.appender.file.file=target/unit-tests.log @@ -25,4 +25,4 @@ log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{ # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.apache.hadoop=WARN diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 2bb3dcffd61d9..92f04b4b859b3 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -28,8 +28,7 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.mockito.Matchers._ import org.mockito.Mockito._ -import org.scalatest.FunSuite -import org.scalatest.Matchers +import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} import scala.collection.JavaConversions._ import scala.collection.mutable.{ HashMap => MutableHashMap } @@ -39,7 +38,15 @@ import scala.util.Try import org.apache.spark.{SparkException, SparkConf} import org.apache.spark.util.Utils -class ClientSuite extends FunSuite with Matchers { +class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll { + + override def beforeAll(): Unit = { + System.setProperty("SPARK_YARN_MODE", "true") + } + + override def afterAll(): Unit = { + System.clearProperty("SPARK_YARN_MODE") + } test("default Yarn application classpath") { Client.getDefaultYarnApplicationClasspath should be(Some(Fixtures.knownDefYarnAppCP)) @@ -82,6 +89,7 @@ class ClientSuite extends FunSuite with Matchers { test("Local jar URIs") { val conf = new Configuration() val sparkConf = new SparkConf().set(Client.CONF_SPARK_JAR, SPARK) + .set("spark.yarn.user.classpath.first", "true") val env = new MutableHashMap[String, String]() val args = new ClientArguments(Array("--jar", USER, "--addJars", ADDED), sparkConf) @@ -98,13 +106,10 @@ class ClientSuite extends FunSuite with Matchers { }) if (classOf[Environment].getMethods().exists(_.getName == "$$")) { cp should contain("{{PWD}}") - cp should contain(s"{{PWD}}${Path.SEPARATOR}*") } else if (Utils.isWindows) { cp should contain("%PWD%") - cp should contain(s"%PWD%${Path.SEPARATOR}*") } else { cp should contain(Environment.PWD.$()) - cp should contain(s"${Environment.PWD.$()}${File.separator}*") } cp should not contain (Client.SPARK_JAR) cp should not contain (Client.APP_JAR) @@ -117,7 +122,7 @@ class ClientSuite extends FunSuite with Matchers { val client = spy(new Client(args, conf, sparkConf)) doReturn(new Path("/")).when(client).copyFileToRemote(any(classOf[Path]), - any(classOf[Path]), anyShort(), anyBoolean()) + any(classOf[Path]), anyShort()) val tempDir = Utils.createTempDir() try { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 024b25f9d3365..c09b01bafce37 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -107,8 +107,8 @@ class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach test("single container allocated") { // request a single container and receive it - val handler = createAllocator() - handler.addResourceRequests(1) + val handler = createAllocator(1) + handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) handler.getNumPendingAllocate should be (1) @@ -123,8 +123,8 @@ class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach test("some containers allocated") { // request a few containers and receive some of them - val handler = createAllocator() - handler.addResourceRequests(4) + val handler = createAllocator(4) + handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) handler.getNumPendingAllocate should be (4) @@ -144,7 +144,7 @@ class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach test("receive more containers than requested") { val handler = createAllocator(2) - handler.addResourceRequests(2) + handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) handler.getNumPendingAllocate should be (2) @@ -162,6 +162,72 @@ class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach handler.allocatedHostToContainersMap.contains("host4") should be (false) } + test("decrease total requested executors") { + val handler = createAllocator(4) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (4) + + handler.requestTotalExecutors(3) + handler.updateResourceRequests() + handler.getNumPendingAllocate should be (3) + + val container = createContainer("host1") + handler.handleAllocatedContainers(Array(container)) + + handler.getNumExecutorsRunning should be (1) + handler.allocatedContainerToHostMap.get(container.getId).get should be ("host1") + handler.allocatedHostToContainersMap.get("host1").get should contain (container.getId) + + handler.requestTotalExecutors(2) + handler.updateResourceRequests() + handler.getNumPendingAllocate should be (1) + } + + test("decrease total requested executors to less than currently running") { + val handler = createAllocator(4) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (4) + + handler.requestTotalExecutors(3) + handler.updateResourceRequests() + handler.getNumPendingAllocate should be (3) + + val container1 = createContainer("host1") + val container2 = createContainer("host2") + handler.handleAllocatedContainers(Array(container1, container2)) + + handler.getNumExecutorsRunning should be (2) + + handler.requestTotalExecutors(1) + handler.updateResourceRequests() + handler.getNumPendingAllocate should be (0) + handler.getNumExecutorsRunning should be (2) + } + + test("kill executors") { + val handler = createAllocator(4) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (4) + + val container1 = createContainer("host1") + val container2 = createContainer("host2") + handler.handleAllocatedContainers(Array(container1, container2)) + + handler.requestTotalExecutors(1) + handler.executorIdToContainer.keys.foreach { id => handler.killExecutor(id ) } + + val statuses = Seq(container1, container2).map { c => + ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Finished", 0) + } + handler.updateResourceRequests() + handler.processCompletedContainers(statuses.toSeq) + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (1) + } + test("memory exceeded diagnostic regexes") { val diagnostics = "Container [pid=12465,containerID=container_1412887393566_0003_01_000002] is running " + diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 7165918e1bfcf..c06c0105670c0 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -17,26 +17,34 @@ package org.apache.spark.deploy.yarn -import java.io.File +import java.io.{File, FileOutputStream, OutputStreamWriter} +import java.util.Properties import java.util.concurrent.TimeUnit import scala.collection.JavaConversions._ +import scala.collection.mutable -import com.google.common.base.Charsets +import com.google.common.base.Charsets.UTF_8 +import com.google.common.io.ByteStreams import com.google.common.io.Files -import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} - import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.server.MiniYARNCluster +import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException} -import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException, TestUtils} +import org.apache.spark.scheduler.cluster.ExecutorInfo +import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorAdded} import org.apache.spark.util.Utils +/** + * Integration tests for YARN; these tests use a mini Yarn cluster to run Spark-on-YARN + * applications, and require the Spark assembly to be built before they can be successfully + * run. + */ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers with Logging { - // log4j configuration for the Yarn containers, so that their output is collected - // by Yarn instead of trying to overwrite unit-tests.log. + // log4j configuration for the YARN containers, so that their output is collected + // by YARN instead of trying to overwrite unit-tests.log. private val LOG4J_CONF = """ |log4j.rootCategory=DEBUG, console |log4j.appender.console=org.apache.log4j.ConsoleAppender @@ -51,13 +59,11 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit | |from pyspark import SparkConf , SparkContext |if __name__ == "__main__": - | if len(sys.argv) != 3: - | print >> sys.stderr, "Usage: test.py [master] [result file]" + | if len(sys.argv) != 2: + | print >> sys.stderr, "Usage: test.py [result file]" | exit(-1) - | conf = SparkConf() - | conf.setMaster(sys.argv[1]).setAppName("python test in yarn cluster mode") - | sc = SparkContext(conf=conf) - | status = open(sys.argv[2],'w') + | sc = SparkContext(conf=SparkConf()) + | status = open(sys.argv[1],'w') | result = "failure" | rdd = sc.parallelize(range(10)) | cnt = rdd.count() @@ -71,21 +77,17 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit private var yarnCluster: MiniYARNCluster = _ private var tempDir: File = _ private var fakeSparkJar: File = _ - private var oldConf: Map[String, String] = _ + private var logConfDir: File = _ override def beforeAll() { - tempDir = Utils.createTempDir() + super.beforeAll() - val logConfDir = new File(tempDir, "log4j") + tempDir = Utils.createTempDir() + logConfDir = new File(tempDir, "log4j") logConfDir.mkdir() val logConfFile = new File(logConfDir, "log4j.properties") - Files.write(LOG4J_CONF, logConfFile, Charsets.UTF_8) - - val childClasspath = logConfDir.getAbsolutePath() + File.pathSeparator + - sys.props("java.class.path") - - oldConf = sys.props.filter { case (k, v) => k.startsWith("spark.") }.toMap + Files.write(LOG4J_CONF, logConfFile, UTF_8) yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1) yarnCluster.init(new YarnConfiguration()) @@ -116,117 +118,249 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit } logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}") - config.foreach { e => - sys.props += ("spark.hadoop." + e.getKey() -> e.getValue()) - } fakeSparkJar = File.createTempFile("sparkJar", null, tempDir) - val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - sys.props += ("spark.yarn.appMasterEnv.SPARK_HOME" -> sparkHome) - sys.props += ("spark.executorEnv.SPARK_HOME" -> sparkHome) - sys.props += ("spark.yarn.jar" -> ("local:" + fakeSparkJar.getAbsolutePath())) - sys.props += ("spark.executor.instances" -> "1") - sys.props += ("spark.driver.extraClassPath" -> childClasspath) - sys.props += ("spark.executor.extraClassPath" -> childClasspath) - - super.beforeAll() } override def afterAll() { yarnCluster.stop() - sys.props.retain { case (k, v) => !k.startsWith("spark.") } - sys.props ++= oldConf super.afterAll() } test("run Spark in yarn-client mode") { - var result = File.createTempFile("result", null, tempDir) - YarnClusterDriver.main(Array("yarn-client", result.getAbsolutePath())) - checkResult(result) + testBasicYarnApp(true) } test("run Spark in yarn-cluster mode") { - val main = YarnClusterDriver.getClass.getName().stripSuffix("$") - var result = File.createTempFile("result", null, tempDir) - - val args = Array("--class", main, - "--jar", "file:" + fakeSparkJar.getAbsolutePath(), - "--arg", "yarn-cluster", - "--arg", result.getAbsolutePath(), - "--num-executors", "1") - Client.main(args) - checkResult(result) + testBasicYarnApp(false) } test("run Spark in yarn-cluster mode unsuccessfully") { - val main = YarnClusterDriver.getClass.getName().stripSuffix("$") - - // Use only one argument so the driver will fail - val args = Array("--class", main, - "--jar", "file:" + fakeSparkJar.getAbsolutePath(), - "--arg", "yarn-cluster", - "--num-executors", "1") + // Don't provide arguments so the driver will fail. val exception = intercept[SparkException] { - Client.main(args) + runSpark(false, mainClassName(YarnClusterDriver.getClass)) + fail("Spark application should have failed.") } - assert(Utils.exceptionString(exception).contains("Application finished with failed status")) } - test("run Python application in yarn-cluster mode") { + // Enable this once fix SPARK-6700 + ignore("run Python application in yarn-cluster mode") { val primaryPyFile = new File(tempDir, "test.py") - Files.write(TEST_PYFILE, primaryPyFile, Charsets.UTF_8) + Files.write(TEST_PYFILE, primaryPyFile, UTF_8) val pyFile = new File(tempDir, "test2.py") - Files.write(TEST_PYFILE, pyFile, Charsets.UTF_8) + Files.write(TEST_PYFILE, pyFile, UTF_8) var result = File.createTempFile("result", null, tempDir) - val args = Array("--class", "org.apache.spark.deploy.PythonRunner", - "--primary-py-file", primaryPyFile.getAbsolutePath(), - "--py-files", pyFile.getAbsolutePath(), - "--arg", "yarn-cluster", - "--arg", result.getAbsolutePath(), - "--name", "python test in yarn-cluster mode", - "--num-executors", "1") - Client.main(args) + // The sbt assembly does not include pyspark / py4j python dependencies, so we need to + // propagate SPARK_HOME so that those are added to PYTHONPATH. See PythonUtils.scala. + val sparkHome = sys.props("spark.test.home") + val extraConf = Map( + "spark.executorEnv.SPARK_HOME" -> sparkHome, + "spark.yarn.appMasterEnv.SPARK_HOME" -> sparkHome) + + runSpark(false, primaryPyFile.getAbsolutePath(), + sparkArgs = Seq("--py-files", pyFile.getAbsolutePath()), + appArgs = Seq(result.getAbsolutePath()), + extraConf = extraConf) checkResult(result) } + test("user class path first in client mode") { + testUseClassPathFirst(true) + } + + test("user class path first in cluster mode") { + testUseClassPathFirst(false) + } + + private def testBasicYarnApp(clientMode: Boolean): Unit = { + var result = File.createTempFile("result", null, tempDir) + runSpark(clientMode, mainClassName(YarnClusterDriver.getClass), + appArgs = Seq(result.getAbsolutePath())) + checkResult(result) + } + + private def testUseClassPathFirst(clientMode: Boolean): Unit = { + // Create a jar file that contains a different version of "test.resource". + val originalJar = TestUtils.createJarWithFiles(Map("test.resource" -> "ORIGINAL"), tempDir) + val userJar = TestUtils.createJarWithFiles(Map("test.resource" -> "OVERRIDDEN"), tempDir) + val driverResult = File.createTempFile("driver", null, tempDir) + val executorResult = File.createTempFile("executor", null, tempDir) + runSpark(clientMode, mainClassName(YarnClasspathTest.getClass), + appArgs = Seq(driverResult.getAbsolutePath(), executorResult.getAbsolutePath()), + extraClassPath = Seq(originalJar.getPath()), + extraJars = Seq("local:" + userJar.getPath()), + extraConf = Map( + "spark.driver.userClassPathFirst" -> "true", + "spark.executor.userClassPathFirst" -> "true")) + checkResult(driverResult, "OVERRIDDEN") + checkResult(executorResult, "OVERRIDDEN") + } + + private def runSpark( + clientMode: Boolean, + klass: String, + appArgs: Seq[String] = Nil, + sparkArgs: Seq[String] = Nil, + extraClassPath: Seq[String] = Nil, + extraJars: Seq[String] = Nil, + extraConf: Map[String, String] = Map()): Unit = { + val master = if (clientMode) "yarn-client" else "yarn-cluster" + val props = new Properties() + + props.setProperty("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath()) + + val childClasspath = logConfDir.getAbsolutePath() + + File.pathSeparator + + sys.props("java.class.path") + + File.pathSeparator + + extraClassPath.mkString(File.pathSeparator) + props.setProperty("spark.driver.extraClassPath", childClasspath) + props.setProperty("spark.executor.extraClassPath", childClasspath) + + // SPARK-4267: make sure java options are propagated correctly. + props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"") + props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"") + + yarnCluster.getConfig().foreach { e => + props.setProperty("spark.hadoop." + e.getKey(), e.getValue()) + } + + sys.props.foreach { case (k, v) => + if (k.startsWith("spark.")) { + props.setProperty(k, v) + } + } + + extraConf.foreach { case (k, v) => props.setProperty(k, v) } + + val propsFile = File.createTempFile("spark", ".properties", tempDir) + val writer = new OutputStreamWriter(new FileOutputStream(propsFile), UTF_8) + props.store(writer, "Spark properties.") + writer.close() + + val extraJarArgs = if (!extraJars.isEmpty()) Seq("--jars", extraJars.mkString(",")) else Nil + val mainArgs = + if (klass.endsWith(".py")) { + Seq(klass) + } else { + Seq("--class", klass, fakeSparkJar.getAbsolutePath()) + } + val argv = + Seq( + new File(sys.props("spark.test.home"), "bin/spark-submit").getAbsolutePath(), + "--master", master, + "--num-executors", "1", + "--properties-file", propsFile.getAbsolutePath()) ++ + extraJarArgs ++ + sparkArgs ++ + mainArgs ++ + appArgs + + Utils.executeAndGetOutput(argv, + extraEnvironment = Map("YARN_CONF_DIR" -> tempDir.getAbsolutePath())) + } + /** * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide * any sort of error when the job process finishes successfully, but the job itself fails. So * the tests enforce that something is written to a file after everything is ok to indicate * that the job succeeded. */ - private def checkResult(result: File) = { - var resultString = Files.toString(result, Charsets.UTF_8) - resultString should be ("success") + private def checkResult(result: File): Unit = { + checkResult(result, "success") + } + + private def checkResult(result: File, expected: String): Unit = { + var resultString = Files.toString(result, UTF_8) + resultString should be (expected) + } + + private def mainClassName(klass: Class[_]): String = { + klass.getName().stripSuffix("$") } } +private class SaveExecutorInfo extends SparkListener { + val addedExecutorInfos = mutable.Map[String, ExecutorInfo]() + + override def onExecutorAdded(executor : SparkListenerExecutorAdded) { + addedExecutorInfos(executor.executorId) = executor.executorInfo + } +} + private object YarnClusterDriver extends Logging with Matchers { - def main(args: Array[String]) = { - if (args.length != 2) { + val WAIT_TIMEOUT_MILLIS = 10000 + var listener: SaveExecutorInfo = null + + def main(args: Array[String]): Unit = { + if (args.length != 1) { System.err.println( s""" |Invalid command line: ${args.mkString(" ")} | - |Usage: YarnClusterDriver [master] [result file] + |Usage: YarnClusterDriver [result file] """.stripMargin) System.exit(1) } - val sc = new SparkContext(new SparkConf().setMaster(args(0)) + listener = new SaveExecutorInfo + val sc = new SparkContext(new SparkConf() .setAppName("yarn \"test app\" 'with quotes' and \\back\\slashes and $dollarSigns")) - val status = new File(args(1)) + sc.addSparkListener(listener) + val status = new File(args(0)) var result = "failure" try { val data = sc.parallelize(1 to 4, 4).collect().toSet + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) data should be (Set(1, 2, 3, 4)) result = "success" } finally { sc.stop() - Files.write(result, status, Charsets.UTF_8) + Files.write(result, status, UTF_8) + } + + // verify log urls are present + listener.addedExecutorInfos.values.foreach { info => + assert(info.logUrlMap.nonEmpty) + } + } + +} + +private object YarnClasspathTest { + + def main(args: Array[String]): Unit = { + if (args.length != 2) { + System.err.println( + s""" + |Invalid command line: ${args.mkString(" ")} + | + |Usage: YarnClasspathTest [driver result file] [executor result file] + """.stripMargin) + System.exit(1) + } + + readResource(args(0)) + val sc = new SparkContext(new SparkConf()) + try { + sc.parallelize(Seq(1)).foreach { x => readResource(args(1)) } + } finally { + sc.stop() + } + } + + private def readResource(resultPath: String): Unit = { + var result = "failure" + try { + val ccl = Thread.currentThread().getContextClassLoader() + val resource = ccl.getResourceAsStream("test.resource") + val bytes = ByteStreams.toByteArray(resource) + result = new String(bytes, 0, bytes.length, UTF_8) + } finally { + Files.write(result, new File(resultPath), UTF_8) } }