diff --git a/tests/test_postprocessing.py b/tests/test_postprocessing.py index 3adeff6..cc75cec 100644 --- a/tests/test_postprocessing.py +++ b/tests/test_postprocessing.py @@ -259,3 +259,24 @@ def test_basic(self, tmp_path, flat, ellipse_radius, use_rds, geo_extension): expected_map = {10: [2], 30: [2, 3, 4, 5, 6], 50: [2, 3, 4, 5, 6, 7, 8, 9, 10]} for rds in filtered_detection_list: check_expected_rds(rds, use_rds, expected_map[ellipse_radius]) + + @pytest.mark.parametrize( + "use_rds,geo_extension", + ([True, None], [False, ".gpkg"], [False, ".geojson"], [False, ".shp"]), + ) + def test_empty(self, tmp_path, use_rds, geo_extension): + + # Make an empty set of detections + detection_list = get_detections([], use_rds, tmp_path, geo_extension) + + # Create a mask + mask = np.ones(shape=(150, 150), dtype=bool) + + # Filter the empty detections + filtered_detection_list = remove_masked_detections( + region_detection_sets=detection_list, + mask_iterator=repeat(mask), + threshold=0.5, + ) + for rds in filtered_detection_list: + check_expected_rds(rds, use_rds, []) diff --git a/tests/test_visualization.py b/tests/test_visualization.py new file mode 100644 index 0000000..b806849 --- /dev/null +++ b/tests/test_visualization.py @@ -0,0 +1,102 @@ +import geopandas as gpd +import numpy as np +import pandas as pd +import pytest +from matplotlib import colormaps +from PIL import Image +from shapely.geometry import Polygon + +from tree_detection_framework.detection.region_detections import RegionDetectionsSet +from tree_detection_framework.utils.visualization import show_filtered_detections + + +# Helper to create a dummy image file +def save_dummy_image(imdir, size=(20, 20), color=(255, 255, 255)): + img = Image.new("RGB", size, color) + path = imdir / "test_image.png" + img.save(path) + return path + + +def save_dummy_gdf(gdir, N): + # Create DataFrame + df = pd.DataFrame( + { + "unique_ID": list(range(N)), + "geometry": [ + Polygon( + [ + (3 * i, 3 * i), + (3 * i + 2, 3 * i), + (3 * i + 2, 3 * i + 2), + (3 * i, 3 * i + 2), + ] + ) + for i in range(N) + ], + } + ) + # Save GeoDataFrame + gdf = gpd.GeoDataFrame(df, geometry="geometry") + path = gdir / f"test_gdf_{N}.gpkg" + gdf.to_file(path) + return path + + +class TestShowFilteredDetections: + + @pytest.mark.parametrize("mask_dtype", [int, bool]) + @pytest.mark.parametrize( + "mask_colormap", + [ + None, + { + 0: (np.array(colormaps["tab20"](0)) * 255).astype(np.uint8), + 1: (np.array(colormaps["tab20"](1)) * 255).astype(np.uint8), + }, + ], + ) + def test_basic_functionality(self, tmp_path, mask_dtype, mask_colormap): + + # Create dummy files + impath = save_dummy_image(tmp_path) + gdf1 = save_dummy_gdf(tmp_path, 3) + gdf2 = save_dummy_gdf(tmp_path, 2) + + # Create mask + mask = np.zeros((20, 20), dtype=mask_dtype) + mask[:10, :10] = 1 + + # Run function + det_img, mask_img = show_filtered_detections( + impath=impath, + detection1=gdf1, + detection2=gdf2, + mask=mask, + mask_colormap=mask_colormap, + ) + assert det_img.shape == (20, 20, 3) + assert mask_img.shape == (20, 20, 3) + + # Spot check a couple of areas. We know there were three detections, two + # matches and a no-match + greenish = det_img[[1, 4], [1, 4]] + assert np.all(greenish[:, 1] > greenish[:, 0]) + assert np.all(greenish[:, 1] > greenish[:, 2]) + reddish = det_img[[7], [7]] + assert np.all(reddish[:, 0] > reddish[:, 1]) + assert np.all(reddish[:, 0] > reddish[:, 2]) + + # Check the mask map in a few spots + if mask_colormap is None: + # In this case the color is alpha blended with the image (white) so it + # won't be an exact match + assert np.argmax(mask_img[-1, -1]) == np.argmax(colormaps["tab20"](0)[:3]) + assert np.argmax(mask_img[0, 0]) == np.argmax(colormaps["tab20"](1)[:3]) + # The tab20[1] should be brighter + assert np.sum(mask_img[0, 0]) > np.sum(mask_img[-1, -1]) + else: + assert np.allclose(mask_img[-5:, -5:].reshape(-1, 3), mask_colormap[0][:3]) + assert np.allclose(mask_img[:5, :5].reshape(-1, 3), mask_colormap[1][:3]) + + assert np.any(mask_img != 255) # Should have some non-white pixels diff --git a/tree_detection_framework/detection/region_detections.py b/tree_detection_framework/detection/region_detections.py index 4782788..8029cc8 100644 --- a/tree_detection_framework/detection/region_detections.py +++ b/tree_detection_framework/detection/region_detections.py @@ -465,7 +465,7 @@ def merge( self, region_ID_key: Optional[str] = "region_ID", CRS: Optional[pyproj.CRS] = None, - ): + ) -> RegionDetections: """Get the merged detections across all regions with an additional field specifying which region the detection came from. @@ -478,7 +478,7 @@ def merge( be used. Defaults to None. Returns: - gpd.GeoDataFrame: Detections in the requested CRS + RegionDetections: Detections in the requested CRS """ # Get the detections from each region detection object as geodataframes diff --git a/tree_detection_framework/postprocessing/postprocessing.py b/tree_detection_framework/postprocessing/postprocessing.py index 9d10951..a86e2aa 100644 --- a/tree_detection_framework/postprocessing/postprocessing.py +++ b/tree_detection_framework/postprocessing/postprocessing.py @@ -630,7 +630,7 @@ def remove_masked_detections( region_detection_sets: Union[List[RegionDetectionsSet], List[PATH_TYPE]], mask_iterator: Iterable[np.ndarray], threshold: float = 0.4, -) -> List[RegionDetectionsSet]: +) -> Union[List[RegionDetectionsSet], List[gpd.GeoDataFrame]]: """ Filters out detections that marked as invalid in the given masks. @@ -699,15 +699,20 @@ def iterator(rds): # over the chips that detections were calculated in for idx, gdf in iterator(rds): - # Get the mean value of each detection polygon, as a list of - # [{"mean": }, ...] - stats = rasterstats.zonal_stats(gdf, mask, stats=["mean"], affine=transform) - - # Get indices that have a greater fraction of good pixels than the - # threshold requires. - good_indices = [ - i for i, stat in enumerate(stats) if stat["mean"] > threshold - ] + if len(gdf) == 0: + # Catch the case when there are no initial detections + good_indices = [] + else: + # Get the mean value of each detection polygon, as a list of + # [{"mean": }, ...] + stats = rasterstats.zonal_stats( + gdf, mask, stats=["mean"], affine=transform + ) + # Get indices that have a greater fraction of good pixels than the + # threshold requires. + good_indices = [ + i for i, stat in enumerate(stats) if stat["mean"] > threshold + ] if isinstance(rds, RegionDetectionsSet): # Subset the RegionDetections object keeping only the valid indices diff --git a/tree_detection_framework/utils/geometric.py b/tree_detection_framework/utils/geometric.py index 6912f2e..73ab452 100644 --- a/tree_detection_framework/utils/geometric.py +++ b/tree_detection_framework/utils/geometric.py @@ -115,7 +115,7 @@ def ellipse_mask( center (Tuple[int, int]): (x0, y0) center of the ellipse in x and y (pixels) axes (Tuple[int, int]): - (a, b) semi-major and semi-minor axis lengths in pixels + (a, b) x and y axis lengths (before rotation) in pixels angle_rad (float): Rotation angle of the semi-major axis, in radians. CCW from x-axis (a.k.a. right-hand rule out of the image). Defaults to 0. diff --git a/tree_detection_framework/utils/visualization.py b/tree_detection_framework/utils/visualization.py new file mode 100644 index 0000000..3682e3f --- /dev/null +++ b/tree_detection_framework/utils/visualization.py @@ -0,0 +1,104 @@ +import warnings +from typing import Optional, Tuple, Union + +import geopandas as gpd +import numpy as np +from matplotlib import colormaps +from PIL import Image, ImageDraw + +from tree_detection_framework.constants import PATH_TYPE +from tree_detection_framework.detection.region_detections import RegionDetectionsSet + + +def show_filtered_detections( + impath: PATH_TYPE, + detection1: Union[RegionDetectionsSet, PATH_TYPE], + detection2: Union[RegionDetectionsSet, PATH_TYPE], + mask: np.ndarray, + mask_colormap: Optional[dict] = None, +) -> Tuple[np.ndarray]: + """ + Visualizes a full set of detections (detection1) against a filtered set of + detections (detection2) and the mask that acted as the filter. + + Arguments: + impath (PATH_TYPE): Path to an (M, N, 3) RGB image. + detection1 (Union[RegionDetectionsSet, PATH_TYPE]): RegionDetectionsSet + derived from a specific drone image, or a geospatial file containing + the detections from a drone image. + detection2 (Union[RegionDetectionsSet, PATH_TYPE]): Same as detection1, + but this is assumed to be a filtered version (subset) of detection1. + mask (np.ndarray): (M, N) array of masked areas. Could be [True, False] + or it could have difference mask values per area, such as [0, 1, 2] + where each value means something like [invalid, ground, tree]. + mask_colormap (Optional[dict]): If None, the matplotlib tab20 color map + is applied to the mask values. A.k.a. a mask value of 1 is given a + color of tab20(1). If a dict is given, it should be of the form + {mask value: (4 element RGBA tuple, 0-255)} + Note that if you use the dict you can leave out certain mask values, + and they will be left uncolored. Defaults to None. + + Returns: Tuple of two images: + [0] Image with detections visualized + [1] Image with the mask visualized + """ + + # Get an alpha-channel image + image = Image.open(impath).convert("RGBA") + + # Load the given detection types as geopandas dataframes + def to_gdf(detection): + """Helper to unify the two input types""" + if isinstance(detection, RegionDetectionsSet): + return detection.merge().detections + else: + return gpd.read_file(detection) + + gdf1 = to_gdf(detection1) + gdf2 = to_gdf(detection2) + + # Draw the detections, coloring each detection in gdf1 based on whether + # it exists in gdf2 + detection_overlay = Image.new("RGBA", image.size, (0, 0, 0, 0)) + detection_draw = ImageDraw.Draw(detection_overlay, "RGBA") + for idx, row in gdf1.iterrows(): + has_match = row["unique_ID"] in gdf2["unique_ID"].values + color = (0, 255, 0, 70) if has_match else (255, 0, 0, 70) + if row.geometry.geom_type == "Polygon": + polygons = [row.geometry] + elif row.geometry.geom_type == "MultiPolygon": + polygons = list(row.geometry.geoms) + else: + warnings.warn( + f"show_filtered_detections found geometry {type(row.geometry)}" + " and was unable to display it", + category=UserWarning, + ) + for poly in polygons: + # Convert polygon to pixel coordinates + coords = list(poly.exterior.coords) + # coords are (x, y), but PIL expects (col, row) so we're good + detection_draw.polygon(coords, outline=(0, 0, 0, 255), fill=color) + + # Create a colored overlay based on the given mask + mask_overlay = np.zeros((image.height, image.width, 4), dtype=np.uint8) + for value in np.unique(mask): + if mask_colormap is None: + color = (np.array(colormaps["tab20"](value)) * 255).astype(np.uint8) + # Set the alpha channel to partially transparent + color[3] = 100 + else: + color = mask_colormap.get(value, None) + if color is not None: + mask_overlay[mask == value] = color + mask_overlay = Image.fromarray(mask_overlay, mode="RGBA") + + # Return RGB image arrays + def to_ndarray(overlay): + combo = Image.alpha_composite(image, overlay) + return np.asarray(combo.convert("RGB")) + + return ( + to_ndarray(detection_overlay), + to_ndarray(mask_overlay), + )