Skip to content

Commit 21265cb

Browse files
authored
Add ST_Contains (#327)
* Add ST_MakePolygon, st_polygon * Add docs * Add ST_dimention * Add license * Add ST_Endpoint, ST_PointN * Reimplement ST_Point to work with snowflake linestring * Add St_x, st_y, st_srid * St_min max for x,y * Add macro for end and start points * Add ST_Contains
1 parent d59876f commit 21265cb

File tree

4 files changed

+276
-14
lines changed

4 files changed

+276
-14
lines changed

crates/runtime/src/datafusion/functions/geospatial/data_types.rs

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use crate::datafusion::functions::geospatial::error::{
2121
self as geo_error, GeoDataFusionError, GeoDataFusionResult,
2222
};
2323
use arrow_array::ArrayRef;
24+
use arrow_schema::DataType;
2425
use datafusion::error::DataFusionError;
2526
use datafusion::logical_expr::{Signature, Volatility};
2627
use geoarrow::array::{
@@ -45,20 +46,21 @@ pub const POLYGON_2D_TYPE: NativeType = NativeType::Polygon(CoordType::Separated
4546

4647
#[must_use]
4748
pub fn any_single_geometry_type_input() -> Signature {
48-
Signature::uniform(
49-
1,
50-
vec![
51-
POINT2D_TYPE.into(),
52-
POINT3D_TYPE.into(),
53-
BOX2D_TYPE.into(),
54-
BOX3D_TYPE.into(),
55-
LINE_STRING_TYPE.into(),
56-
POLYGON_2D_TYPE.into(),
57-
GEOMETRY_TYPE.into(),
58-
GEOMETRY_COLLECTION_TYPE.into(),
59-
],
60-
Volatility::Immutable,
61-
)
49+
Signature::uniform(1, geo_types(), Volatility::Immutable)
50+
}
51+
52+
#[must_use]
53+
pub fn geo_types() -> Vec<DataType> {
54+
vec![
55+
POINT2D_TYPE.into(),
56+
POINT3D_TYPE.into(),
57+
BOX2D_TYPE.into(),
58+
BOX3D_TYPE.into(),
59+
LINE_STRING_TYPE.into(),
60+
POLYGON_2D_TYPE.into(),
61+
GEOMETRY_TYPE.into(),
62+
GEOMETRY_COLLECTION_TYPE.into(),
63+
]
6264
}
6365

6466
/// This will not cast a `PointArray` to a `GeometryArray`
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use crate::datafusion::functions::geospatial::data_types::parse_to_native_array;
19+
use arrow_schema::DataType;
20+
use datafusion::logical_expr::scalar_doc_sections::DOC_SECTION_OTHER;
21+
use datafusion::logical_expr::{
22+
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
23+
};
24+
use datafusion_common::{DataFusionError, Result};
25+
use datafusion_expr::ScalarFunctionArgs;
26+
use geoarrow::algorithm::geo::Contains as ContainsTrait;
27+
use geoarrow::array::AsNativeArray;
28+
use geoarrow::datatypes::NativeType;
29+
use std::any::Any;
30+
use std::sync::{Arc, OnceLock};
31+
32+
#[derive(Debug)]
33+
pub struct Contains {
34+
signature: Signature,
35+
}
36+
37+
impl Contains {
38+
#[must_use]
39+
pub fn new() -> Self {
40+
Self {
41+
signature: Signature::any(2, Volatility::Immutable),
42+
}
43+
}
44+
}
45+
46+
static DOCUMENTATION: OnceLock<Documentation> = OnceLock::new();
47+
48+
impl ScalarUDFImpl for Contains {
49+
fn as_any(&self) -> &dyn Any {
50+
self
51+
}
52+
53+
fn name(&self) -> &'static str {
54+
"st_contains"
55+
}
56+
57+
fn signature(&self) -> &Signature {
58+
&self.signature
59+
}
60+
61+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
62+
Ok(DataType::Boolean)
63+
}
64+
65+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
66+
contains(&args.args)
67+
}
68+
69+
fn documentation(&self) -> Option<&Documentation> {
70+
Some(DOCUMENTATION.get_or_init(|| {
71+
Documentation::builder(
72+
DOC_SECTION_OTHER,
73+
"Returns TRUE the geometry object is completely inside another object of the same type.",
74+
"ST_Contains(g1, g2)",
75+
)
76+
.with_argument("g1", "geometry")
77+
.with_argument("g2", "geometry")
78+
.with_related_udf("st_within")
79+
.with_related_udf("st_covers")
80+
.build()
81+
}))
82+
}
83+
}
84+
85+
macro_rules! match_rhs_data_type {
86+
($left:expr, $left_method:ident, $rhs:expr) => {
87+
match $rhs.data_type() {
88+
NativeType::Point(_, _) => {
89+
ContainsTrait::contains($left.$left_method(), $rhs.as_point())
90+
}
91+
NativeType::LineString(_, _) => {
92+
ContainsTrait::contains($left.$left_method(), $rhs.as_line_string())
93+
}
94+
NativeType::Polygon(_, _) => {
95+
ContainsTrait::contains($left.$left_method(), $rhs.as_polygon())
96+
}
97+
NativeType::MultiPoint(_, _) => {
98+
ContainsTrait::contains($left.$left_method(), $rhs.as_multi_point())
99+
}
100+
NativeType::MultiLineString(_, _) => {
101+
ContainsTrait::contains($left.$left_method(), $rhs.as_multi_line_string())
102+
}
103+
NativeType::MultiPolygon(_, _) => {
104+
ContainsTrait::contains($left.$left_method(), $rhs.as_multi_polygon())
105+
}
106+
_ => {
107+
return Err(DataFusionError::Execution(
108+
"ST_Contains does not support this rhs geometry type".to_string(),
109+
))
110+
}
111+
}
112+
};
113+
}
114+
115+
fn contains(args: &[ColumnarValue]) -> Result<ColumnarValue> {
116+
let array = ColumnarValue::values_to_arrays(args)?;
117+
if array.len() > 2 {
118+
return Err(DataFusionError::Execution(
119+
"ST_Contains takes two arguments".to_string(),
120+
));
121+
}
122+
123+
let left = parse_to_native_array(&array[0])?;
124+
let left = left.as_ref();
125+
let rhs = parse_to_native_array(&array[1])?;
126+
let rhs = rhs.as_ref();
127+
128+
let result = match left.data_type() {
129+
NativeType::Point(_, _) => match_rhs_data_type!(left, as_point, rhs),
130+
NativeType::LineString(_, _) => match_rhs_data_type!(left, as_line_string, rhs),
131+
NativeType::Polygon(_, _) => match_rhs_data_type!(left, as_polygon, rhs),
132+
NativeType::MultiPoint(_, _) => match_rhs_data_type!(left, as_multi_point, rhs),
133+
NativeType::MultiLineString(_, _) => match_rhs_data_type!(left, as_multi_line_string, rhs),
134+
NativeType::MultiPolygon(_, _) => match_rhs_data_type!(left, as_multi_polygon, rhs),
135+
_ => {
136+
return Err(DataFusionError::Execution(
137+
"ST_Contains does not support this left geometry type".to_string(),
138+
))
139+
}
140+
};
141+
Ok(ColumnarValue::Array(Arc::new(result)))
142+
}
143+
144+
#[cfg(test)]
145+
mod tests {
146+
use super::*;
147+
use arrow_array::cast::AsArray;
148+
use arrow_array::ArrayRef;
149+
use datafusion::logical_expr::ColumnarValue;
150+
use geo_types::{line_string, point, polygon};
151+
use geoarrow::array::LineStringBuilder;
152+
use geoarrow::array::{CoordType, PointBuilder, PolygonBuilder};
153+
use geoarrow::datatypes::Dimension;
154+
use geoarrow::ArrayBase;
155+
156+
#[test]
157+
#[allow(clippy::unwrap_used)]
158+
fn test_contains() {
159+
let dim = Dimension::XY;
160+
let ct = CoordType::Separated;
161+
162+
let args: [(ArrayRef, ArrayRef, [bool; 2]); 3] = [
163+
(
164+
{
165+
let data = vec![
166+
line_string![(x: 0., y: 0.), (x: 1., y: 2.), (x: 1., y: 1.), (x: 0., y: 1.), (x: 0., y: 0.)],
167+
line_string![(x: 2., y: 2.), (x: 3., y: 2.), (x: 3., y: 3.), (x: 2., y: 3.), (x: 2., y: 2.)],
168+
];
169+
let array =
170+
LineStringBuilder::from_line_strings(&data, dim, ct, Arc::default())
171+
.finish();
172+
array.to_array_ref()
173+
},
174+
{
175+
let data = [point! {x: 0., y: 0.}, point! {x: 1., y: 1.}];
176+
let array =
177+
PointBuilder::from_points(data.iter(), dim, ct, Arc::default()).finish();
178+
array.to_array_ref()
179+
},
180+
[true, false],
181+
),
182+
(
183+
{
184+
let data = [point! {x: 0., y: 0.}, point! {x: 1., y: 1.}];
185+
let array =
186+
PointBuilder::from_points(data.iter(), dim, ct, Arc::default()).finish();
187+
array.to_array_ref()
188+
},
189+
{
190+
let data = [point! {x: 0., y: 0.}, point! {x: 0., y: 0.}];
191+
let array =
192+
PointBuilder::from_points(data.iter(), dim, ct, Arc::default()).finish();
193+
array.to_array_ref()
194+
},
195+
[true, false],
196+
),
197+
(
198+
{
199+
let data = vec![
200+
polygon![(x: 3.3, y: 30.5), (x: 1.7, y: 24.6), (x: 13.4, y: 25.1), (x: 14.4, y: 31.0),(x:3.3,y:30.4)],
201+
polygon![(x: 3.3, y: 30.4), (x: 1.7, y: 24.6), (x: 13.4, y: 25.1), (x: 14.4, y: 31.0),(x:3.3,y:30.4)],
202+
];
203+
let array =
204+
PolygonBuilder::from_polygons(&data, dim, ct, Arc::default()).finish();
205+
array.to_array_ref()
206+
},
207+
{
208+
let data = [point! {x: 7.9, y: 28.4}, point! {x: 0., y: 0.}];
209+
let array =
210+
PointBuilder::from_points(data.iter(), dim, ct, Arc::default()).finish();
211+
array.to_array_ref()
212+
},
213+
[true, false],
214+
),
215+
];
216+
217+
for (left, rhs, exp) in args {
218+
let args = ScalarFunctionArgs {
219+
args: vec![ColumnarValue::Array(left), ColumnarValue::Array(rhs)],
220+
number_rows: 2,
221+
return_type: &DataType::Null,
222+
};
223+
let contains = Contains::new();
224+
let result = contains
225+
.invoke_with_args(args)
226+
.unwrap()
227+
.to_array(2)
228+
.unwrap();
229+
let result = result.as_boolean();
230+
assert_eq!(result.value(0), exp[0]);
231+
assert_eq!(result.value(1), exp[1]);
232+
}
233+
}
234+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
mod contains;
19+
20+
use datafusion::prelude::SessionContext;
21+
22+
pub fn register_udfs(ctx: &SessionContext) {
23+
ctx.register_udf(contains::Contains::new().into());
24+
}

crates/runtime/src/datafusion/functions/geospatial/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@ pub mod accessors;
1919
pub mod constructors;
2020
pub mod data_types;
2121
pub mod error;
22+
mod measurement;
2223

2324
use datafusion::prelude::SessionContext;
2425

2526
pub fn register_udfs(ctx: &SessionContext) {
2627
constructors::register_udfs(ctx);
2728
accessors::register_udfs(ctx);
29+
measurement::register_udfs(ctx);
2830
}

0 commit comments

Comments
 (0)