-
Notifications
You must be signed in to change notification settings - Fork 4
Feature/fes/region filtering #153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
36b2109
00dbdbe
68d6c0b
2352b42
d20b896
2b0e04d
0753a6a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import geopandas as gpd | ||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
from matplotlib import colormaps | ||
from PIL import Image | ||
from shapely.geometry import Polygon | ||
|
||
from tree_detection_framework.detection.region_detections import RegionDetectionsSet | ||
from tree_detection_framework.utils.visualization import show_filtered_detections | ||
|
||
|
||
# Helper to create a dummy image file | ||
def save_dummy_image(imdir, size=(20, 20), color=(255, 255, 255)): | ||
img = Image.new("RGB", size, color) | ||
path = imdir / "test_image.png" | ||
img.save(path) | ||
return path | ||
|
||
|
||
def save_dummy_gdf(gdir, N): | ||
# Create DataFrame | ||
df = pd.DataFrame( | ||
{ | ||
"unique_ID": list(range(N)), | ||
"geometry": [ | ||
Polygon( | ||
[ | ||
(3 * i, 3 * i), | ||
(3 * i + 2, 3 * i), | ||
(3 * i + 2, 3 * i + 2), | ||
(3 * i, 3 * i + 2), | ||
] | ||
) | ||
for i in range(N) | ||
], | ||
} | ||
) | ||
# Save GeoDataFrame | ||
gdf = gpd.GeoDataFrame(df, geometry="geometry") | ||
path = gdir / f"test_gdf_{N}.gpkg" | ||
gdf.to_file(path) | ||
return path | ||
|
||
|
||
class TestShowFilteredDetections: | ||
|
||
@pytest.mark.parametrize("mask_dtype", [int, bool]) | ||
@pytest.mark.parametrize( | ||
"mask_colormap", | ||
[ | ||
None, | ||
{ | ||
0: (np.array(colormaps["tab20"](0)) * 255).astype(np.uint8), | ||
1: (np.array(colormaps["tab20"](1)) * 255).astype(np.uint8), | ||
}, | ||
], | ||
) | ||
def test_basic_functionality(self, tmp_path, mask_dtype, mask_colormap): | ||
|
||
# Create dummy files | ||
impath = save_dummy_image(tmp_path) | ||
gdf1 = save_dummy_gdf(tmp_path, 3) | ||
gdf2 = save_dummy_gdf(tmp_path, 2) | ||
|
||
# Create mask | ||
mask = np.zeros((20, 20), dtype=mask_dtype) | ||
mask[:10, :10] = 1 | ||
|
||
# Run function | ||
det_img, mask_img = show_filtered_detections( | ||
impath=impath, | ||
detection1=gdf1, | ||
detection2=gdf2, | ||
mask=mask, | ||
mask_colormap=mask_colormap, | ||
) | ||
assert det_img.shape == (20, 20, 3) | ||
assert mask_img.shape == (20, 20, 3) | ||
|
||
# Spot check a couple of areas. We know there were three detections, two | ||
# matches and a no-match | ||
greenish = det_img[[1, 4], [1, 4]] | ||
assert np.all(greenish[:, 1] > greenish[:, 0]) | ||
assert np.all(greenish[:, 1] > greenish[:, 2]) | ||
reddish = det_img[[7], [7]] | ||
assert np.all(reddish[:, 0] > reddish[:, 1]) | ||
assert np.all(reddish[:, 0] > reddish[:, 2]) | ||
|
||
# Check the mask map in a few spots | ||
if mask_colormap is None: | ||
# In this case the color is alpha blended with the image (white) so it | ||
# won't be an exact match | ||
assert np.argmax(mask_img[-1, -1]) == np.argmax(colormaps["tab20"](0)[:3]) | ||
assert np.argmax(mask_img[0, 0]) == np.argmax(colormaps["tab20"](1)[:3]) | ||
# The tab20[1] should be brighter | ||
assert np.sum(mask_img[0, 0]) > np.sum(mask_img[-1, -1]) | ||
else: | ||
assert np.allclose(mask_img[-5:, -5:].reshape(-1, 3), mask_colormap[0][:3]) | ||
assert np.allclose(mask_img[:5, :5].reshape(-1, 3), mask_colormap[1][:3]) | ||
|
||
assert np.any(mask_img != 255) # Should have some non-white pixels |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -465,7 +465,7 @@ def merge( | |
self, | ||
region_ID_key: Optional[str] = "region_ID", | ||
CRS: Optional[pyproj.CRS] = None, | ||
): | ||
) -> RegionDetections: | ||
"""Get the merged detections across all regions with an additional field specifying which region | ||
the detection came from. | ||
|
||
|
@@ -478,7 +478,7 @@ def merge( | |
be used. Defaults to None. | ||
|
||
Returns: | ||
gpd.GeoDataFrame: Detections in the requested CRS | ||
RegionDetections: Detections in the requested CRS | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just updating the docstring to match what the code is doing |
||
""" | ||
|
||
# Get the detections from each region detection object as geodataframes | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -630,7 +630,7 @@ def remove_masked_detections( | |
region_detection_sets: Union[List[RegionDetectionsSet], List[PATH_TYPE]], | ||
mask_iterator: Iterable[np.ndarray], | ||
threshold: float = 0.4, | ||
) -> List[RegionDetectionsSet]: | ||
) -> Union[List[RegionDetectionsSet], List[gpd.GeoDataFrame]]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not a change in behavior, just correcting the hint to match the behavior |
||
""" | ||
Filters out detections that marked as invalid in the given masks. | ||
|
||
|
@@ -699,15 +699,20 @@ def iterator(rds): | |
# over the chips that detections were calculated in | ||
for idx, gdf in iterator(rds): | ||
|
||
# Get the mean value of each detection polygon, as a list of | ||
# [{"mean": <value>}, ...] | ||
stats = rasterstats.zonal_stats(gdf, mask, stats=["mean"], affine=transform) | ||
|
||
# Get indices that have a greater fraction of good pixels than the | ||
# threshold requires. | ||
good_indices = [ | ||
i for i, stat in enumerate(stats) if stat["mean"] > threshold | ||
] | ||
if len(gdf) == 0: | ||
# Catch the case when there are no initial detections | ||
good_indices = [] | ||
else: | ||
# Get the mean value of each detection polygon, as a list of | ||
# [{"mean": <value>}, ...] | ||
stats = rasterstats.zonal_stats( | ||
gdf, mask, stats=["mean"], affine=transform | ||
) | ||
# Get indices that have a greater fraction of good pixels than the | ||
# threshold requires. | ||
good_indices = [ | ||
i for i, stat in enumerate(stats) if stat["mean"] > threshold | ||
] | ||
|
||
if isinstance(rds, RegionDetectionsSet): | ||
# Subset the RegionDetections object keeping only the valid indices | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -115,7 +115,7 @@ def ellipse_mask( | |
center (Tuple[int, int]): | ||
(x0, y0) center of the ellipse in x and y (pixels) | ||
axes (Tuple[int, int]): | ||
(a, b) semi-major and semi-minor axis lengths in pixels | ||
(a, b) x and y axis lengths (before rotation) in pixels | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A more correct description |
||
angle_rad (float): | ||
Rotation angle of the semi-major axis, in radians. CCW from x-axis | ||
(a.k.a. right-hand rule out of the image). Defaults to 0. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import warnings | ||
from typing import Optional, Tuple, Union | ||
|
||
import geopandas as gpd | ||
import numpy as np | ||
from matplotlib import colormaps | ||
from PIL import Image, ImageDraw | ||
|
||
from tree_detection_framework.constants import PATH_TYPE | ||
from tree_detection_framework.detection.region_detections import RegionDetectionsSet | ||
|
||
|
||
def show_filtered_detections( | ||
impath: PATH_TYPE, | ||
detection1: Union[RegionDetectionsSet, PATH_TYPE], | ||
detection2: Union[RegionDetectionsSet, PATH_TYPE], | ||
mask: np.ndarray, | ||
mask_colormap: Optional[dict] = None, | ||
) -> Tuple[np.ndarray]: | ||
""" | ||
Visualizes a full set of detections (detection1) against a filtered set of | ||
detections (detection2) and the mask that acted as the filter. | ||
|
||
Arguments: | ||
impath (PATH_TYPE): Path to an (M, N, 3) RGB image. | ||
detection1 (Union[RegionDetectionsSet, PATH_TYPE]): RegionDetectionsSet | ||
derived from a specific drone image, or a geospatial file containing | ||
the detections from a drone image. | ||
detection2 (Union[RegionDetectionsSet, PATH_TYPE]): Same as detection1, | ||
but this is assumed to be a filtered version (subset) of detection1. | ||
mask (np.ndarray): (M, N) array of masked areas. Could be [True, False] | ||
or it could have difference mask values per area, such as [0, 1, 2] | ||
where each value means something like [invalid, ground, tree]. | ||
mask_colormap (Optional[dict]): If None, the matplotlib tab20 color map | ||
is applied to the mask values. A.k.a. a mask value of 1 is given a | ||
color of tab20(1). If a dict is given, it should be of the form | ||
{mask value: (4 element RGBA tuple, 0-255)} | ||
Note that if you use the dict you can leave out certain mask values, | ||
and they will be left uncolored. Defaults to None. | ||
|
||
Returns: Tuple of two images: | ||
[0] Image with detections visualized | ||
[1] Image with the mask visualized | ||
""" | ||
|
||
# Get an alpha-channel image | ||
image = Image.open(impath).convert("RGBA") | ||
|
||
# Load the given detection types as geopandas dataframes | ||
def to_gdf(detection): | ||
"""Helper to unify the two input types""" | ||
if isinstance(detection, RegionDetectionsSet): | ||
return detection.merge().detections | ||
else: | ||
return gpd.read_file(detection) | ||
|
||
gdf1 = to_gdf(detection1) | ||
gdf2 = to_gdf(detection2) | ||
|
||
# Draw the detections, coloring each detection in gdf1 based on whether | ||
# it exists in gdf2 | ||
detection_overlay = Image.new("RGBA", image.size, (0, 0, 0, 0)) | ||
detection_draw = ImageDraw.Draw(detection_overlay, "RGBA") | ||
for idx, row in gdf1.iterrows(): | ||
has_match = row["unique_ID"] in gdf2["unique_ID"].values | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The "unique_id" column gets created whenever There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure, the files that I was using (documented in the PR description) came from TDF and had that column so I assumed it would generally be available. I suppose
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah okay. Any files created using TDF is guaranteed to have the If needed, I think it would be good to add |
||
color = (0, 255, 0, 70) if has_match else (255, 0, 0, 70) | ||
if row.geometry.geom_type == "Polygon": | ||
polygons = [row.geometry] | ||
elif row.geometry.geom_type == "MultiPolygon": | ||
polygons = list(row.geometry.geoms) | ||
else: | ||
warnings.warn( | ||
f"show_filtered_detections found geometry {type(row.geometry)}" | ||
" and was unable to display it", | ||
category=UserWarning, | ||
) | ||
for poly in polygons: | ||
# Convert polygon to pixel coordinates | ||
coords = list(poly.exterior.coords) | ||
FranzEricSchneider marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# coords are (x, y), but PIL expects (col, row) so we're good | ||
detection_draw.polygon(coords, outline=(0, 0, 0, 255), fill=color) | ||
|
||
# Create a colored overlay based on the given mask | ||
mask_overlay = np.zeros((image.height, image.width, 4), dtype=np.uint8) | ||
for value in np.unique(mask): | ||
if mask_colormap is None: | ||
color = (np.array(colormaps["tab20"](value)) * 255).astype(np.uint8) | ||
# Set the alpha channel to partially transparent | ||
color[3] = 100 | ||
else: | ||
color = mask_colormap.get(value, None) | ||
if color is not None: | ||
mask_overlay[mask == value] = color | ||
mask_overlay = Image.fromarray(mask_overlay, mode="RGBA") | ||
|
||
# Return RGB image arrays | ||
def to_ndarray(overlay): | ||
combo = Image.alpha_composite(image, overlay) | ||
return np.asarray(combo.convert("RGB")) | ||
|
||
return ( | ||
to_ndarray(detection_overlay), | ||
to_ndarray(mask_overlay), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran into an error with an empty GDF, this test replicated the error and then started passing when the error was fixed.