Skip to content

Commit 18bfc43

Browse files
cpuhrschjainapurva
authored andcommitted
More batching and improved furious accuracy/performance (#1253)
1 parent cdebc2d commit 18bfc43

File tree

9 files changed

+408
-156
lines changed

9 files changed

+408
-156
lines changed

examples/sam2_amg_server/README.md

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,29 @@ xargs -I {} curl -s -w "\n" -X POST http://<your_hostname>:<your_port>/upload_rl
2121

2222
Experiments run on H100 and with batch size 1
2323

24-
| mode | mIoU | mask count mismatch | avg. ms per request | batch size | points per batch |
25-
| --- | --- | ------------------- | ------------------- | ---------- | ---------------- |
26-
| baseline | 1.0 | 0 | 786 | 1 | 64 |
27-
| baseline | N/A | N/A | N/A | 32 | 1024 |
28-
| ao | 1.0 | 0 | 738 | 1 | 64 |
29-
| ao | 0.9999994993636996 | 0 | 564 | 32 | 1024 |
30-
| fast | 0.95 | 190 | 563 | 1 | 64 |
31-
| fast | 0.9527849197435295 | 191 | 460 | 32 | 1024 |
32-
| furious | 0 | 1000 | 204 | 1 | 64 |
33-
| furious | 0 | 1000 | 210 | 32 | 1024 |
24+
| mode | mIoU | mask count mismatch | avg. ms per request | max. memory (MiB (%)) | batch size | points per batch |
25+
| -------------- | ----------------- | ------------------- | ------------------- | --------------------- | ---------- | ---------------- |
26+
| baseline | 1.0 | 0 | 863 | 4013MiB (4%) | 1 | 64 |
27+
| ao | 1.0 | 0 | 840 | 4350MiB (4%) | 1 | 64 |
28+
| fast | 0.9897813200950623 | 191 | 661 | 3916MiB (4%) | 1 | 64 |
29+
| fast | 0.9897371530532837 | 192 | 388 | 50787MiB (52%) | 16 | 1024 |
30+
| fast + furious | 0.974319338798523 | 209 | 461 | 3453MiB (3%) | 1 | 64 |
31+
| fast + furious | 0.9702069759368896 | 196 | 195 | 48298MiB (49%) | 16 | 1024 |
3432

3533
mask count mismatch counts the number of requests where the number of masks differ from the baseline.
3634
For example, the baseline may have chosen to segment an image into 18 masks, but the fast variant produces 17 or 19.
3735
We exclude these examples from the mIoU calculation.
3836

37+
The 'ao' mode is a copy of the baseline with modifications to make the code compile-able and improve the performance of fast.
38+
39+
### 0. Download checkpoints and install requirements
40+
41+
```
42+
pip install -r requirements.txt
43+
```
44+
45+
Download `sam2.1_hiera_large.pt` from https://github.com/facebookresearch/sam2?tab=readme-ov-file#download-checkpoints and put it into `~/checkpoints/sam2`
46+
3947
### 1. Create a random subset of 1000 images
4048
```
4149
find sav_val -type f > sav_val_image_paths

examples/sam2_amg_server/compare_rle_lists.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import fire
22
import torch
33
import json
4-
from sam2.utils.amg import rle_to_mask
4+
from torchao._models.sam2.utils.amg import rle_to_mask
55

66
"""
77
Script to calculate mIoU given two lists of rles from upload_rle endpoint
@@ -16,6 +16,39 @@ def iou(mask1, mask2):
1616
union = torch.logical_or(mask1, mask2)
1717
return (intersection.sum(dim=(-1, -2)) / union.sum(dim=(-1, -2)))
1818

19+
def compare_masks(masks, ref_masks, order_by_area=False, verbose=False):
20+
from torchao._models.sam2.utils.amg import rle_to_mask
21+
v0_areas = []
22+
v1_areas = []
23+
v0_masks = []
24+
v1_masks = []
25+
for k0 in ref_masks:
26+
assert k0 in masks, f"Expected {k0} to be in return data"
27+
from torchao._models.sam2.utils.amg import area_from_rle
28+
v0_area = area_from_rle(ref_masks[k0])
29+
v1_area = area_from_rle(masks[k0])
30+
v0_areas.append(v0_area)
31+
v1_areas.append(v1_area)
32+
if (v0_area != v1_area) and verbose:
33+
print(f"v0 area {v0_area} doesn't match v1 area {v1_area}")
34+
v0_mask = torch.from_numpy(rle_to_mask(ref_masks[k0]))
35+
v1_mask = torch.from_numpy(rle_to_mask(masks[k0]))
36+
v0_masks.append((v0_mask, v0_area))
37+
v1_masks.append((v1_mask, v1_area))
38+
39+
if order_by_area:
40+
v0_masks = sorted(v0_masks, key=(lambda x: x[1]), reverse=True)
41+
v1_masks = sorted(v1_masks, key=(lambda x: x[1]), reverse=True)
42+
miou_sum = 0.0
43+
miou_count = 0
44+
for ((v0_mask, _), (v1_mask, _)) in zip(v0_masks, v1_masks):
45+
miou_sum += iou(v0_mask, v1_mask)
46+
miou_count += 1
47+
if verbose:
48+
print(f"Masks don't match for key {k0}. IoU is {iou(v0_mask, v1_mask)}")
49+
50+
return miou_sum, miou_count
51+
1952

2053
def main(path0, path1):
2154
fail_count = 0
@@ -28,11 +61,9 @@ def main(path0, path1):
2861
if masks0.keys() != masks1.keys():
2962
fail_count += 1
3063
continue
31-
for mask0, mask1 in zip(masks0.values(), masks1.values()):
32-
mask0 = torch.from_numpy(rle_to_mask(mask0))
33-
mask1 = torch.from_numpy(rle_to_mask(mask1))
34-
miou_sum += iou(mask0, mask1).item()
35-
miou_count += 1
64+
s, c = compare_masks(masks0, masks1, order_by_area=True)
65+
miou_sum += s
66+
miou_count += c
3667

3768
print(f"fail_count: {fail_count} mIoU: {miou_sum / miou_count}")
3869

0 commit comments

Comments
 (0)