diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 81fd4e782628a..4639da690662a 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -183,13 +183,33 @@ def __init__(self, jc): # container operators __contains__ = _bin_op("contains") - __getitem__ = _bin_op("apply") # bitwise operators bitwiseOR = _bin_op("bitwiseOR") bitwiseAND = _bin_op("bitwiseAND") bitwiseXOR = _bin_op("bitwiseXOR") + def __getitem__(self, key): + if isinstance(key, slice): + if key.step is not None and key.step != 1: + raise ValueError("PySpark doesn't support slice with step") + if key.start is not None and key.stop is not None: + # str[i:j] + return self.substr(key.start, key.stop - key.start) + elif key.start is not None: + # str[i:] + return self.substr(key.start) + elif key.stop is not None: + # str[:j] + return self.substr(1, key.stop - 1) + else: + # str[:] + return self + else: + jc = key._jc if isinstance(key, Column) else key + njc = getattr(self._jc, "apply")(jc) + return Column(njc) + @since(1.3) def getItem(self, key): """ @@ -250,7 +270,7 @@ def __iter__(self): @ignore_unicode_prefix @since(1.3) - def substr(self, startPos, length): + def substr(self, startPos, length=None): """ Return a :class:`Column` which is a substring of the column. @@ -259,10 +279,27 @@ def substr(self, startPos, length): >>> df.select(df.name.substr(1, 3).alias("col")).collect() [Row(col=u'Ali'), Row(col=u'Bob')] + >>> df.select(df.name[1:4].alias("col")).collect() + [Row(col=u'Ali'), Row(col=u'Bob')] + >>> df.select(df.name[2:].alias("col")).collect() + [Row(col=u'lice'), Row(col=u'ob')] + >>> df.select(df.name[:3].alias("col")).collect() + [Row(col=u'Al'), Row(col=u'Bo')] + >>> df.select(df.name[:].alias("col")).collect() + [Row(col=u'Alice'), Row(col=u'Bob')] """ - if type(startPos) != type(length): + if length is not None and type(startPos) != type(length): raise TypeError("Can not mix the type") if isinstance(startPos, (int, long)): + javaMaxInt = SparkContext._active_spark_context._jvm.java.lang.Integer.MAX_VALUE + if startPos > javaMaxInt: + raise ValueError("startPos is larger than the java max int value " + "which is not supported by pyspark, startPos=" + str(startPos)) + if length is None: + length = javaMaxInt + elif length > javaMaxInt: + raise ValueError("length is larger than the java max int value " + "which is not supported by pyspark, length=" + str(length)) jc = self._jc.substr(startPos, length) elif isinstance(startPos, Column): jc = self._jc.substr(startPos._jc, length._jc) @@ -270,8 +307,6 @@ def substr(self, startPos, length): raise TypeError("Unexpected type: %s" % type(startPos)) return Column(jc) - __getslice__ = substr - @ignore_unicode_prefix @since(1.3) def inSet(self, *cols):