diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index a33744fc7d04..15c4ac155343 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -1720,6 +1720,12 @@ cdef class CoreWorker: CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference( c_object_id) + def get_owner_address(self, ObjectRef object_ref): + cdef: + CObjectID c_object_id = object_ref.native() + return CCoreWorkerProcess.GetCoreWorker().GetOwnerAddress( + c_object_id).SerializeAsString() + def serialize_and_promote_object_ref(self, ObjectRef object_ref): cdef: CObjectID c_object_id = object_ref.native() diff --git a/python/ray/data/block.py b/python/ray/data/block.py index 41ccd545deaa..a9b570115c8e 100644 --- a/python/ray/data/block.py +++ b/python/ray/data/block.py @@ -17,7 +17,7 @@ # # Block data can be accessed in a uniform way via ``BlockAccessors`` such as # ``SimpleBlockAccessor``, ``ArrowBlockAccessor``, and ``TensorBlockAccessor``. -Block = Union[List[T], np.ndarray, "pyarrow.Table"] +Block = Union[List[T], np.ndarray, "pyarrow.Table", bytes] @DeveloperAPI @@ -124,6 +124,10 @@ def for_block(block: Block) -> "BlockAccessor[T]": from ray.data.impl.arrow_block import \ ArrowBlockAccessor return ArrowBlockAccessor(block) + elif isinstance(block, bytes): + from ray.data.impl.arrow_block import \ + ArrowBlockAccessor + return ArrowBlockAccessor.from_bytes(block) elif isinstance(block, list): from ray.data.impl.simple_block import \ SimpleBlockAccessor diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 367b0a744e17..e80e7489c622 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -1257,7 +1257,8 @@ def to_modin(self) -> "modin.DataFrame": pd_objs = self.to_pandas() return from_partitions(pd_objs, axis=0) - def to_spark(self) -> "pyspark.sql.DataFrame": + def to_spark(self, + spark: "pyspark.sql.SparkSession") -> "pyspark.sql.DataFrame": """Convert this dataset into a Spark dataframe. Time complexity: O(dataset size / parallelism) @@ -1265,7 +1266,14 @@ def to_spark(self) -> "pyspark.sql.DataFrame": Returns: A Spark dataframe created from this dataset. """ - raise NotImplementedError # P2 + import raydp + core_worker = ray.worker.global_worker.core_worker + locations = [ + core_worker.get_owner_address(block) + for block in self.get_blocks() + ] + return raydp.spark.ray_dataset_to_spark_dataframe( + spark, self.schema(), self.get_blocks(), locations) def to_pandas(self) -> List[ObjectRef["pandas.DataFrame"]]: """Convert this dataset into a distributed set of Pandas dataframes. diff --git a/python/ray/data/impl/arrow_block.py b/python/ray/data/impl/arrow_block.py index a9b8f59fb553..18c2f3753d56 100644 --- a/python/ray/data/impl/arrow_block.py +++ b/python/ray/data/impl/arrow_block.py @@ -135,6 +135,11 @@ def __init__(self, table: "pyarrow.Table"): raise ImportError("Run `pip install pyarrow` for Arrow support") self._table = table + @classmethod + def from_bytes(cls, data: bytes): + reader = pyarrow.ipc.open_stream(data) + return cls(reader.read_all()) + def iter_rows(self) -> Iterator[ArrowRow]: outer = self diff --git a/python/ray/data/read_api.py b/python/ray/data/read_api.py index 01dd79515f5e..a2daf7732f8e 100644 --- a/python/ray/data/read_api.py +++ b/python/ray/data/read_api.py @@ -472,7 +472,6 @@ def from_pandas(dfs: List[ObjectRef["pandas.DataFrame"]]) -> Dataset[ArrowRow]: return Dataset(BlockList(blocks, ray.get(list(metadata)))) -@PublicAPI(stability="beta") def from_numpy(ndarrays: List[ObjectRef[np.ndarray]]) -> Dataset[np.ndarray]: """Create a dataset from a set of NumPy ndarrays. @@ -490,34 +489,40 @@ def from_numpy(ndarrays: List[ObjectRef[np.ndarray]]) -> Dataset[np.ndarray]: @PublicAPI(stability="beta") -def from_arrow(tables: List[ObjectRef["pyarrow.Table"]]) -> Dataset[ArrowRow]: +def from_arrow(tables: List[ObjectRef[Union["pyarrow.Table", bytes]]] + ) -> Dataset[ArrowRow]: """Create a dataset from a set of Arrow tables. Args: - dfs: A list of Ray object references to Arrow tables. + tables: A list of Ray object references to Arrow tables, + or its streaming format in bytes. Returns: Dataset holding Arrow records from the tables. """ - get_metadata = cached_remote_fn(_get_metadata) metadata = [get_metadata.remote(t) for t in tables] return Dataset(BlockList(tables, ray.get(metadata))) @PublicAPI(stability="beta") -def from_spark(df: "pyspark.sql.DataFrame", *, - parallelism: int = 200) -> Dataset[ArrowRow]: +def from_spark(df: "pyspark.sql.DataFrame", + *, + parallelism: Optional[int] = None) -> Dataset[ArrowRow]: """Create a dataset from a Spark dataframe. Args: + spark: A SparkSession, which must be created by RayDP (Spark-on-Ray). df: A Spark dataframe, which must be created by RayDP (Spark-on-Ray). - parallelism: The amount of parallelism to use for the dataset. + parallelism: The amount of parallelism to use for the dataset. + If not provided, it will be equal to the number of partitions of + the original Spark dataframe. Returns: Dataset holding Arrow records read from the dataframe. """ - raise NotImplementedError # P2 + import raydp + return raydp.spark.spark_dataframe_to_ray_dataset(df, parallelism) def _df_to_block(df: "pandas.DataFrame") -> Block[ArrowRow]: diff --git a/python/ray/data/tests/test_raydp_dataset.py b/python/ray/data/tests/test_raydp_dataset.py new file mode 100644 index 000000000000..c86c6a0803c1 --- /dev/null +++ b/python/ray/data/tests/test_raydp_dataset.py @@ -0,0 +1,44 @@ +import pytest +import ray +import raydp + + +@pytest.fixture(scope="function") +def spark_on_ray_small(request): + ray.init(num_cpus=2, include_dashboard=False) + spark = raydp.init_spark("test", 1, 1, "500 M") + + def stop_all(): + raydp.stop_spark() + ray.shutdown() + + request.addfinalizer(stop_all) + return spark + + +def test_raydp_roundtrip(spark_on_ray_small): + spark = spark_on_ray_small + spark_df = spark.createDataFrame([(1, "a"), (2, "b"), (3, "c")], + ["one", "two"]) + rows = [(r.one, r.two) for r in spark_df.take(3)] + ds = ray.data.from_spark(spark_df) + values = [(r["one"], r["two"]) for r in ds.take(6)] + assert values == rows + df = ds.to_spark(spark) + rows_2 = [(r.one, r.two) for r in df.take(3)] + assert values == rows_2 + + +def test_raydp_to_spark(spark_on_ray_small): + spark = spark_on_ray_small + n = 5 + ds = ray.data.range_arrow(n) + values = [r["value"] for r in ds.take(5)] + df = ds.to_spark(spark) + rows = [r.value for r in df.take(5)] + assert values == rows + + +if __name__ == "__main__": + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/requirements/data_processing/requirements.txt b/python/requirements/data_processing/requirements.txt index 7f49e686a9b1..c835ec7472a3 100644 --- a/python/requirements/data_processing/requirements.txt +++ b/python/requirements/data_processing/requirements.txt @@ -7,3 +7,4 @@ s3fs modin>=0.8.3; python_version < '3.7' modin>=0.10.0; python_version >= '3.7' pytest-repeat +raydp-nightly