diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 8d5adc8ffd6d..bd8043b54b3c 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -254,6 +254,7 @@ def substr(self, startPos, length): :param startPos: start position (int or Column) :param length: length of the substring (int or Column) + >>> df = spark.createDataFrame([(2, 'Alice'), (5, 'Bob')], ['age', 'name']) >>> df.select(df.name.substr(1, 3).alias("col")).collect() [Row(col=u'Ali'), Row(col=u'Bob')] """ @@ -276,6 +277,7 @@ def isin(self, *cols): A boolean expression that is evaluated to true if the value of this expression is contained by the evaluated values of the arguments. + >>> df = spark.createDataFrame([(2, 'Alice'), (5, 'Bob')], ['age', 'name']) >>> df[df.name.isin("Bob", "Mike")].collect() [Row(age=5, name=u'Bob')] >>> df[df.age.isin([1, 2, 3])].collect() @@ -303,6 +305,7 @@ def alias(self, *alias): Returns this column aliased with a new name or names (in the case of expressions that return more than one column, such as explode). + >>> df = spark.createDataFrame([(2, 'Alice'), (5, 'Bob')], ['age', 'name']) >>> df.select(df.age.alias("age2")).collect() [Row(age2=2), Row(age2=5)] """ @@ -320,10 +323,13 @@ def alias(self, *alias): def cast(self, dataType): """ Convert the column into type ``dataType``. + >>> df = spark.createDataFrame([(2, 'Alice'), (5, 'Bob')], ['age', 'name']) >>> df.select(df.age.cast("string").alias('ages')).collect() - [Row(ages=u'2'), Row(ages=u'5')] + [Row(ages='2'), Row(ages='5')] + + >>> from pyspark.sql.types import StringType >>> df.select(df.age.cast(StringType()).alias('ages')).collect() - [Row(ages=u'2'), Row(ages=u'5')] + [Row(ages='2'), Row(ages='5')] """ if isinstance(dataType, basestring): jc = self._jc.cast(dataType) @@ -344,6 +350,7 @@ def between(self, lowerBound, upperBound): A boolean expression that is evaluated to true if the value of this expression is between the given columns. + >>> df = spark.createDataFrame([(2, 'Alice'), (5, 'Bob')], ['age', 'name']) >>> df.select(df.name, df.age.between(2, 4)).show() +-----+---------------------------+ | name|((age >= 2) AND (age <= 4))| @@ -366,6 +373,7 @@ def when(self, condition, value): :param value: a literal value, or a :class:`Column` expression. >>> from pyspark.sql import functions as F + >>> df = spark.createDataFrame([(2, 'Alice'), (5, 'Bob')], ['age', 'name']) >>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show() +-----+------------------------------------------------------------+ | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0 END| @@ -391,6 +399,7 @@ def otherwise(self, value): :param value: a literal value, or a :class:`Column` expression. >>> from pyspark.sql import functions as F + >>> df = spark.createDataFrame([(2, 'Alice'), (5, 'Bob')], ['age', 'name']) >>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show() +-----+-------------------------------------+ | name|CASE WHEN (age > 3) THEN 1 ELSE 0 END| @@ -412,9 +421,17 @@ def over(self, window): :return: a Column >>> from pyspark.sql import Window - >>> window = Window.partitionBy("name").orderBy("age").rowsBetween(-1, 1) - >>> from pyspark.sql.functions import rank, min - >>> # df.select(rank().over(window), min('age').over(window)) + >>> window = Window.partitionBy("name").orderBy("age") + >>> from pyspark.sql.functions import rank + >>> df = spark.createDataFrame([(2, 'Alice'), (5, 'Bob'), (3, 'Bob')], ['age', 'name']) + >>> df.select('name', 'age', rank().over(window)).show() + +-----+---+-----------------------------------------------------------------+ + | name|age|RANK() OVER (PARTITION BY name ORDER BY age ASC UnspecifiedFrame)| + +-----+---+-----------------------------------------------------------------+ + | Bob| 3| 1| + | Bob| 5| 2| + |Alice| 2| 1| + +-----+---+-----------------------------------------------------------------+ """ from pyspark.sql.window import WindowSpec if not isinstance(window, WindowSpec): @@ -442,10 +459,6 @@ def _test(): .getOrCreate() sc = spark.sparkContext globs['sc'] = sc - globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ - .toDF(StructType([StructField('age', IntegerType()), - StructField('name', StringType())])) - (failure_count, test_count) = doctest.testmod( pyspark.sql.column, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index de4c335ad275..b07942419b1e 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -338,6 +338,7 @@ def registerDataFrameAsTable(self, df, tableName): Temporary tables exist only during the lifetime of this instance of :class:`SQLContext`. + >>> df = spark.createDataFrame([(1, 'row1'), (2, 'row2')], ['field1', 'field2']) >>> sqlContext.registerDataFrameAsTable(df, "table1") """ df.createOrReplaceTempView(tableName) @@ -346,6 +347,7 @@ def registerDataFrameAsTable(self, df, tableName): def dropTempTable(self, tableName): """ Remove the temp table from catalog. + >>> df = spark.createDataFrame([(1, 'row1'), (2, 'row2')], ['field1', 'field2']) >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> sqlContext.dropTempTable("table1") """ @@ -376,10 +378,11 @@ def sql(self, sqlQuery): :return: :class:`DataFrame` + >>> df = spark.createDataFrame([(1, 'row1'), (2, 'row2')], ['field1', 'field2']) >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> df2 = sqlContext.sql("SELECT field1 AS f1, field2 as f2 from table1") >>> df2.collect() - [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] + [Row(f1=1, f2='row1'), Row(f1=2, f2='row2')] """ return self.sparkSession.sql(sqlQuery) @@ -389,6 +392,7 @@ def table(self, tableName): :return: :class:`DataFrame` + >>> df = spark.createDataFrame([(1, 'row1'), (2, 'row2')], ['field1', 'field2']) >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> df2 = sqlContext.table("table1") >>> sorted(df.collect()) == sorted(df2.collect()) @@ -409,10 +413,11 @@ def tables(self, dbName=None): :param dbName: string, name of the database to use. :return: :class:`DataFrame` + >>> df = spark.createDataFrame([(1, 'row1'), (2, 'row2')], ['field1', 'field2']) >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> df2 = sqlContext.tables() >>> df2.filter("tableName = 'table1'").first() - Row(database=u'', tableName=u'table1', isTemporary=True) + Row(tableName='table1', isTemporary=True) """ if dbName is None: return DataFrame(self._ssql_ctx.tables(), self) @@ -426,6 +431,7 @@ def tableNames(self, dbName=None): :param dbName: string, name of the database to use. Default to the current database. :return: list of table names, in string + >>> df = spark.createDataFrame([(1, 'row1'), (2, 'row2')], ['field1', 'field2']) >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> "table1" in sqlContext.tableNames() True @@ -474,6 +480,7 @@ def readStream(self): :return: :class:`DataStreamReader` + >>> import tempfile >>> text_sdf = sqlContext.readStream.text(tempfile.mkdtemp()) >>> text_sdf.isStreaming True @@ -553,34 +560,16 @@ def register(self, name, f, returnType=StringType()): def _test(): import os import doctest - import tempfile from pyspark.context import SparkContext from pyspark.sql import Row, SQLContext import pyspark.sql.context - os.chdir(os.environ["SPARK_HOME"]) globs = pyspark.sql.context.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') - globs['tempfile'] = tempfile - globs['os'] = os + globs['sc'] = sc globs['sqlContext'] = SQLContext(sc) - globs['rdd'] = rdd = sc.parallelize( - [Row(field1=1, field2="row1"), - Row(field1=2, field2="row2"), - Row(field1=3, field2="row3")] - ) - globs['df'] = rdd.toDF() - jsonStrings = [ - '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', - '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' - '"field6":[{"field7": "row2"}]}', - '{"field1" : null, "field2": "row3", ' - '"field3":{"field4":33, "field5": []}}' - ] - globs['jsonStrings'] = jsonStrings - globs['json'] = sc.parallelize(jsonStrings) (failure_count, test_count) = doctest.testmod( pyspark.sql.context, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 29710acf54c4..2b5460c3f01d 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -106,8 +106,9 @@ def toJSON(self, use_unicode=True): Each row is turned into a JSON document as one element in the returned RDD. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.toJSON().first() - u'{"age":2,"name":"Alice"}' + u'{"name":"Alice","age":2}' """ rdd = self._jdf.toJSON() return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) @@ -119,6 +120,7 @@ def registerTempTable(self, name): The lifetime of this temporary table is tied to the :class:`SQLContext` that was used to create this :class:`DataFrame`. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.registerTempTable("people") >>> df2 = spark.sql("select * from people") >>> sorted(df.collect()) == sorted(df2.collect()) @@ -138,6 +140,7 @@ def createTempView(self, name): throws :class:`TempTableAlreadyExistsException`, if the view name already exists in the catalog. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.createTempView("people") >>> df2 = spark.sql("select * from people") >>> sorted(df.collect()) == sorted(df2.collect()) @@ -158,6 +161,7 @@ def createOrReplaceTempView(self, name): The lifetime of this temporary table is tied to the :class:`SparkSession` that was used to create this :class:`DataFrame`. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.createOrReplaceTempView("people") >>> df2 = df.filter(df.age > 3) >>> df2.createOrReplaceTempView("people") @@ -219,6 +223,7 @@ def writeStream(self): def schema(self): """Returns the schema of this :class:`DataFrame` as a :class:`pyspark.sql.types.StructType`. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.schema StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true))) """ @@ -234,11 +239,12 @@ def schema(self): def printSchema(self): """Prints out the schema in the tree format. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.printSchema() root |-- age: integer (nullable = true) |-- name: string (nullable = true) - + """ print(self._jdf.schema().treeString()) @@ -248,6 +254,7 @@ def explain(self, extended=False): :param extended: boolean, default ``False``. If ``False``, prints only the physical plan. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.explain() == Physical Plan == Scan ExistingRDD[age#0,name#1] @@ -297,6 +304,7 @@ def show(self, n=20, truncate=True): If set to a number greater than one, truncates long strings to length ``truncate`` and align cells right. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df DataFrame[age: int, name: string] >>> df.show() @@ -326,6 +334,7 @@ def __repr__(self): def count(self): """Returns the number of rows in this :class:`DataFrame`. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.count() 2 """ @@ -336,6 +345,7 @@ def count(self): def collect(self): """Returns all the records as a list of :class:`Row`. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.collect() [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ @@ -350,8 +360,9 @@ def toLocalIterator(self): Returns an iterator that contains all of the rows in this :class:`DataFrame`. The iterator will consume as much memory as the largest partition in this DataFrame. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> list(df.toLocalIterator()) - [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] """ with SCCallSiteSync(self._sc) as css: port = self._jdf.toPythonIterator() @@ -362,6 +373,7 @@ def toLocalIterator(self): def limit(self, num): """Limits the result count to the number specified. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.limit(1).collect() [Row(age=2, name=u'Alice')] >>> df.limit(0).collect() @@ -375,8 +387,9 @@ def limit(self, num): def take(self, num): """Returns the first ``num`` rows as a :class:`list` of :class:`Row`. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.take(2) - [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] """ return self.limit(num).collect() @@ -489,6 +502,7 @@ def repartition(self, numPartitions, *cols): Added optional arguments to specify the partitioning columns. Also made numPartitions optional if partitioning columns are specified. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.repartition(10).rdd.getNumPartitions() 10 >>> data = df.union(df).repartition("age") @@ -540,6 +554,7 @@ def repartition(self, numPartitions, *cols): def distinct(self): """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.distinct().count() 2 """ @@ -549,6 +564,7 @@ def distinct(self): def sample(self, withReplacement, fraction, seed=None): """Returns a sampled subset of this :class:`DataFrame`. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.sample(False, 0.5, 42).count() 2 """ @@ -601,6 +617,11 @@ def randomSplit(self, weights, seed=None): be normalized if they don't sum up to 1.0. :param seed: The seed for sampling. + >>> from pyspark.sql import Row + >>> df4 = spark.createDataFrame([Row(name='Alice', age=10, height=80), + Row(name='Bob', age=5, height=None), + Row(name='Tom', age=None, height=None), + Row(name=None, age=None, height=None)]) >>> splits = df4.randomSplit([1.0, 2.0], 24) >>> splits[0].count() 1 @@ -620,6 +641,7 @@ def randomSplit(self, weights, seed=None): def dtypes(self): """Returns all column names and their data types as a list. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.dtypes [('age', 'int'), ('name', 'string')] """ @@ -630,6 +652,7 @@ def dtypes(self): def columns(self): """Returns all column names as a list. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.columns ['age', 'name'] """ @@ -641,6 +664,7 @@ def alias(self, alias): """Returns a new :class:`DataFrame` with an alias set. >>> from pyspark.sql.functions import * + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df_as1 = df.alias("df_as1") >>> df_as2 = df.alias("df_as2") >>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), 'inner') @@ -684,6 +708,8 @@ def join(self, other, on=None, how=None): The following performs a full outer join between ``df1`` and ``df2``. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) + >>> df2 = spark.createDataFrame([('Tom', 80), ('Bob', 85)], ['name', 'height']) >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] @@ -1435,8 +1461,9 @@ def withColumn(self, colName, col): :param colName: string, name of the new column. :param col: a :class:`Column` expression for the new column. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.withColumn('age2', df.age + 2).collect() - [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] + [Row(name=u'Alice', age=2, age2=4), Row(name=u'Bob', age=5, age2=7)] """ assert isinstance(col, Column), "col should be Column" return DataFrame(self._jdf.withColumn(colName, col._jc), self.sql_ctx) @@ -1450,8 +1477,9 @@ def withColumnRenamed(self, existing, new): :param existing: string, name of the existing column to rename. :param col: string, new name of the column. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.withColumnRenamed('age', 'age2').collect() - [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')] + [Row(name=u'Alice', age2=2), Row(name=u'Bob', age2=5)] """ return DataFrame(self._jdf.withColumnRenamed(existing, new), self.sql_ctx) @@ -1464,6 +1492,8 @@ def drop(self, *cols): :param cols: a string name of the column to drop, or a :class:`Column` to drop, or a list of string name of the columns to drop. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) + >>> df2 = spark.createDataFrame([('Tom', 80), ('Bob', 85)], ['name', 'height']) >>> df.drop('age').collect() [Row(name=u'Alice'), Row(name=u'Bob')] @@ -1501,8 +1531,9 @@ def toDF(self, *cols): :param cols: list of new column names (string) + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.toDF('f1', 'f2').collect() - [Row(f1=2, f2=u'Alice'), Row(f1=5, f2=u'Bob')] + [Row(f1=u'Alice', f2=2), Row(f1=u'Bob', f2=5)] """ jdf = self._jdf.toDF(self._jseq(cols)) return DataFrame(jdf, self.sql_ctx) @@ -1516,6 +1547,7 @@ def toPandas(self): This is only available if Pandas is installed and available. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.toPandas() # doctest: +SKIP age name 0 2 Alice @@ -1626,16 +1658,6 @@ def _test(): globs['sc'] = sc globs['sqlContext'] = SQLContext(sc) globs['spark'] = SparkSession(sc) - globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')])\ - .toDF(StructType([StructField('age', IntegerType()), - StructField('name', StringType())])) - globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF() - globs['df3'] = sc.parallelize([Row(name='Alice', age=2), - Row(name='Bob', age=5)]).toDF() - globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80), - Row(name='Bob', age=5, height=None), - Row(name='Tom', age=None, height=None), - Row(name=None, age=None, height=None)]).toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.dataframe, globs=globs, diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 7fa3fd2de7dd..f991c4a8406f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -204,6 +204,7 @@ def approxCountDistinct(col, rsd=None): def approx_count_distinct(col, rsd=None): """Returns a new :class:`Column` for approximate distinct count of ``col``. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.agg(approx_count_distinct(df.age).alias('c')).collect() [Row(c=2)] """ @@ -309,6 +310,7 @@ def covar_samp(col1, col2): def countDistinct(col, *cols): """Returns a new :class:`Column` for distinct count of ``col`` or ``cols``. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.agg(countDistinct(df.age, df.name).alias('c')).collect() [Row(c=2)] @@ -338,6 +340,7 @@ def grouping(col): Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated or not, returns 1 for aggregated or 0 for not aggregated in the result set. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.cube("name").agg(grouping("name"), sum("age")).orderBy("name").show() +-----+--------------+--------+ | name|grouping(name)|sum(age)| @@ -362,6 +365,7 @@ def grouping_id(*cols): Note: the list of columns should match with grouping columns exactly, or empty (means all the grouping columns). + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.cube("name").agg(grouping_id(), sum("age")).orderBy("name").show() +-----+-------------+--------+ | name|grouping_id()|sum(age)| @@ -426,7 +430,7 @@ def monotonically_increasing_id(): The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. The current implementation puts the partition ID in the upper 31 bits, and the record number - within each partition in the lower 33 bits. The assumption is that the data frame has + within each partition in the lower 33 bits. The assumption is that the :class:`DataFrame` has less than 1 billion partitions, and each partition has less than 8 billion records. As an example, consider a :class:`DataFrame` with two partitions, each with 3 records. @@ -558,6 +562,7 @@ def spark_partition_id(): def expr(str): """Parses the expression string into the column that it represents + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.select(expr("length(name)")).collect() [Row(length(name)=5), Row(length(name)=3)] """ @@ -572,6 +577,7 @@ def struct(*cols): :param cols: list of column names (string) or list of :class:`Column` expressions + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.select(struct('age', 'name').alias("struct")).collect() [Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))] >>> df.select(struct([df.age, df.name]).alias("struct")).collect() @@ -618,12 +624,14 @@ def least(*cols): @since(1.4) def when(condition, value): - """Evaluates a list of conditions and returns one of multiple possible result expressions. + """ + Evaluates a list of conditions and returns one of multiple possible result expressions. If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. :param condition: a boolean :class:`Column` expression. :param value: a literal value, or a :class:`Column` expression. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.select(when(df['age'] == 2, 3).otherwise(4).alias("age")).collect() [Row(age=3), Row(age=4)] @@ -644,6 +652,7 @@ def log(arg1, arg2=None): If there is only one argument, then this takes the natural logarithm of the argument. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.select(log(10.0, df.age).alias('ten')).rdd.map(lambda l: str(l.ten)[:7]).collect() ['0.30102', '0.69897'] @@ -777,7 +786,7 @@ def date_format(date, format): A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All pattern letters of the Java class `java.text.SimpleDateFormat` can be used. - NOTE: Use when ever possible specialized functions like `year`. These benefit from a + NOTE: Whenever possible, use specialized functions like `year`. These benefit from a specialized implementation. >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) @@ -1137,7 +1146,7 @@ def check_string_field(field, fieldName): @ignore_unicode_prefix def crc32(col): """ - Calculates the cyclic redundancy check value (CRC32) of a binary column and + Calculates the cyclic redundancy check value (CRC32) of a binary column and returns the value as a bigint. >>> spark.createDataFrame([('ABC',)], ['a']).select(crc32('a').alias('crc32')).collect() @@ -1180,6 +1189,7 @@ def sha2(col, numBits): and SHA-512). The numBits indicates the desired bit length of the result, which must have a value of 224, 256, 384, 512, or 0 (which is equivalent to 256). + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> digests = df.select(sha2(df.name, 256).alias('s')).collect() >>> digests[0] Row(s=u'3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043') @@ -1518,6 +1528,7 @@ def soundex(col): def bin(col): """Returns the string representation of the binary value of the given column. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.select(bin(df.age).alias('c')).collect() [Row(c=u'10'), Row(c=u'101')] """ @@ -1592,6 +1603,7 @@ def create_map(*cols): :param cols: list of column names (string) or list of :class:`Column` expressions that grouped as key-value pairs, e.g. (key1, value1, key2, value2, ...). + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.select(create_map('name', 'age').alias("map")).collect() [Row(map={u'Alice': 2}), Row(map={u'Bob': 5})] >>> df.select(create_map([df.name, df.age]).alias("map")).collect() @@ -1611,6 +1623,7 @@ def array(*cols): :param cols: list of column names (string) or list of :class:`Column` expressions that have the same data type. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.select(array('age', 'age').alias("arr")).collect() [Row(arr=[2, 2]), Row(arr=[5, 5])] >>> df.select(array([df.age, df.age]).alias("arr")).collect() @@ -1833,6 +1846,7 @@ def udf(f, returnType=StringType()): >>> from pyspark.sql.types import IntegerType >>> slen = udf(lambda s: len(s), IntegerType()) + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.select(slen(df.name).alias('slen')).collect() [Row(slen=5), Row(slen=3)] """ @@ -1856,7 +1870,6 @@ def _test(): sc = spark.sparkContext globs['sc'] = sc globs['spark'] = spark - globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.functions, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index f2092f9c6305..ec75f2b72d4f 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -73,13 +73,14 @@ def agg(self, *exprs): :param exprs: a dict mapping from column name (string) to aggregate functions (string), or a list of :class:`Column`. + >>> df = spark.createDataFrame([(4, 'Bob'), (2, 'Alice'), (5, 'Bob')], ['age', 'name']) >>> gdf = df.groupBy(df.name) >>> sorted(gdf.agg({"*": "count"}).collect()) - [Row(name=u'Alice', count(1)=1), Row(name=u'Bob', count(1)=1)] + [Row(name='Alice', count(1)=1), Row(name='Bob', count(1)=2)] >>> from pyspark.sql import functions as F >>> sorted(gdf.agg(F.min(df.age)).collect()) - [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] + [Row(name='Alice', min(age)=2), Row(name='Bob', min(age)=4)] """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): @@ -96,8 +97,12 @@ def agg(self, *exprs): def count(self): """Counts the number of records for each group. - >>> sorted(df.groupBy(df.age).count().collect()) - [Row(age=2, count=1), Row(age=5, count=1)] + >>> df = spark.createDataFrame([(7, 'Bob'), (7, 'Alice'), (5, 'Bob')], ['age', 'name']) + >>> df.groupBy(df.age).count().sort(df.age).collect() + [Row(age=5, count=1), Row(age=7, count=2)] + + >>> df.groupBy(df.age).count().sort(df.age, ascending=False).collect() + [Row(age=7, count=2), Row(age=5, count=1)] """ @df_varargs_api @@ -109,8 +114,12 @@ def mean(self, *cols): :param cols: list of column names (string). Non-numeric columns are ignored. + >>> df = spark.createDataFrame([(2, 'Alice'), (5, 'Bob')], ['age', 'name']) >>> df.groupBy().mean('age').collect() [Row(avg(age)=3.5)] + + >>> df3 = spark.createDataFrame([(2, 'Alice', 80), (5, 'Bob', 85)], + ['age', 'name', 'height']) >>> df3.groupBy().mean('age', 'height').collect() [Row(avg(age)=3.5, avg(height)=82.5)] """ @@ -124,8 +133,12 @@ def avg(self, *cols): :param cols: list of column names (string). Non-numeric columns are ignored. + >>> df = spark.createDataFrame([(2, 'Alice'), (5, 'Bob')], ['age', 'name']) >>> df.groupBy().avg('age').collect() [Row(avg(age)=3.5)] + + >>> df3 = spark.createDataFrame([(2, 'Alice', 80), (5, 'Bob', 85)], + ['age', 'name', 'height']) >>> df3.groupBy().avg('age', 'height').collect() [Row(avg(age)=3.5, avg(height)=82.5)] """ @@ -135,8 +148,12 @@ def avg(self, *cols): def max(self, *cols): """Computes the max value for each numeric columns for each group. + >>> df = spark.createDataFrame([(2, 'Alice'), (5, 'Bob')], ['age', 'name']) >>> df.groupBy().max('age').collect() [Row(max(age)=5)] + + >>> df3 = spark.createDataFrame([(2, 'Alice', 80), (5, 'Bob', 85)], + ['age', 'name', 'height']) >>> df3.groupBy().max('age', 'height').collect() [Row(max(age)=5, max(height)=85)] """ @@ -148,8 +165,12 @@ def min(self, *cols): :param cols: list of column names (string). Non-numeric columns are ignored. + >>> df = spark.createDataFrame([(2, 'Alice'), (5, 'Bob')], ['age', 'name']) >>> df.groupBy().min('age').collect() [Row(min(age)=2)] + + >>> df3 = spark.createDataFrame([(2, 'Alice', 80), (5, 'Bob', 85)], + ['age', 'name', 'height']) >>> df3.groupBy().min('age', 'height').collect() [Row(min(age)=2, min(height)=80)] """ @@ -161,8 +182,12 @@ def sum(self, *cols): :param cols: list of column names (string). Non-numeric columns are ignored. + >>> df = spark.createDataFrame([(2, 'Alice'), (5, 'Bob')], ['age', 'name']) >>> df.groupBy().sum('age').collect() [Row(sum(age)=7)] + + >>> df3 = spark.createDataFrame([(2, 'Alice', 80), (5, 'Bob', 85)], + ['age', 'name', 'height']) >>> df3.groupBy().sum('age', 'height').collect() [Row(sum(age)=7, sum(height)=165)] """ @@ -180,6 +205,12 @@ def pivot(self, pivot_col, values=None): # Compute the sum of earnings for each year by course with each course as a separate column + >>> df4 = spark.createDataFrame([("dotNET", 10000, 2012), + ("Java", 20000, 2012), + ("dotNET", 5000, 2012), + ("dotNET", 48000, 2013), + ("Java", 30000, 2013)], + ['course', 'earnings', 'year']) >>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").collect() [Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)] @@ -206,17 +237,6 @@ def _test(): .getOrCreate() sc = spark.sparkContext globs['sc'] = sc - globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ - .toDF(StructType([StructField('age', IntegerType()), - StructField('name', StringType())])) - globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80), - Row(name='Bob', age=5, height=85)]).toDF() - globs['df4'] = sc.parallelize([Row(course="dotNET", year=2012, earnings=10000), - Row(course="Java", year=2012, earnings=20000), - Row(course="dotNET", year=2012, earnings=5000), - Row(course="dotNET", year=2013, earnings=48000), - Row(course="Java", year=2013, earnings=30000)]).toDF() - (failure_count, test_count) = doctest.testmod( pyspark.sql.group, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 91c2b17049fa..f358791157ab 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -467,6 +467,8 @@ def mode(self, saveMode): * `error`: Throw an exception if data already exists. * `ignore`: Silently ignore this operation if data already exists. + >>> import os, tempfile + >>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned') >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ # At the JVM side, the default value of mode is already set to "error". @@ -481,6 +483,8 @@ def format(self, source): :param source: string, name of the data source, e.g. 'json', 'parquet'. + >>> import os, tempfile + >>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned') >>> df.write.format('json').save(os.path.join(tempfile.mkdtemp(), 'data')) """ self._jwrite = self._jwrite.format(source) @@ -510,6 +514,8 @@ def partitionBy(self, *cols): :param cols: name of columns + >>> import os, tempfile + >>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned') >>> df.write.partitionBy('year', 'month').parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ if len(cols) == 1 and isinstance(cols[0], (list, tuple)): @@ -536,6 +542,8 @@ def save(self, path=None, format=None, mode=None, partitionBy=None, **options): :param partitionBy: names of partitioning columns :param options: all other string options + >>> import os, tempfile + >>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned') >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode).options(**options) @@ -609,6 +617,8 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm This applies to timestamp type. If None is set, it uses the default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. + >>> import os, tempfile + >>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned') >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) @@ -634,6 +644,8 @@ def parquet(self, path, mode=None, partitionBy=None, compression=None): is set, it uses the value specified in ``spark.sql.parquet.compression.codec``. + >>> import os, tempfile + >>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned') >>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) @@ -701,6 +713,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No This applies to timestamp type. If None is set, it uses the default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. + >>> import os, tempfile + >>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned') >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) @@ -728,6 +742,7 @@ def orc(self, path, mode=None, partitionBy=None, compression=None): This will override ``orc.compress``. If None is set, it uses the default value, ``snappy``. + >>> import os, tempfile >>> orc_df = spark.read.orc('python/test_support/sql/orc_partitioned') >>> orc_df.write.orc(os.path.join(tempfile.mkdtemp(), 'data')) """ @@ -767,7 +782,6 @@ def jdbc(self, url, table, mode=None, properties=None): def _test(): import doctest import os - import tempfile import py4j from pyspark.context import SparkContext from pyspark.sql import SparkSession, Row @@ -782,11 +796,8 @@ def _test(): except py4j.protocol.Py4JError: spark = SparkSession(sc) - globs['tempfile'] = tempfile - globs['os'] = os globs['sc'] = sc globs['spark'] = spark - globs['df'] = spark.read.parquet('python/test_support/sql/parquet_partitioned') (failure_count, test_count) = doctest.testmod( pyspark.sql.readwriter, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 1e40b9c39fc4..5f8abc930b52 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -51,8 +51,15 @@ def toDF(self, schema=None, sampleRatio=None): :param samplingRatio: the sample ratio of rows used for inferring :return: a DataFrame - >>> rdd.toDF().collect() - [Row(name=u'Alice', age=1)] + >>> from pyspark.sql import Row + >>> rdd = sc.parallelize([Row(field1=1, field2="row1"), Row(field1=2, field2="row2")]) + >>> rdd.toDF().show() + +------+------+ + |field1|field2| + +------+------+ + | 1| row1| + | 2| row2| + +------+------+ """ return sparkSession.createDataFrame(self, schema, sampleRatio) @@ -533,10 +540,11 @@ def sql(self, sqlQuery): :return: :class:`DataFrame` + >>> df = spark.createDataFrame([(1, 'row1'), (2, 'row2')], ['field1', 'field2']) >>> df.createOrReplaceTempView("table1") >>> df2 = spark.sql("SELECT field1 AS f1, field2 as f2 from table1") >>> df2.collect() - [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] + [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2')] """ return DataFrame(self._jsparkSession.sql(sqlQuery), self._wrapped) @@ -546,6 +554,7 @@ def table(self, tableName): :return: :class:`DataFrame` + >>> df = spark.createDataFrame([(1, 'row1'), (2, 'row2')], ['field1', 'field2']) >>> df.createOrReplaceTempView("table1") >>> df2 = spark.table("table1") >>> sorted(df.collect()) == sorted(df2.collect()) @@ -627,11 +636,6 @@ def _test(): sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc globs['spark'] = SparkSession(sc) - globs['rdd'] = rdd = sc.parallelize( - [Row(field1=1, field2="row1"), - Row(field1=2, field2="row2"), - Row(field1=3, field2="row3")]) - globs['df'] = rdd.toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.session, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 35fc46929168..6535646ede1a 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -121,6 +121,7 @@ def __init__(self, jsqm): def active(self): """Returns a list of active queries associated with this SQLContext + >>> sdf = spark.readStream.format('text').load('python/test_support/sql/streaming') >>> sq = sdf.writeStream.format('memory').queryName('this_query').start() >>> sqm = spark.streams >>> # get the list of active streaming queries @@ -136,6 +137,7 @@ def get(self, id): """Returns an active query from this SQLContext or throws exception if an active query with this name doesn't exist. + >>> sdf = spark.readStream.format('text').load('python/test_support/sql/streaming') >>> sq = sdf.writeStream.format('memory').queryName('this_query').start() >>> sq.name u'this_query' @@ -568,6 +570,8 @@ def schema(self, schema): :param schema: a :class:`pyspark.sql.types.StructType` object + >>> from pyspark.sql.types import StructType, StructField, StringType + >>> sdf_schema = StructType([StructField("data", StringType(), False)]) >>> s = spark.readStream.schema(sdf_schema) """ from pyspark.sql import SparkSession @@ -612,9 +616,12 @@ def load(self, path=None, format=None, schema=None, **options): :param schema: optional :class:`pyspark.sql.types.StructType` for the input schema. :param options: all other string options - >>> json_sdf = spark.readStream.format("json") \\ - ... .schema(sdf_schema) \\ - ... .load(tempfile.mkdtemp()) + >>> import tempfile + >>> from pyspark.sql.types import StructType, StructField, StringType + >>> sdf_schema = StructType([StructField("data", StringType(), False)]) + >>> json_sdf = (spark.readStream.format("json") + ... .schema(sdf_schema) + ... .load(tempfile.mkdtemp())) >>> json_sdf.isStreaming True >>> json_sdf.schema == sdf_schema @@ -690,7 +697,10 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, This applies to timestamp type. If None is set, it uses the default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. - >>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema) + >>> import tempfile + >>> from pyspark.sql.types import StructType, StructField, StringType + >>> sdf_schema = StructType([StructField("data", StringType(), False)]) + >>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema=sdf_schema) >>> json_sdf.isStreaming True >>> json_sdf.schema == sdf_schema @@ -719,6 +729,9 @@ def parquet(self, path): .. note:: Experimental. + >>> import tempfile + >>> from pyspark.sql.types import StructType, StructField, StringType + >>> sdf_schema = StructType([StructField("data", StringType(), False)]) >>> parquet_sdf = spark.readStream.schema(sdf_schema).parquet(tempfile.mkdtemp()) >>> parquet_sdf.isStreaming True @@ -823,6 +836,9 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. + >>> import tempfile + >>> from pyspark.sql.types import StructType, StructField, StringType + >>> sdf_schema = StructType([StructField("data", StringType(), False)]) >>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema) >>> csv_sdf.isStreaming True @@ -876,6 +892,7 @@ def outputMode(self, outputMode): .. note:: Experimental. + >>> sdf = spark.readStream.format('text').load('python/test_support/sql/streaming') >>> writer = sdf.writeStream.outputMode('append') """ if not outputMode or type(outputMode) != str or len(outputMode.strip()) == 0: @@ -891,6 +908,7 @@ def format(self, source): :param source: string, name of the data source, which for now can be 'parquet'. + >>> sdf = spark.readStream.format('text').load('python/test_support/sql/streaming') >>> writer = sdf.writeStream.format('json') """ self._jwrite = self._jwrite.format(source) @@ -942,6 +960,7 @@ def queryName(self, queryName): :param queryName: unique name for the query + >>> sdf = spark.readStream.format('text').load('python/test_support/sql/streaming') >>> writer = sdf.writeStream.queryName('streaming_query') """ if not queryName or type(queryName) != str or len(queryName.strip()) == 0: @@ -959,6 +978,7 @@ def trigger(self, processingTime=None): :param processingTime: a processing time interval as a string, e.g. '5 seconds', '1 minute'. + >>> sdf = spark.readStream.format('text').load('python/test_support/sql/streaming') >>> # trigger the query for execution every 5 seconds >>> writer = sdf.writeStream.trigger(processingTime='5 seconds') """ @@ -997,6 +1017,7 @@ def start(self, path=None, format=None, partitionBy=None, queryName=None, **opti :param options: All other string options. You may want to provide a `checkpointLocation` for most streams, however it is not required for a `memory` stream. + >>> sdf = spark.readStream.format('text').load('python/test_support/sql/streaming') >>> sq = sdf.writeStream.format('memory').queryName('this_query').start() >>> sq.isActive True @@ -1041,15 +1062,8 @@ def _test(): except py4j.protocol.Py4JError: spark = SparkSession(sc) - globs['tempfile'] = tempfile - globs['os'] = os globs['spark'] = spark globs['sqlContext'] = SQLContext.getOrCreate(spark.sparkContext) - globs['sdf'] = \ - spark.readStream.format('text').load('python/test_support/sql/streaming') - globs['sdf_schema'] = StructType([StructField("data", StringType(), False)]) - globs['df'] = \ - globs['spark'].readStream.format('text').load('python/test_support/sql/streaming') globs['sqs'] = StreamingQueryStatus( spark.sparkContext._jvm.org.apache.spark.sql.streaming.StreamingQueryStatus.testStatus())