Skip to content

Commit 0a1a73b

Browse files
committed
Merge branch 'df1' of github.com:rxin/spark into df1
2 parents 23b4427 + 828f70d commit 0a1a73b

File tree

3 files changed

+28
-29
lines changed

3 files changed

+28
-29
lines changed

python/pyspark/sql.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1973,7 +1973,7 @@ def collect(self):
19731973
[Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')]
19741974
"""
19751975
with SCCallSiteSync(self._sc) as css:
1976-
bytesInJava = self._jdf.collectToPython().iterator()
1976+
bytesInJava = self._jdf.javaToPython().collect().iterator()
19771977
cls = _create_cls(self.schema())
19781978
tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
19791979
tempFile.close()
@@ -1997,14 +1997,14 @@ def take(self, num):
19971997
return self.limit(num).collect()
19981998

19991999
def map(self, f):
2000+
""" Return a new RDD by applying a function to each Row, it's a
2001+
shorthand for df.rdd.map()
2002+
"""
20002003
return self.rdd.map(f)
20012004

2002-
# Convert each object in the RDD to a Row with the right class
2003-
# for this DataFrame, so that fields can be accessed as attributes.
20042005
def mapPartitions(self, f, preservesPartitioning=False):
20052006
"""
2006-
Return a new RDD by applying a function to each partition of this RDD,
2007-
while tracking the index of the original partition.
2007+
Return a new RDD by applying a function to each partition.
20082008
20092009
>>> rdd = sc.parallelize([1, 2, 3, 4], 4)
20102010
>>> def f(iterator): yield 1
@@ -2013,21 +2013,28 @@ def mapPartitions(self, f, preservesPartitioning=False):
20132013
"""
20142014
return self.rdd.mapPartitions(f, preservesPartitioning)
20152015

2016-
# We override the default cache/persist/checkpoint behavior
2017-
# as we want to cache the underlying DataFrame object in the JVM,
2018-
# not the PythonRDD checkpointed by the super class
20192016
def cache(self):
2017+
""" Persist with the default storage level (C{MEMORY_ONLY_SER}).
2018+
"""
20202019
self.is_cached = True
20212020
self._jdf.cache()
20222021
return self
20232022

20242023
def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER):
2024+
""" Set the storage level to persist its values across operations
2025+
after the first time it is computed. This can only be used to assign
2026+
a new storage level if the RDD does not have a storage level set yet.
2027+
If no storage level is specified defaults to (C{MEMORY_ONLY_SER}).
2028+
"""
20252029
self.is_cached = True
20262030
javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel)
20272031
self._jdf.persist(javaStorageLevel)
20282032
return self
20292033

20302034
def unpersist(self, blocking=True):
2035+
""" Mark it as non-persistent, and remove all blocks for it from
2036+
memory and disk.
2037+
"""
20312038
self.is_cached = False
20322039
self._jdf.unpersist(blocking)
20332040
return self
@@ -2036,10 +2043,12 @@ def unpersist(self, blocking=True):
20362043
# rdd = self._jdf.coalesce(numPartitions, shuffle, None)
20372044
# return DataFrame(rdd, self.sql_ctx)
20382045

2039-
# def repartition(self, numPartitions):
2040-
# rdd = self._jdf.repartition(numPartitions, None)
2041-
# return DataFrame(rdd, self.sql_ctx)
2042-
#
2046+
def repartition(self, numPartitions):
2047+
""" Return a new :class:`DataFrame` that has exactly `numPartitions`
2048+
partitions.
2049+
"""
2050+
rdd = self._jdf.repartition(numPartitions, None)
2051+
return DataFrame(rdd, self.sql_ctx)
20432052

20442053
def sample(self, withReplacement, fraction, seed=None):
20452054
"""
@@ -2359,11 +2368,11 @@ def _scalaMethod(name):
23592368
""" Translate operators into methodName in Scala
23602369
23612370
For example:
2362-
>>> scalaMethod('+')
2371+
>>> _scalaMethod('+')
23632372
'$plus'
2364-
>>> scalaMethod('>=')
2373+
>>> _scalaMethod('>=')
23652374
'$greater$eq'
2366-
>>> scalaMethod('cast')
2375+
>>> _scalaMethod('cast')
23672376
'cast'
23682377
"""
23692378
return ''.join(SCALA_METHOD_MAPPINGS.get(c, c) for c in name)

python/pyspark/tests.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -946,8 +946,7 @@ def test_apply_schema_with_udt(self):
946946
schema = StructType([StructField("label", DoubleType(), False),
947947
StructField("point", ExamplePointUDT(), False)])
948948
df = self.sqlCtx.applySchema(rdd, schema)
949-
# TODO: test collect with UDT
950-
point = df.rdd.first().point
949+
point = df.head().point
951950
self.assertEquals(point, ExamplePoint(1.0, 2.0))
952951

953952
def test_parquet_with_udt(self):
@@ -984,11 +983,12 @@ def test_column_select(self):
984983
self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())
985984

986985
def test_aggregator(self):
987-
from pyspark.sql import Aggregator as Agg
988986
df = self.df
989987
g = df.groupBy()
990988
self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
991989
self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
990+
# TODO(davies): fix aggregators
991+
from pyspark.sql import Aggregator as Agg
992992
# self.assertEqual((0, '100'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first()))
993993

994994

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -590,17 +590,7 @@ class DataFrame protected[sql](
590590
*/
591591
protected[sql] def javaToPython: JavaRDD[Array[Byte]] = {
592592
val fieldTypes = schema.fields.map(_.dataType)
593-
val jrdd = this.rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD()
593+
val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD()
594594
SerDeUtil.javaToPython(jrdd)
595595
}
596-
/**
597-
* Serializes the Array[Row] returned by collect(), using the same format as javaToPython.
598-
*/
599-
protected[sql] def collectToPython: JList[Array[Byte]] = {
600-
val fieldTypes = schema.fields.map(_.dataType)
601-
val pickle = new Pickler
602-
new ArrayList[Array[Byte]](collect().map { row =>
603-
EvaluatePython.rowToArray(row, fieldTypes)
604-
}.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable)
605-
}
606596
}

0 commit comments

Comments
 (0)