Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions tests/test_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,24 @@ def test_basic(self, tmp_path, flat, ellipse_radius, use_rds, geo_extension):
expected_map = {10: [2], 30: [2, 3, 4, 5, 6], 50: [2, 3, 4, 5, 6, 7, 8, 9, 10]}
for rds in filtered_detection_list:
check_expected_rds(rds, use_rds, expected_map[ellipse_radius])

@pytest.mark.parametrize(
"use_rds,geo_extension",
([True, None], [False, ".gpkg"], [False, ".geojson"], [False, ".shp"]),
)
def test_empty(self, tmp_path, use_rds, geo_extension):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran into an error with an empty GDF, this test replicated the error and then started passing when the error was fixed.


# Make an empty set of detections
detection_list = get_detections([], use_rds, tmp_path, geo_extension)

# Create a mask
mask = np.ones(shape=(150, 150), dtype=bool)

# Filter the empty detections
filtered_detection_list = remove_masked_detections(
region_detection_sets=detection_list,
mask_iterator=repeat(mask),
threshold=0.5,
)
for rds in filtered_detection_list:
check_expected_rds(rds, use_rds, [])
102 changes: 102 additions & 0 deletions tests/test_visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import geopandas as gpd
import numpy as np
import pandas as pd
import pytest
from matplotlib import colormaps
from PIL import Image
from shapely.geometry import Polygon

from tree_detection_framework.detection.region_detections import RegionDetectionsSet
from tree_detection_framework.utils.visualization import show_filtered_detections


# Helper to create a dummy image file
def save_dummy_image(imdir, size=(20, 20), color=(255, 255, 255)):
img = Image.new("RGB", size, color)
path = imdir / "test_image.png"
img.save(path)
return path


def save_dummy_gdf(gdir, N):
# Create DataFrame
df = pd.DataFrame(
{
"unique_ID": list(range(N)),
"geometry": [
Polygon(
[
(3 * i, 3 * i),
(3 * i + 2, 3 * i),
(3 * i + 2, 3 * i + 2),
(3 * i, 3 * i + 2),
]
)
for i in range(N)
],
}
)
# Save GeoDataFrame
gdf = gpd.GeoDataFrame(df, geometry="geometry")
path = gdir / f"test_gdf_{N}.gpkg"
gdf.to_file(path)
return path


class TestShowFilteredDetections:

@pytest.mark.parametrize("mask_dtype", [int, bool])
@pytest.mark.parametrize(
"mask_colormap",
[
None,
{
0: (np.array(colormaps["tab20"](0)) * 255).astype(np.uint8),
1: (np.array(colormaps["tab20"](1)) * 255).astype(np.uint8),
},
],
)
def test_basic_functionality(self, tmp_path, mask_dtype, mask_colormap):

# Create dummy files
impath = save_dummy_image(tmp_path)
gdf1 = save_dummy_gdf(tmp_path, 3)
gdf2 = save_dummy_gdf(tmp_path, 2)

# Create mask
mask = np.zeros((20, 20), dtype=mask_dtype)
mask[:10, :10] = 1

# Run function
det_img, mask_img = show_filtered_detections(
impath=impath,
detection1=gdf1,
detection2=gdf2,
mask=mask,
mask_colormap=mask_colormap,
)
assert det_img.shape == (20, 20, 3)
assert mask_img.shape == (20, 20, 3)

# Spot check a couple of areas. We know there were three detections, two
# matches and a no-match
greenish = det_img[[1, 4], [1, 4]]
assert np.all(greenish[:, 1] > greenish[:, 0])
assert np.all(greenish[:, 1] > greenish[:, 2])
reddish = det_img[[7], [7]]
assert np.all(reddish[:, 0] > reddish[:, 1])
assert np.all(reddish[:, 0] > reddish[:, 2])

# Check the mask map in a few spots
if mask_colormap is None:
# In this case the color is alpha blended with the image (white) so it
# won't be an exact match
assert np.argmax(mask_img[-1, -1]) == np.argmax(colormaps["tab20"](0)[:3])
assert np.argmax(mask_img[0, 0]) == np.argmax(colormaps["tab20"](1)[:3])
# The tab20[1] should be brighter
assert np.sum(mask_img[0, 0]) > np.sum(mask_img[-1, -1])
else:
assert np.allclose(mask_img[-5:, -5:].reshape(-1, 3), mask_colormap[0][:3])
assert np.allclose(mask_img[:5, :5].reshape(-1, 3), mask_colormap[1][:3])

assert np.any(mask_img != 255) # Should have some non-white pixels
4 changes: 2 additions & 2 deletions tree_detection_framework/detection/region_detections.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def merge(
self,
region_ID_key: Optional[str] = "region_ID",
CRS: Optional[pyproj.CRS] = None,
):
) -> RegionDetections:
"""Get the merged detections across all regions with an additional field specifying which region
the detection came from.

Expand All @@ -478,7 +478,7 @@ def merge(
be used. Defaults to None.

Returns:
gpd.GeoDataFrame: Detections in the requested CRS
RegionDetections: Detections in the requested CRS
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just updating the docstring to match what the code is doing

"""

# Get the detections from each region detection object as geodataframes
Expand Down
25 changes: 15 additions & 10 deletions tree_detection_framework/postprocessing/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ def remove_masked_detections(
region_detection_sets: Union[List[RegionDetectionsSet], List[PATH_TYPE]],
mask_iterator: Iterable[np.ndarray],
threshold: float = 0.4,
) -> List[RegionDetectionsSet]:
) -> Union[List[RegionDetectionsSet], List[gpd.GeoDataFrame]]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a change in behavior, just correcting the hint to match the behavior

"""
Filters out detections that marked as invalid in the given masks.

Expand Down Expand Up @@ -699,15 +699,20 @@ def iterator(rds):
# over the chips that detections were calculated in
for idx, gdf in iterator(rds):

# Get the mean value of each detection polygon, as a list of
# [{"mean": <value>}, ...]
stats = rasterstats.zonal_stats(gdf, mask, stats=["mean"], affine=transform)

# Get indices that have a greater fraction of good pixels than the
# threshold requires.
good_indices = [
i for i, stat in enumerate(stats) if stat["mean"] > threshold
]
if len(gdf) == 0:
# Catch the case when there are no initial detections
good_indices = []
else:
# Get the mean value of each detection polygon, as a list of
# [{"mean": <value>}, ...]
stats = rasterstats.zonal_stats(
gdf, mask, stats=["mean"], affine=transform
)
# Get indices that have a greater fraction of good pixels than the
# threshold requires.
good_indices = [
i for i, stat in enumerate(stats) if stat["mean"] > threshold
]

if isinstance(rds, RegionDetectionsSet):
# Subset the RegionDetections object keeping only the valid indices
Expand Down
2 changes: 1 addition & 1 deletion tree_detection_framework/utils/geometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def ellipse_mask(
center (Tuple[int, int]):
(x0, y0) center of the ellipse in x and y (pixels)
axes (Tuple[int, int]):
(a, b) semi-major and semi-minor axis lengths in pixels
(a, b) x and y axis lengths (before rotation) in pixels
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A more correct description

angle_rad (float):
Rotation angle of the semi-major axis, in radians. CCW from x-axis
(a.k.a. right-hand rule out of the image). Defaults to 0.
Expand Down
104 changes: 104 additions & 0 deletions tree_detection_framework/utils/visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import warnings
from typing import Optional, Tuple, Union

import geopandas as gpd
import numpy as np
from matplotlib import colormaps
from PIL import Image, ImageDraw

from tree_detection_framework.constants import PATH_TYPE
from tree_detection_framework.detection.region_detections import RegionDetectionsSet


def show_filtered_detections(
impath: PATH_TYPE,
detection1: Union[RegionDetectionsSet, PATH_TYPE],
detection2: Union[RegionDetectionsSet, PATH_TYPE],
mask: np.ndarray,
mask_colormap: Optional[dict] = None,
) -> Tuple[np.ndarray]:
"""
Visualizes a full set of detections (detection1) against a filtered set of
detections (detection2) and the mask that acted as the filter.

Arguments:
impath (PATH_TYPE): Path to an (M, N, 3) RGB image.
detection1 (Union[RegionDetectionsSet, PATH_TYPE]): RegionDetectionsSet
derived from a specific drone image, or a geospatial file containing
the detections from a drone image.
detection2 (Union[RegionDetectionsSet, PATH_TYPE]): Same as detection1,
but this is assumed to be a filtered version (subset) of detection1.
mask (np.ndarray): (M, N) array of masked areas. Could be [True, False]
or it could have difference mask values per area, such as [0, 1, 2]
where each value means something like [invalid, ground, tree].
mask_colormap (Optional[dict]): If None, the matplotlib tab20 color map
is applied to the mask values. A.k.a. a mask value of 1 is given a
color of tab20(1). If a dict is given, it should be of the form
{mask value: (4 element RGBA tuple, 0-255)}
Note that if you use the dict you can leave out certain mask values,
and they will be left uncolored. Defaults to None.

Returns: Tuple of two images:
[0] Image with detections visualized
[1] Image with the mask visualized
"""

# Get an alpha-channel image
image = Image.open(impath).convert("RGBA")

# Load the given detection types as geopandas dataframes
def to_gdf(detection):
"""Helper to unify the two input types"""
if isinstance(detection, RegionDetectionsSet):
return detection.merge().detections
else:
return gpd.read_file(detection)

gdf1 = to_gdf(detection1)
gdf2 = to_gdf(detection2)

# Draw the detections, coloring each detection in gdf1 based on whether
# it exists in gdf2
detection_overlay = Image.new("RGBA", image.size, (0, 0, 0, 0))
detection_draw = ImageDraw.Draw(detection_overlay, "RGBA")
for idx, row in gdf1.iterrows():
has_match = row["unique_ID"] in gdf2["unique_ID"].values
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "unique_id" column gets created whenever merge() is done on a RegionDetectionsSet. So if the detections are being read from a .gpkg/.geojson file is it expected by the file to have the "unique_ID" column? If so should that be specified in the documentation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure, the files that I was using (documented in the PR description) came from TDF and had that column so I assumed it would generally be available. I suppose

  1. Is there a better column available?
  2. If not, we can add id_column as an argument with unique_ID as the default

Copy link
Contributor

@amrithasp02 amrithasp02 Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay. Any files created using TDF is guaranteed to have the unique_ID column. I was only asking for use cases where this could get input vector files generated outside of TDF. But I guess there isn't a possibility of that happening.

If needed, I think it would be good to add id_column as an argument with unique_ID default like you suggested, just to keep it more flexible.

color = (0, 255, 0, 70) if has_match else (255, 0, 0, 70)
if row.geometry.geom_type == "Polygon":
polygons = [row.geometry]
elif row.geometry.geom_type == "MultiPolygon":
polygons = list(row.geometry.geoms)
else:
warnings.warn(
f"show_filtered_detections found geometry {type(row.geometry)}"
" and was unable to display it",
category=UserWarning,
)
for poly in polygons:
# Convert polygon to pixel coordinates
coords = list(poly.exterior.coords)
# coords are (x, y), but PIL expects (col, row) so we're good
detection_draw.polygon(coords, outline=(0, 0, 0, 255), fill=color)

# Create a colored overlay based on the given mask
mask_overlay = np.zeros((image.height, image.width, 4), dtype=np.uint8)
for value in np.unique(mask):
if mask_colormap is None:
color = (np.array(colormaps["tab20"](value)) * 255).astype(np.uint8)
# Set the alpha channel to partially transparent
color[3] = 100
else:
color = mask_colormap.get(value, None)
if color is not None:
mask_overlay[mask == value] = color
mask_overlay = Image.fromarray(mask_overlay, mode="RGBA")

# Return RGB image arrays
def to_ndarray(overlay):
combo = Image.alpha_composite(image, overlay)
return np.asarray(combo.convert("RGB"))

return (
to_ndarray(detection_overlay),
to_ndarray(mask_overlay),
)