diff --git a/crates/runtime/src/datafusion/functions/geospatial/accessors/dim.rs b/crates/runtime/src/datafusion/functions/geospatial/accessors/dim.rs index 9b2e85944..7db37850f 100644 --- a/crates/runtime/src/datafusion/functions/geospatial/accessors/dim.rs +++ b/crates/runtime/src/datafusion/functions/geospatial/accessors/dim.rs @@ -26,6 +26,7 @@ use arrow_schema::DataType; use datafusion::logical_expr::scalar_doc_sections::DOC_SECTION_OTHER; use datafusion::logical_expr::{ColumnarValue, Documentation, ScalarUDFImpl, Signature}; use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::ScalarFunctionArgs; use geoarrow::array::AsNativeArray; use geoarrow::datatypes::NativeType; use geoarrow::scalar::Geometry; @@ -63,8 +64,8 @@ impl ScalarUDFImpl for GeomDimension { Ok(DataType::UInt8) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { - dim_impl(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + dim_impl(&args.args) } fn documentation(&self) -> Option<&Documentation> { @@ -198,9 +199,13 @@ mod tests { ]; for (array, exp) in args { - let args = vec![ColumnarValue::Array(array.clone())]; + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(array)], + number_rows: 2, + return_type: &DataType::Null, + }; let dim_fn = GeomDimension::new(); - let result = dim_fn.invoke_batch(&args, 2).unwrap().to_array(2).unwrap(); + let result = dim_fn.invoke_with_args(args).unwrap().to_array(2).unwrap(); let result = result.as_primitive::(); assert_eq!(result.value(0), exp); } diff --git a/crates/runtime/src/datafusion/functions/geospatial/accessors/geometry.rs b/crates/runtime/src/datafusion/functions/geospatial/accessors/geometry.rs new file mode 100644 index 000000000..34dd2a176 --- /dev/null +++ b/crates/runtime/src/datafusion/functions/geospatial/accessors/geometry.rs @@ -0,0 +1,238 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::datafusion::functions::geospatial::data_types::{ + any_single_geometry_type_input, parse_to_native_array, +}; +use arrow_array::builder::Float64Builder; +use arrow_schema::DataType; +use arrow_schema::DataType::Float64; +use datafusion::logical_expr::scalar_doc_sections::DOC_SECTION_OTHER; +use datafusion::logical_expr::{ColumnarValue, Documentation, ScalarUDFImpl, Signature}; +use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::ScalarFunctionArgs; +use geo_traits::CoordTrait; +use geo_traits::RectTrait; +use geoarrow::algorithm::geo::BoundingRect; +use geoarrow::trait_::ArrayAccessor; +use std::any::Any; +use std::sync::{Arc, OnceLock}; + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +macro_rules! create_extremum_udf { + ($name:ident, $func_name:expr, $index:expr, $is_max:expr, $doc:expr, $syntax:expr) => { + #[derive(Debug)] + pub struct $name { + signature: Signature, + } + + impl $name { + pub fn new() -> Self { + Self { + signature: any_single_geometry_type_input(), + } + } + } + + impl ScalarUDFImpl for $name { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &'static str { + $func_name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Float64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + get_extremum(&args.args, $index, $is_max) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(DOCUMENTATION.get_or_init(|| { + Documentation::builder(DOC_SECTION_OTHER, $doc, $syntax) + .with_argument("g1", "geometry") + .with_related_udf("st_xmin") + .with_related_udf("st_ymin") + .with_related_udf("st_zmin") + .with_related_udf("st_xmax") + .build() + })) + } + } + }; +} + +create_extremum_udf!( + MinX, + "st_xmin", + 0, + false, + "Returns the minimum longitude (X coordinate) of all points contained in the specified geometry.", + "ST_XMin(geom)" +); + +create_extremum_udf!( + MinY, + "st_ymin", + 1, + false, + "Returns the minimum latitude (Y coordinate) of all points contained in the specified geometry.", + "ST_YMin(geom)" +); + +create_extremum_udf!( + MaxX, + "st_xmax", + 0, + true, + "Returns the maximum longitude (X coordinate) of all points contained in the specified geometry.", + "ST_XMax(geom)" +); + +create_extremum_udf!( + MaxY, + "st_ymax", + 1, + true, + "Returns the maximum latitude (Y coordinate) of all points contained in the specified geometry.", + "ST_YMax(geom)" +); + +fn get_extremum(args: &[ColumnarValue], index: i64, is_max: bool) -> Result { + let arg = ColumnarValue::values_to_arrays(args)? + .into_iter() + .next() + .ok_or_else(|| DataFusionError::Execution("Expected only one argument".to_string()))?; + + let array = ColumnarValue::values_to_arrays(args)? + .into_iter() + .next() + .ok_or_else(|| DataFusionError::Execution("Expected at least one argument".to_string()))?; + + let native_array = parse_to_native_array(&array)?; + let native_array_ref = native_array + .as_ref() + .bounding_rect() + .map_err(|e| DataFusionError::Execution(format!("Error getting bounding rect: {e}")))?; + + let mut output_array = Float64Builder::with_capacity(arg.len()); + for rect in native_array_ref.iter() { + match (index, is_max) { + (0, false) => output_array.append_option(rect.map(|r| r.min().x())), + (1, false) => output_array.append_option(rect.map(|r| r.min().y())), + (0, true) => output_array.append_option(rect.map(|r| r.max().x())), + (1, true) => output_array.append_option(rect.map(|r| r.max().y())), + _ => { + return Err(DataFusionError::Execution( + "Index out of bounds".to_string(), + )) + } + } + } + Ok(ColumnarValue::Array(Arc::new(output_array.finish()))) +} + +#[cfg(test)] +mod tests { + use super::*; + use super::{MaxX, MaxY, MinX, MinY}; + use arrow_array::cast::AsArray; + use arrow_array::types::Float64Type; + use arrow_array::ArrayRef; + use datafusion::logical_expr::ColumnarValue; + use geo_types::{line_string, point, polygon}; + use geoarrow::array::{CoordType, LineStringBuilder, PointBuilder, PolygonBuilder}; + use geoarrow::datatypes::Dimension; + use geoarrow::ArrayBase; + + #[test] + #[allow(clippy::unwrap_used, clippy::float_cmp)] + fn test_extrema() { + let dim = Dimension::XY; + let ct = CoordType::Separated; + + let args: [(ArrayRef, [[f64; 2]; 4]); 3] = [ + ( + { + let data = vec![ + line_string![(x: 0., y: 0.), (x: 1., y: 0.), (x: 1., y: 1.), (x: 0., y: 1.), (x: 0., y: 0.)], + line_string![(x: -60., y: -30.), (x: 60., y: -30.)], + ]; + let array = + LineStringBuilder::from_line_strings(&data, dim, ct, Arc::default()) + .finish(); + array.to_array_ref() + }, + [[0., -60.], [1., 60.], [0., -30.], [1., -30.]], + ), + ( + { + let data = [point! {x: 0., y: 0.}, point! {x: 1., y: 1.}]; + let array = + PointBuilder::from_points(data.iter(), dim, ct, Arc::default()).finish(); + array.to_array_ref() + }, + [[0., 1.], [0., 1.], [0., 1.], [0., 1.]], + ), + ( + { + let data = vec![ + polygon![(x: 3.3, y: 30.2), (x: 4.7, y: 24.6), (x: 13.4, y: 25.1), (x: 24.4, y: 30.0),(x:3.3,y:30.4)], + polygon![(x: 3.2, y: 11.1), (x: 4.7, y: 24.6), (x: 13.4, y: 25.1), (x: 19.4, y: 31.0),(x:3.3,y:36.4)], + ]; + let array = + PolygonBuilder::from_polygons(&data, dim, ct, Arc::default()).finish(); + array.to_array_ref() + }, + [[3.3, 3.2], [24.4, 19.4], [24.6, 11.1], [30.4, 36.4]], + ), + ]; + + let udfs: Vec> = vec![ + Box::new(MinX::new()), + Box::new(MaxX::new()), + Box::new(MinY::new()), + Box::new(MaxY::new()), + ]; + + for (array, exp) in args { + for (i, udf) in udfs.iter().enumerate() { + let res = udf + .invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Array(array.clone())], + number_rows: 2, + return_type: &DataType::Null, + }) + .unwrap() + .to_array(2) + .unwrap(); + let res = res.as_primitive::(); + assert_eq!(res.value(0), exp[i][0]); + assert_eq!(res.value(1), exp[i][1]); + } + } + } +} diff --git a/crates/runtime/src/datafusion/functions/geospatial/accessors/line_string.rs b/crates/runtime/src/datafusion/functions/geospatial/accessors/line_string.rs index e77df8476..2f4110e6d 100644 --- a/crates/runtime/src/datafusion/functions/geospatial/accessors/line_string.rs +++ b/crates/runtime/src/datafusion/functions/geospatial/accessors/line_string.rs @@ -28,6 +28,7 @@ use datafusion::logical_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::ScalarFunctionArgs; use geo_traits::LineStringTrait; use geoarrow::array::{AsNativeArray, CoordType, PointBuilder}; use geoarrow::datatypes::Dimension; @@ -38,100 +39,76 @@ use snafu::ResultExt; use std::any::Any; use std::sync::{Arc, OnceLock}; -#[derive(Debug)] -pub struct EndPoint { - signature: Signature, -} - -impl EndPoint { - pub fn new() -> Self { - Self { - signature: Signature::exact(vec![LINE_STRING_TYPE.into()], Volatility::Immutable), - } - } -} - static DOCUMENTATION: OnceLock = OnceLock::new(); -impl ScalarUDFImpl for EndPoint { - fn as_any(&self) -> &dyn Any { - self - } +macro_rules! create_line_string_udf { + ($name:ident, $func_name:expr, $index:expr, $doc:expr, $syntax:expr) => { + #[derive(Debug)] + pub struct $name { + signature: Signature, + } - fn name(&self) -> &'static str { - "st_endpoint" - } + impl $name { + pub fn new() -> Self { + Self { + signature: Signature::exact( + vec![LINE_STRING_TYPE.into()], + Volatility::Immutable, + ), + } + } + } - fn signature(&self) -> &Signature { - &self.signature - } + impl ScalarUDFImpl for $name { + fn as_any(&self) -> &dyn Any { + self + } - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(POINT2D_TYPE.into()) - } + fn name(&self) -> &'static str { + $func_name + } - fn invoke(&self, args: &[ColumnarValue]) -> Result { - get_n_point(args, None) - } + fn signature(&self) -> &Signature { + &self.signature + } - fn documentation(&self) -> Option<&Documentation> { - Some(DOCUMENTATION.get_or_init(|| { - Documentation::builder( - DOC_SECTION_OTHER, - "Returns the last point of a LINESTRING geometry as a POINT. Returns NULL if the input is not a LINESTRING", - "ST_EndPoint(line_string)") - .with_argument("g1", "geometry") - .build() - })) - } -} + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(POINT2D_TYPE.into()) + } -#[derive(Debug)] -pub struct StartPoint { - signature: Signature, -} + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + get_n_point(&args.args, $index) + } -impl StartPoint { - pub fn new() -> Self { - Self { - signature: Signature::exact(vec![LINE_STRING_TYPE.into()], Volatility::Immutable), + fn documentation(&self) -> Option<&Documentation> { + Some(DOCUMENTATION.get_or_init(|| { + Documentation::builder(DOC_SECTION_OTHER, $doc, $syntax) + .with_argument("g1", "geometry") + .with_related_udf("st_startpoint") + .with_related_udf("st_pointn") + .with_related_udf("st_endpoint") + .build() + })) + } } - } + }; } -impl ScalarUDFImpl for StartPoint { - fn as_any(&self) -> &dyn Any { - self - } - - fn name(&self) -> &'static str { - "st_startpoint" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(POINT2D_TYPE.into()) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - get_n_point(args, Some(1)) - } - - fn documentation(&self) -> Option<&Documentation> { - Some(DOCUMENTATION.get_or_init(|| { - Documentation::builder( - DOC_SECTION_OTHER, - "Returns the first point of a LINESTRING geometry as a POINT.", - "ST_StartPoint(line_string)", - ) - .with_argument("g1", "geometry") - .build() - })) - } -} +create_line_string_udf!( + EndPoint, + "st_endpoint", + None, + "Returns the last point of a LINESTRING geometry as a POINT. Returns NULL if the input is not a LINESTRING", + "ST_EndPoint(line_string)" +); + +create_line_string_udf!( + StartPoint, + "st_startpoint", + Some(1), + "Returns the first point of a LINESTRING geometry as a POINT.", + "ST_StartPoint(geom)" +); #[derive(Debug)] pub struct PointN { @@ -175,14 +152,15 @@ impl ScalarUDFImpl for PointN { Ok(POINT2D_TYPE.into()) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = args.args; if args.len() < 2 { return Err(DataFusionError::Execution( "Expected two arguments in ST_PointN".to_string(), )); } let index = to_primitive_array::(&args[1])?.value(0); - get_n_point(args, Some(index)) + get_n_point(&args, Some(index)) } fn documentation(&self) -> Option<&Documentation> { @@ -245,6 +223,7 @@ fn get_n_point(args: &[ColumnarValue], n: Option) -> Result #[cfg(test)] mod tests { use super::*; + use super::{EndPoint, StartPoint}; use arrow_array::Array; use datafusion::logical_expr::ColumnarValue; use geo_types::line_string; @@ -255,7 +234,7 @@ mod tests { #[test] #[allow(clippy::unwrap_used)] - fn test_start_point() { + fn test_start_end_point() { let data = vec![ line_string![(x: 1., y: 1.), (x: 1., y: 0.), (x: 1., y: 1.)], line_string![(x: 2., y: 2.), (x: 3., y: 2.), (x: 3., y: 3.)], @@ -267,46 +246,31 @@ mod tests { CoordType::Separated, Arc::default(), ) - .finish(); - - let data = array.to_array_ref(); - let args = vec![ColumnarValue::Array(data)]; - let start_point = StartPoint::new(); - let result = start_point.invoke_batch(&args, 3).unwrap(); - let result = result.to_array(3).unwrap(); - assert_eq!(result.data_type(), &POINT2D_TYPE.into()); - let result = PointArray::try_from((result.as_ref(), Dimension::XY)).unwrap(); - assert_eq!(result.get(0).unwrap().to_wkt().unwrap(), "POINT(1 1)"); - assert_eq!(result.get(1).unwrap().to_wkt().unwrap(), "POINT(2 2)"); - assert_eq!(result.get(2).unwrap().to_wkt().unwrap(), "POINT(2 2)"); - } - - #[test] - #[allow(clippy::unwrap_used)] - fn test_end_point() { - let data = vec![ - line_string![(x: 0., y: 0.), (x: 1., y: 0.), (x: 1., y: 1.)], - line_string![(x: 2., y: 2.), (x: 3., y: 2.), (x: 3., y: 3.)], - line_string![(x: 2., y: 2.), (x: 3., y: 2.)], + .finish() + .to_array_ref(); + + let udfs: Vec> = + vec![Box::new(StartPoint::new()), Box::new(EndPoint::new())]; + let results: [[&str; 3]; 2] = [ + ["POINT(1 1)", "POINT(2 2)", "POINT(2 2)"], + ["POINT(1 1)", "POINT(3 3)", "POINT(3 2)"], ]; - let array = LineStringBuilder::from_line_strings( - &data, - Dimension::XY, - CoordType::Separated, - Arc::default(), - ) - .finish(); - let data = array.to_array_ref(); - let args = vec![ColumnarValue::Array(data)]; - let end_point = EndPoint::new(); - let result = end_point.invoke_batch(&args, 3).unwrap(); - let result = result.to_array(3).unwrap(); - assert_eq!(result.data_type(), &POINT2D_TYPE.into()); - let result = PointArray::try_from((result.as_ref(), Dimension::XY)).unwrap(); - assert_eq!(result.get(0).unwrap().to_wkt().unwrap(), "POINT(1 1)"); - assert_eq!(result.get(1).unwrap().to_wkt().unwrap(), "POINT(3 3)"); - assert_eq!(result.get(2).unwrap().to_wkt().unwrap(), "POINT(3 2)"); + for (idx, udf) in udfs.iter().enumerate() { + let result = udf + .invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Array(array.clone())], + number_rows: 3, + return_type: &DataType::Null, + }) + .unwrap(); + let result = result.to_array(3).unwrap(); + assert_eq!(result.data_type(), &POINT2D_TYPE.into()); + let result = PointArray::try_from((result.as_ref(), Dimension::XY)).unwrap(); + assert_eq!(result.get(0).unwrap().to_wkt().unwrap(), results[idx][0]); + assert_eq!(result.get(1).unwrap().to_wkt().unwrap(), results[idx][1]); + assert_eq!(result.get(2).unwrap().to_wkt().unwrap(), results[idx][2]); + } } #[test] @@ -335,12 +299,17 @@ mod tests { for (index, ok, exp) in cases { let data = array.to_array_ref(); - let args = vec![ - ColumnarValue::Array(data), - ColumnarValue::Scalar(index.into()), - ]; + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(data), + ColumnarValue::Scalar(index.into()), + ], + number_rows: 3, + return_type: &DataType::Null, + }; + let point_n = PointN::new(); - let result = point_n.invoke_batch(&args, 3); + let result = point_n.invoke_with_args(args); if ok { let result = result.unwrap().to_array(3).unwrap(); diff --git a/crates/runtime/src/datafusion/functions/geospatial/accessors/mod.rs b/crates/runtime/src/datafusion/functions/geospatial/accessors/mod.rs index 0889773ce..4ce379788 100644 --- a/crates/runtime/src/datafusion/functions/geospatial/accessors/mod.rs +++ b/crates/runtime/src/datafusion/functions/geospatial/accessors/mod.rs @@ -16,6 +16,7 @@ // under the License. mod dim; +mod geometry; mod line_string; mod point; mod srid; @@ -31,4 +32,8 @@ pub fn register_udfs(ctx: &SessionContext) { ctx.register_udf(srid::Srid::new().into()); ctx.register_udf(point::PointX::new().into()); ctx.register_udf(point::PointY::new().into()); + ctx.register_udf(geometry::MinX::new().into()); + ctx.register_udf(geometry::MinY::new().into()); + ctx.register_udf(geometry::MaxX::new().into()); + ctx.register_udf(geometry::MaxY::new().into()); } diff --git a/crates/runtime/src/datafusion/functions/geospatial/accessors/point.rs b/crates/runtime/src/datafusion/functions/geospatial/accessors/point.rs index 0881e9dff..b38c1341a 100644 --- a/crates/runtime/src/datafusion/functions/geospatial/accessors/point.rs +++ b/crates/runtime/src/datafusion/functions/geospatial/accessors/point.rs @@ -25,6 +25,7 @@ use datafusion::logical_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::ScalarFunctionArgs; use geo_traits::{CoordTrait, PointTrait}; use geoarrow::array::AsNativeArray; use geoarrow::error::GeoArrowError; @@ -34,101 +35,71 @@ use snafu::ResultExt; use std::any::Any; use std::sync::{Arc, OnceLock}; -#[derive(Debug)] -pub struct PointX { - signature: Signature, -} - -impl PointX { - pub fn new() -> Self { - Self { - signature: Signature::exact(vec![POINT2D_TYPE.into()], Volatility::Immutable), - } - } -} - static DOCUMENTATION: OnceLock = OnceLock::new(); +macro_rules! create_point_udf { + ($name:ident, $func_name:expr, $index:expr, $doc:expr, $syntax:expr) => { + #[derive(Debug)] + pub struct $name { + signature: Signature, + } -impl ScalarUDFImpl for PointX { - fn as_any(&self) -> &dyn Any { - self - } - - fn name(&self) -> &'static str { - "st_x" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(Float64) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - get_coord(args, 0) - } - - fn documentation(&self) -> Option<&Documentation> { - Some(DOCUMENTATION.get_or_init(|| { - Documentation::builder( - DOC_SECTION_OTHER, - "Returns the longitude (X coordinate) of a Point represented by geometry.", - "ST_X(geom)", - ) - .with_argument("g1", "geometry") - .build() - })) - } -} - -#[derive(Debug)] -pub struct PointY { - signature: Signature, -} + impl $name { + pub fn new() -> Self { + Self { + signature: Signature::exact(vec![POINT2D_TYPE.into()], Volatility::Immutable), + } + } + } -impl PointY { - pub fn new() -> Self { - Self { - signature: Signature::exact(vec![POINT2D_TYPE.into()], Volatility::Immutable), + impl ScalarUDFImpl for $name { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &'static str { + $func_name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Float64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + get_coord(&args.args, $index) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(DOCUMENTATION.get_or_init(|| { + Documentation::builder(DOC_SECTION_OTHER, $doc, $syntax) + .with_argument("g1", "geometry") + .with_related_udf("st_x") + .with_related_udf("st_y") + .build() + })) + } } - } + }; } -impl ScalarUDFImpl for PointY { - fn as_any(&self) -> &dyn Any { - self - } - - fn name(&self) -> &'static str { - "st_y" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(Float64) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - get_coord(args, 1) - } - - fn documentation(&self) -> Option<&Documentation> { - Some(DOCUMENTATION.get_or_init(|| { - Documentation::builder( - DOC_SECTION_OTHER, - "Returns the latitude (Y coordinate) of a Point represented by geometry.", - "ST_Y(geom)", - ) - .with_argument("g1", "geometry") - .build() - })) - } -} +create_point_udf!( + PointX, + "st_x", + 0, + "Returns the longitude (X coordinate) of a Point represented by geometry.", + "ST_X(geom)" +); + +create_point_udf!( + PointY, + "st_y", + 1, + "Returns the latitude (Y coordinate) of a Point represented by geometry.", + "ST_Y(geom)" +); fn get_coord(args: &[ColumnarValue], n: i64) -> Result { let array = ColumnarValue::values_to_arrays(args)? @@ -181,7 +152,7 @@ mod tests { #[test] #[allow(clippy::unwrap_used, clippy::float_cmp)] - fn test_x() { + fn test_points() { let pa = PointBuilder::from_points( [ point! {x: 4., y: 2.}, @@ -196,41 +167,23 @@ mod tests { .finish() .to_array_ref(); - let args = vec![ColumnarValue::Array(pa)]; - let x = PointX::new(); - let result = x.invoke_batch(&args, 3).unwrap(); - let result = result.to_array(3).unwrap(); - - let result = result.as_primitive::(); - assert_eq!(result.value(0), 4.0); - assert_eq!(result.value(1), 1.0); - assert_eq!(result.value(2), 2.0); - } - #[test] - #[allow(clippy::unwrap_used, clippy::float_cmp)] - fn test_y() { - let pa = PointBuilder::from_points( - [ - point! {x: 4., y: 0.}, - point! {x: 1., y: 2.}, - point! {x: 2., y: 3.}, - ] - .iter(), - Dimension::XY, - CoordType::Separated, - Arc::default(), - ) - .finish() - .to_array_ref(); - - let args = vec![ColumnarValue::Array(pa)]; - let y = PointY::new(); - let result = y.invoke_batch(&args, 3).unwrap(); - let result = result.to_array(3).unwrap(); - - let result = result.as_primitive::(); - assert_eq!(result.value(0), 0.0); - assert_eq!(result.value(1), 2.0); - assert_eq!(result.value(2), 3.0); + let results: [[f64; 3]; 2] = [[4., 1., 2.], [2., 2., 3.]]; + let udfs: Vec> = + vec![Box::new(PointX::new()), Box::new(PointY::new())]; + + for (idx, udf) in udfs.iter().enumerate() { + let result = udf + .invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Array(pa.clone())], + number_rows: 3, + return_type: &DataType::Null, + }) + .unwrap(); + let result = result.to_array(3).unwrap(); + let result = result.as_primitive::(); + assert_eq!(result.value(0), results[idx][0]); + assert_eq!(result.value(1), results[idx][1]); + assert_eq!(result.value(2), results[idx][2]); + } } } diff --git a/crates/runtime/src/datafusion/functions/geospatial/accessors/srid.rs b/crates/runtime/src/datafusion/functions/geospatial/accessors/srid.rs index 429925ad3..078400c49 100644 --- a/crates/runtime/src/datafusion/functions/geospatial/accessors/srid.rs +++ b/crates/runtime/src/datafusion/functions/geospatial/accessors/srid.rs @@ -26,6 +26,7 @@ use arrow_schema::DataType; use datafusion::logical_expr::scalar_doc_sections::DOC_SECTION_OTHER; use datafusion::logical_expr::{ColumnarValue, Documentation, ScalarUDFImpl, Signature}; use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::ScalarFunctionArgs; use geoarrow::array::AsNativeArray; use geoarrow::datatypes::NativeType; use geoarrow::trait_::ArrayAccessor; @@ -65,8 +66,8 @@ impl ScalarUDFImpl for Srid { Ok(DataType::Int32) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { - dim_impl(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + dim_impl(&args.args) } fn documentation(&self) -> Option<&Documentation> { @@ -84,7 +85,6 @@ impl ScalarUDFImpl for Srid { macro_rules! build_output_array { ($arr:expr) => {{ - print!("arr: {:?}", $arr); let mut output_array = Int32Builder::with_capacity($arr.len()); for geom in $arr.iter() { if let Some(p) = geom { @@ -136,6 +136,7 @@ mod tests { use arrow_array::types::Int32Type; use arrow_array::ArrayRef; use datafusion::logical_expr::ColumnarValue; + use datafusion_expr::ScalarFunctionArgs; use geo_types::{line_string, point, polygon}; use geoarrow::array::LineStringBuilder; use geoarrow::array::{CoordType, PointBuilder, PolygonBuilder}; @@ -186,9 +187,13 @@ mod tests { ]; for (array, exp) in args { - let args = vec![ColumnarValue::Array(array.clone())]; + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(array)], + number_rows: 2, + return_type: &DataType::Null, + }; let srid_fn = Srid::new(); - let result = srid_fn.invoke_batch(&args, 2).unwrap().to_array(2).unwrap(); + let result = srid_fn.invoke_with_args(args).unwrap().to_array(2).unwrap(); let result = result.as_primitive::(); assert_eq!(result.value(0), exp); } diff --git a/crates/runtime/src/datafusion/functions/geospatial/data_types.rs b/crates/runtime/src/datafusion/functions/geospatial/data_types.rs index bd9caf6f6..263c941e1 100644 --- a/crates/runtime/src/datafusion/functions/geospatial/data_types.rs +++ b/crates/runtime/src/datafusion/functions/geospatial/data_types.rs @@ -24,7 +24,8 @@ use arrow_array::ArrayRef; use datafusion::error::DataFusionError; use datafusion::logical_expr::{Signature, Volatility}; use geoarrow::array::{ - CoordType, GeometryArray, LineStringArray, PointArray, PolygonArray, RectArray, + CoordType, GeometryArray, GeometryCollectionArray, LineStringArray, PointArray, PolygonArray, + RectArray, }; use geoarrow::datatypes::{Dimension, NativeType}; use geoarrow::NativeArray; @@ -35,6 +36,9 @@ pub const POINT3D_TYPE: NativeType = NativeType::Point(CoordType::Separated, Dim pub const BOX2D_TYPE: NativeType = NativeType::Rect(Dimension::XY); pub const BOX3D_TYPE: NativeType = NativeType::Rect(Dimension::XYZ); pub const GEOMETRY_TYPE: NativeType = NativeType::Geometry(CoordType::Separated); +pub const GEOMETRY_COLLECTION_TYPE: NativeType = + NativeType::GeometryCollection(CoordType::Separated, Dimension::XY); + pub const LINE_STRING_TYPE: NativeType = NativeType::LineString(CoordType::Separated, Dimension::XY); pub const POLYGON_2D_TYPE: NativeType = NativeType::Polygon(CoordType::Separated, Dimension::XY); @@ -51,6 +55,7 @@ pub fn any_single_geometry_type_input() -> Signature { LINE_STRING_TYPE.into(), POLYGON_2D_TYPE.into(), GEOMETRY_TYPE.into(), + GEOMETRY_COLLECTION_TYPE.into(), ], Volatility::Immutable, ) @@ -87,6 +92,11 @@ pub fn parse_to_native_array(array: &ArrayRef) -> GeoDataFusionResult