From 36b210918c4fd0a5c2b92b7e9123ca49aa20563a Mon Sep 17 00:00:00 2001 From: Eric Schneider Date: Wed, 27 Aug 2025 10:12:59 -0500 Subject: [PATCH 1/7] Prototype for filter visualization --- .../detection/region_detections.py | 4 +- .../postprocessing/postprocessing.py | 2 +- tree_detection_framework/utils/geometric.py | 2 +- .../utils/visualization.py | 75 +++++++++++++++++++ 4 files changed, 79 insertions(+), 4 deletions(-) create mode 100644 tree_detection_framework/utils/visualization.py diff --git a/tree_detection_framework/detection/region_detections.py b/tree_detection_framework/detection/region_detections.py index 4782788..8029cc8 100644 --- a/tree_detection_framework/detection/region_detections.py +++ b/tree_detection_framework/detection/region_detections.py @@ -465,7 +465,7 @@ def merge( self, region_ID_key: Optional[str] = "region_ID", CRS: Optional[pyproj.CRS] = None, - ): + ) -> RegionDetections: """Get the merged detections across all regions with an additional field specifying which region the detection came from. @@ -478,7 +478,7 @@ def merge( be used. Defaults to None. Returns: - gpd.GeoDataFrame: Detections in the requested CRS + RegionDetections: Detections in the requested CRS """ # Get the detections from each region detection object as geodataframes diff --git a/tree_detection_framework/postprocessing/postprocessing.py b/tree_detection_framework/postprocessing/postprocessing.py index 9d10951..1a10938 100644 --- a/tree_detection_framework/postprocessing/postprocessing.py +++ b/tree_detection_framework/postprocessing/postprocessing.py @@ -630,7 +630,7 @@ def remove_masked_detections( region_detection_sets: Union[List[RegionDetectionsSet], List[PATH_TYPE]], mask_iterator: Iterable[np.ndarray], threshold: float = 0.4, -) -> List[RegionDetectionsSet]: +) -> Union[List[RegionDetectionsSet], List[gpd.GeoDataFrame]]: """ Filters out detections that marked as invalid in the given masks. diff --git a/tree_detection_framework/utils/geometric.py b/tree_detection_framework/utils/geometric.py index 6912f2e..73ab452 100644 --- a/tree_detection_framework/utils/geometric.py +++ b/tree_detection_framework/utils/geometric.py @@ -115,7 +115,7 @@ def ellipse_mask( center (Tuple[int, int]): (x0, y0) center of the ellipse in x and y (pixels) axes (Tuple[int, int]): - (a, b) semi-major and semi-minor axis lengths in pixels + (a, b) x and y axis lengths (before rotation) in pixels angle_rad (float): Rotation angle of the semi-major axis, in radians. CCW from x-axis (a.k.a. right-hand rule out of the image). Defaults to 0. diff --git a/tree_detection_framework/utils/visualization.py b/tree_detection_framework/utils/visualization.py new file mode 100644 index 0000000..c1d9700 --- /dev/null +++ b/tree_detection_framework/utils/visualization.py @@ -0,0 +1,75 @@ +from typing import List, Optional, Union + +import geopandas as gpd +import numpy as np +from matplotlib import colormaps +from PIL import Image, ImageDraw + +from tree_detection_framework.constants import PATH_TYPE +from tree_detection_framework.detection.region_detections import RegionDetectionsSet + + +def show_filtered_detections( + impath: PATH_TYPE, + detection1: Union[RegionDetectionsSet, PATH_TYPE], + detection2: Union[RegionDetectionsSet, PATH_TYPE], + mask: np.ndarray, + mask_colormap: Optional[dict] = None, +) -> List[np.ndarray]: + """ + Arguments: + + Returns: + """ + + # TODO + image = Image.open(impath).convert("RGBA") + + # TODO + def to_gdf(detection): + """Helper to unify the two input types""" + if isinstance(detection, RegionDetectionsSet): + return detection.merge().detections + else: + return gpd.read_file(detection) + gdf1 = to_gdf(detection1) + gdf2 = to_gdf(detection2) + + # TODO + detection_overlay = Image.new("RGBA", image.size, (0, 0, 0, 0)) + detection_draw = ImageDraw.Draw(detection_overlay, "RGBA") + for idx, row in gdf1.iterrows(): + has_match = row["unique_ID"] in gdf2["unique_ID"] + color = (0, 255, 0, 70) if has_match else (255, 0, 0, 70) + if row.geometry.geom_type == "Polygon": + polygons = [row.geometry] + elif row.geometry.geom_type == "MultiPolygon": + polygons = list(row.geometry.geoms) + else: + raise NotImplementedError(f"Can't handle {type(row.geometry)}") + for poly in polygons: + # Convert polygon to pixel coordinates + coords = list(poly.exterior.coords) + # coords are (x, y), but PIL expects (col, row) so we're good + detection_draw.polygon(coords, outline=(0, 0, 0, 255), fill=color) + + # TODO + mask_overlay = np.zeros((image.height, image.width, 4), dtype=np.uint8) + for value in np.unique(mask): + if mask_colormap is None: + color = (np.array(colormaps["tab20"](value)) * 255).astype(np.uint8) + # Set the alpha channel to partially transparent + color[3] = 100 + else: + color = mask_colormap[value] + mask_overlay[mask == value] = color + mask_overlay = Image.fromarray(mask_overlay, mode="RGBA") + + # TODO + def to_ndarray(overlay): + combo = Image.alpha_composite(image, overlay) + return np.asarray(combo.convert("RGB")) + return ( + to_ndarray(detection_overlay), + to_ndarray(mask_overlay), + ) \ No newline at end of file From 00dbdbe0021930ba7d0f15601c243d4b78a74ab1 Mon Sep 17 00:00:00 2001 From: Eric Schneider Date: Wed, 27 Aug 2025 16:34:43 +0000 Subject: [PATCH 2/7] Fix empty GDF case --- tests/test_postprocessing.py | 21 +++++++++++++++++++ .../postprocessing/postprocessing.py | 21 +++++++++++-------- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/tests/test_postprocessing.py b/tests/test_postprocessing.py index 3adeff6..cc75cec 100644 --- a/tests/test_postprocessing.py +++ b/tests/test_postprocessing.py @@ -259,3 +259,24 @@ def test_basic(self, tmp_path, flat, ellipse_radius, use_rds, geo_extension): expected_map = {10: [2], 30: [2, 3, 4, 5, 6], 50: [2, 3, 4, 5, 6, 7, 8, 9, 10]} for rds in filtered_detection_list: check_expected_rds(rds, use_rds, expected_map[ellipse_radius]) + + @pytest.mark.parametrize( + "use_rds,geo_extension", + ([True, None], [False, ".gpkg"], [False, ".geojson"], [False, ".shp"]), + ) + def test_empty(self, tmp_path, use_rds, geo_extension): + + # Make an empty set of detections + detection_list = get_detections([], use_rds, tmp_path, geo_extension) + + # Create a mask + mask = np.ones(shape=(150, 150), dtype=bool) + + # Filter the empty detections + filtered_detection_list = remove_masked_detections( + region_detection_sets=detection_list, + mask_iterator=repeat(mask), + threshold=0.5, + ) + for rds in filtered_detection_list: + check_expected_rds(rds, use_rds, []) diff --git a/tree_detection_framework/postprocessing/postprocessing.py b/tree_detection_framework/postprocessing/postprocessing.py index 1a10938..41add00 100644 --- a/tree_detection_framework/postprocessing/postprocessing.py +++ b/tree_detection_framework/postprocessing/postprocessing.py @@ -699,15 +699,18 @@ def iterator(rds): # over the chips that detections were calculated in for idx, gdf in iterator(rds): - # Get the mean value of each detection polygon, as a list of - # [{"mean": }, ...] - stats = rasterstats.zonal_stats(gdf, mask, stats=["mean"], affine=transform) - - # Get indices that have a greater fraction of good pixels than the - # threshold requires. - good_indices = [ - i for i, stat in enumerate(stats) if stat["mean"] > threshold - ] + if len(gdf) == 0: + # Catch the case when there are no initial detections + good_indices = [] + else: + # Get the mean value of each detection polygon, as a list of + # [{"mean": }, ...] + stats = rasterstats.zonal_stats(gdf, mask, stats=["mean"], affine=transform) + # Get indices that have a greater fraction of good pixels than the + # threshold requires. + good_indices = [ + i for i, stat in enumerate(stats) if stat["mean"] > threshold + ] if isinstance(rds, RegionDetectionsSet): # Subset the RegionDetections object keeping only the valid indices From 68d6c0b92fd45ea4bd7431f5c7a5be918078efee Mon Sep 17 00:00:00 2001 From: Eric Schneider Date: Wed, 27 Aug 2025 17:56:03 +0000 Subject: [PATCH 3/7] Use and test visualization --- tests/test_visualization.py | 102 ++++++++++++++++++ .../utils/visualization.py | 44 ++++++-- 2 files changed, 135 insertions(+), 11 deletions(-) create mode 100644 tests/test_visualization.py diff --git a/tests/test_visualization.py b/tests/test_visualization.py new file mode 100644 index 0000000..3d5ec68 --- /dev/null +++ b/tests/test_visualization.py @@ -0,0 +1,102 @@ +import geopandas as gpd +import numpy as np +import pandas as pd +import pytest +from matplotlib import colormaps +from PIL import Image +from shapely.geometry import Polygon + +from tree_detection_framework.detection.region_detections import RegionDetectionsSet +from tree_detection_framework.utils.visualization import show_filtered_detections + + +# Helper to create a dummy image file +def save_dummy_image(imdir, size=(20, 20), color=(255, 255, 255)): + img = Image.new("RGB", size, color) + path = imdir / "test_image.png" + img.save(path) + return path + + +def save_dummy_gdf(gdir, N): + # Create DataFrame + df = pd.DataFrame( + { + "unique_ID": list(range(N)), + "geometry": [ + Polygon( + [ + (3 * i, 3 * i), + (3 * i + 2, 3 * i), + (3 * i + 2, 3 * i + 2), + (3 * i, 3 * i + 2), + ] + ) + for i in range(N) + ], + } + ) + # Save GeoDataFrame + gdf = gpd.GeoDataFrame(df, geometry="geometry") + path = gdir / f"test_gdf_{N}.gpkg" + gdf.to_file(path) + return path + + +class TestShowFilteredDetections: + + @pytest.mark.parametrize("mask_dtype", [int, bool]) + @pytest.mark.parametrize( + "mask_colormap", + [ + None, + { + 0: (np.array(colormaps["tab20"](0)) * 255).astype(np.uint8), + 1: (np.array(colormaps["tab20"](1)) * 255).astype(np.uint8), + } + ], + ) + def test_basic_functionality(self, tmp_path, mask_dtype, mask_colormap): + + # Create dummy files + impath = save_dummy_image(tmp_path) + gdf1 = save_dummy_gdf(tmp_path, 3) + gdf2 = save_dummy_gdf(tmp_path, 2) + + # Create mask + mask = np.zeros((20, 20), dtype=mask_dtype) + mask[:10, :10] = 1 + + # Run function + det_img, mask_img = show_filtered_detections( + impath=impath, + detection1=gdf1, + detection2=gdf2, + mask=mask, + mask_colormap=mask_colormap, + ) + assert det_img.shape == (20, 20, 3) + assert mask_img.shape == (20, 20, 3) + + # Spot check a couple of areas. We know there were three detections, two + # matches and a no-match + greenish = det_img[[1, 4], [1, 4]] + assert np.all(greenish[:, 1] > greenish[:, 0]) + assert np.all(greenish[:, 1] > greenish[:, 2]) + reddish = det_img[[7], [7]] + assert np.all(reddish[:, 0] > reddish[:, 1]) + assert np.all(reddish[:, 0] > reddish[:, 2]) + + # Check the mask map in a few spots + if mask_colormap is None: + # In this case the color is alpha blended with the image (white) so it + # won't be an exact match + assert np.argmax(mask_img[-1, -1]) == np.argmax(colormaps["tab20"](0)[:3]) + assert np.argmax(mask_img[0, 0]) == np.argmax(colormaps["tab20"](1)[:3]) + # The tab20[1] should be brighter + assert np.sum(mask_img[0, 0]) > np.sum(mask_img[-1, -1]) + else: + assert np.allclose(mask_img[-5:, -5:].reshape(-1, 3), mask_colormap[0][:3]) + assert np.allclose(mask_img[:5, :5].reshape(-1, 3), mask_colormap[1][:3]) + + assert np.any(mask_img != 255) # Should have some non-white pixels diff --git a/tree_detection_framework/utils/visualization.py b/tree_detection_framework/utils/visualization.py index c1d9700..f278dd1 100644 --- a/tree_detection_framework/utils/visualization.py +++ b/tree_detection_framework/utils/visualization.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union +from typing import Optional, Tuple, Union import geopandas as gpd import numpy as np @@ -15,17 +15,37 @@ def show_filtered_detections( detection2: Union[RegionDetectionsSet, PATH_TYPE], mask: np.ndarray, mask_colormap: Optional[dict] = None, -) -> List[np.ndarray]: +) -> Tuple[np.ndarray]: """ + Visualizes a full set of detections (detection1) against a filtered set of + detections (detection2) and the mask that acted as the filter. + Arguments: + impath (PATH_TYPE): Path to an (M, N, 3) RGB image. + detection1 (Union[RegionDetectionsSet, PATH_TYPE]): RegionDetectionsSet + derived from a specific drone image, or a geospatial file containing + the detections from a drone image. + detection2 (Union[RegionDetectionsSet, PATH_TYPE]): Same as detection1, + but this is assumed to be a filtered version (subset) of detection1. + mask (np.ndarray): (M, N) array of masked areas. Could be [True, False] + or it could have difference mask values per area, such as [0, 1, 2] + where each value means something like [invalid, ground, tree]. + mask_colormap (Optional[dict]): If None, the matplotlib tab20 color map + is applied to the mask values. A.k.a. a mask value of 1 is given a + color of tab20(1). If a dict is given, it should be of the form + {mask value: (4 element RGBA tuple, 0-255)} + Note that if you use the dict you can leave out certain mask values, + and they will be left uncolored. Defaults to None. - Returns: + Returns: Tuple of two images: + [0] Image with detections visualized + [1] Image with the mask visualized """ - # TODO + # Get an alpha-channel image image = Image.open(impath).convert("RGBA") - # TODO + # Load the given detection types as geopandas dataframes def to_gdf(detection): """Helper to unify the two input types""" if isinstance(detection, RegionDetectionsSet): @@ -35,11 +55,12 @@ def to_gdf(detection): gdf1 = to_gdf(detection1) gdf2 = to_gdf(detection2) - # TODO + # Draw the detections, coloring each detection in gdf1 based on whether + # it exists in gdf2 detection_overlay = Image.new("RGBA", image.size, (0, 0, 0, 0)) detection_draw = ImageDraw.Draw(detection_overlay, "RGBA") for idx, row in gdf1.iterrows(): - has_match = row["unique_ID"] in gdf2["unique_ID"] + has_match = row["unique_ID"] in gdf2["unique_ID"].values color = (0, 255, 0, 70) if has_match else (255, 0, 0, 70) if row.geometry.geom_type == "Polygon": polygons = [row.geometry] @@ -53,7 +74,7 @@ def to_gdf(detection): # coords are (x, y), but PIL expects (col, row) so we're good detection_draw.polygon(coords, outline=(0, 0, 0, 255), fill=color) - # TODO + # Create a colored overlay based on the given mask mask_overlay = np.zeros((image.height, image.width, 4), dtype=np.uint8) for value in np.unique(mask): if mask_colormap is None: @@ -61,11 +82,12 @@ def to_gdf(detection): # Set the alpha channel to partially transparent color[3] = 100 else: - color = mask_colormap[value] - mask_overlay[mask == value] = color + color = mask_colormap.get(value, None) + if color is not None: + mask_overlay[mask == value] = color mask_overlay = Image.fromarray(mask_overlay, mode="RGBA") - # TODO + # Return RGB image arrays def to_ndarray(overlay): combo = Image.alpha_composite(image, overlay) return np.asarray(combo.convert("RGB")) From 2352b42ec175c7931d4d1ecf9093b1c3b692a951 Mon Sep 17 00:00:00 2001 From: Eric Schneider Date: Wed, 27 Aug 2025 18:05:17 +0000 Subject: [PATCH 4/7] Formatting fixes --- tests/test_visualization.py | 2 +- tree_detection_framework/utils/visualization.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 3d5ec68..b806849 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -53,7 +53,7 @@ class TestShowFilteredDetections: { 0: (np.array(colormaps["tab20"](0)) * 255).astype(np.uint8), 1: (np.array(colormaps["tab20"](1)) * 255).astype(np.uint8), - } + }, ], ) def test_basic_functionality(self, tmp_path, mask_dtype, mask_colormap): diff --git a/tree_detection_framework/utils/visualization.py b/tree_detection_framework/utils/visualization.py index f278dd1..df9ec0d 100644 --- a/tree_detection_framework/utils/visualization.py +++ b/tree_detection_framework/utils/visualization.py @@ -52,6 +52,7 @@ def to_gdf(detection): return detection.merge().detections else: return gpd.read_file(detection) + gdf1 = to_gdf(detection1) gdf2 = to_gdf(detection2) @@ -91,7 +92,8 @@ def to_gdf(detection): def to_ndarray(overlay): combo = Image.alpha_composite(image, overlay) return np.asarray(combo.convert("RGB")) + return ( to_ndarray(detection_overlay), to_ndarray(mask_overlay), - ) \ No newline at end of file + ) From d20b8961680fcc46afbfb7e5c25e0ec179399358 Mon Sep 17 00:00:00 2001 From: FranzEricSchneider Date: Wed, 27 Aug 2025 22:01:04 +0000 Subject: [PATCH 5/7] Apply black formatting changes --- tree_detection_framework/postprocessing/postprocessing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tree_detection_framework/postprocessing/postprocessing.py b/tree_detection_framework/postprocessing/postprocessing.py index 41add00..a86e2aa 100644 --- a/tree_detection_framework/postprocessing/postprocessing.py +++ b/tree_detection_framework/postprocessing/postprocessing.py @@ -705,7 +705,9 @@ def iterator(rds): else: # Get the mean value of each detection polygon, as a list of # [{"mean": }, ...] - stats = rasterstats.zonal_stats(gdf, mask, stats=["mean"], affine=transform) + stats = rasterstats.zonal_stats( + gdf, mask, stats=["mean"], affine=transform + ) # Get indices that have a greater fraction of good pixels than the # threshold requires. good_indices = [ From 2b0e04de1b927a991c17d7872354ef24290760c0 Mon Sep 17 00:00:00 2001 From: Eric Schneider Date: Fri, 29 Aug 2025 15:00:48 -0500 Subject: [PATCH 6/7] Replace NotImplemented with warning per PR --- tree_detection_framework/utils/visualization.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tree_detection_framework/utils/visualization.py b/tree_detection_framework/utils/visualization.py index df9ec0d..fa8ad03 100644 --- a/tree_detection_framework/utils/visualization.py +++ b/tree_detection_framework/utils/visualization.py @@ -1,4 +1,5 @@ from typing import Optional, Tuple, Union +import warnings import geopandas as gpd import numpy as np @@ -68,7 +69,11 @@ def to_gdf(detection): elif row.geometry.geom_type == "MultiPolygon": polygons = list(row.geometry.geoms) else: - raise NotImplementedError(f"Can't handle {type(row.geometry)}") + warnings.warn( + f"show_filtered_detections found geometry {type(row.geometry)}" + " and was unable to display it", + category=UserWarning, + ) for poly in polygons: # Convert polygon to pixel coordinates coords = list(poly.exterior.coords) From 0753a6a97518c04b4da271d3356cac8be2cd2caa Mon Sep 17 00:00:00 2001 From: FranzEricSchneider <3709527+FranzEricSchneider@users.noreply.github.com> Date: Fri, 29 Aug 2025 20:01:16 +0000 Subject: [PATCH 7/7] Apply isort formatting changes --- tree_detection_framework/utils/visualization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tree_detection_framework/utils/visualization.py b/tree_detection_framework/utils/visualization.py index fa8ad03..3682e3f 100644 --- a/tree_detection_framework/utils/visualization.py +++ b/tree_detection_framework/utils/visualization.py @@ -1,5 +1,5 @@ -from typing import Optional, Tuple, Union import warnings +from typing import Optional, Tuple, Union import geopandas as gpd import numpy as np