Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 92 additions & 3 deletions datasets.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,102 @@
import torch
import pathlib

from torch.hub import tqdm

from torchvision import datasets
from torchvision.transforms import functional as F_v1

COCO_ROOT = "~/datasets/coco"

__all__ = ["classification_dataset_builder", "detection_dataset_builder"]

def classification_dataset_builder(*, input_type, api_version, rng, num_samples):

def classification_dataset_builder(*, api_version, rng, num_samples):
return [
F_v1.to_pil_image(
# average size of images in ImageNet
torch.randint(0, 256, (3, 469, 387), dtype=torch.uint8, generator=rng)
torch.randint(0, 256, (3, 469, 387), dtype=torch.uint8, generator=rng),
)
for _ in range(num_samples)
]


def detection_dataset_builder(*, api_version, rng, num_samples):
root = pathlib.Path(COCO_ROOT).expanduser().resolve()
image_folder = str(root / "train2017")
annotation_file = str(root / "annotations" / "instances_train2017.json")
if api_version == "v1":
dataset = CocoDetectionV1(image_folder, annotation_file, transforms=None)
elif api_version == "v2":
dataset = datasets.CocoDetection(image_folder, annotation_file)
else:
raise ValueError(f"Got {api_version=}")

dataset = _coco_remove_images_without_annotations(dataset)

idcs = torch.randperm(len(dataset), generator=rng)[:num_samples].tolist()
print(f"Caching {num_samples} ({idcs[:3]} ... {idcs[-3:]}) COCO samples")
return [dataset[idx] for idx in tqdm(idcs)]


# everything below is copy-pasted from
# https://github.com/pytorch/vision/blob/main/references/detection/coco_utils.py

import torch
import torchvision


class CocoDetectionV1(torchvision.datasets.CocoDetection):
def __init__(self, img_folder, ann_file, transforms):
super().__init__(img_folder, ann_file)
self._transforms = transforms

def __getitem__(self, idx):
img, target = super().__getitem__(idx)
image_id = self.ids[idx]
target = dict(image_id=image_id, annotations=target)
if self._transforms is not None:
img, target = self._transforms(img, target)
return img, target


def _coco_remove_images_without_annotations(dataset, cat_list=None):
def _has_only_empty_bbox(anno):
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)

def _count_visible_keypoints(anno):
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)

min_keypoints_per_image = 10

def _has_valid_annotation(anno):
# if it's empty, there is no annotation
if len(anno) == 0:
return False
# if all boxes have close to zero area, there is no annotation
if _has_only_empty_bbox(anno):
return False
# keypoints task have a slight different criteria for considering
# if an annotation is valid
if "keypoints" not in anno[0]:
return True
# for keypoint detection tasks, only consider valid images those
# containing at least min_keypoints_per_image
if _count_visible_keypoints(anno) >= min_keypoints_per_image:
return True
return False

if not isinstance(dataset, torchvision.datasets.CocoDetection):
raise TypeError(
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
)
ids = []
for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
anno = dataset.coco.loadAnns(ann_ids)
if cat_list:
anno = [obj for obj in anno if obj["category_id"] in cat_list]
if _has_valid_annotation(anno):
ids.append(ds_idx)

dataset = torch.utils.data.Subset(dataset, ids)
return dataset
141 changes: 78 additions & 63 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import contextlib
import itertools
import pathlib
import string
import sys
from datetime import datetime

Expand All @@ -23,97 +24,111 @@ def write(self, message):
self.stdout.write(message)
self.file.write(message)

def flush(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is this needed?

self.stdout.flush()
self.file.flush()


def main(*, input_types, tasks, num_samples):
# This is hardcoded when using a DataLoader with multiple workers:
# https://github.com/pytorch/pytorch/blob/19162083f8831be87be01bb84f186310cad1d348/torch/utils/data/_utils/worker.py#L222
torch.set_num_threads(1)

dataset_rng = torch.Generator()
dataset_rng.manual_seed(0)
dataset_rng_state = dataset_rng.get_state()

for task_name in tasks:
print("#" * 60)
print(task_name)
print("#" * 60)

medians = {input_type: {} for input_type in input_types}
for input_type in input_types:
dataset_rng = torch.Generator()
dataset_rng.manual_seed(0)
dataset_rng_state = dataset_rng.get_state()

for api_version in ["v1", "v2"]:
dataset_rng.set_state(dataset_rng_state)
task = make_task(
task_name,
input_type=input_type,
api_version=api_version,
dataset_rng=dataset_rng,
num_samples=num_samples,
)
if task is None:
continue

print(f"{input_type=}, {api_version=}")
print()
print(f"Results computed for {num_samples:_} samples")
print()

pipeline, dataset = task

for sample in dataset:
pipeline(sample)

results = pipeline.extract_times()
field_len = max(len(name) for name in results)
print(f"{' ' * field_len} {'median ':>9} {'std ':>9}")
medians[input_type][api_version] = 0.0
for transform_name, times in results.items():
median = float(times.median())
print(
f"{transform_name:{field_len}} {median * 1e6:6.0f} µs +- {float(times.std()) * 1e6:6.0f} µs"
)
medians[input_type][api_version] += median
for input_type, api_version in itertools.product(input_types, ["v1", "v2"]):
dataset_rng.set_state(dataset_rng_state)
task = make_task(
task_name,
input_type=input_type,
api_version=api_version,
dataset_rng=dataset_rng,
num_samples=num_samples,
)
if task is None:
continue

print(
f"\n{'total':{field_len}} {medians[input_type][api_version] * 1e6:6.0f} µs"
)
print("-" * 60)
print(f"{input_type=}, {api_version=}")
print()
print(f"Results computed for {num_samples:_} samples")
print()

print()
print("Summaries")
print()
pipeline, dataset = task

field_len = max(len(input_type) for input_type in medians)
print(f"{' ' * field_len} v2 / v1")
for input_type, api_versions in medians.items():
if len(api_versions) < 2:
continue
torch.manual_seed(0)
for sample in dataset:
pipeline(sample)

results = pipeline.extract_times()
field_len = max(len(name) for name in results)
print(f"{' ' * field_len} {'median ':>9} {'std ':>9}")
medians[input_type][api_version] = 0.0
for transform_name, times in results.items():
median = float(times.median())
print(
f"{transform_name:{field_len}} {median * 1e6:6.0f} µs +- {float(times.std()) * 1e6:6.0f} µs"
)
medians[input_type][api_version] += median

print(
f"{input_type:{field_len}} {api_versions['v2'] / api_versions['v1']:>7.2f}"
f"\n{'total':{field_len}} {medians[input_type][api_version] * 1e6:6.0f} µs"
)
print("-" * 60)

print()
print("Summaries")
print()

print()
field_len = max(len(input_type) for input_type in medians)
print(f"{' ' * field_len} v2 / v1")
for input_type, api_versions in medians.items():
if len(api_versions) < 2:
continue

median_ref = medians["PIL"]["v1"]
medians_flat = {
f"{input_type}, {api_version}": median
for input_type, api_versions in medians.items()
for api_version, median in api_versions.items()
}
field_len = max(len(label) for label in medians_flat)
print(f"{' ' * field_len} x / PIL, v1")
for label, median in medians_flat.items():
print(f"{label:{field_len}} {median / median_ref:>11.2f}")
print(
f"{input_type:{field_len}} {api_versions['v2'] / api_versions['v1']:>7.2f}"
)

print()

medians_flat = {
f"{input_type}, {api_version}": median
for input_type, api_versions in medians.items()
for api_version, median in api_versions.items()
}
field_len = max(len(label) for label in medians_flat)

print(
f"{' ' * (field_len + 5)} {' '.join(f' [{id}]' for _, id in zip(range(len(medians_flat)), string.ascii_lowercase))}"
)
for (label, val), id in zip(medians_flat.items(), string.ascii_lowercase):
print(
f"{label:>{field_len}}, [{id}] {' '.join(f'{val / ref:4.2f}' for ref in medians_flat.values())}"
)
print()
print("Slowdown as row / col")


if __name__ == "__main__":
tee = Tee(stdout=sys.stdout)

with contextlib.redirect_stdout(tee):
main(
tasks=["classification-simple", "classification-complex"],
tasks=[
"classification-simple",
"classification-complex",
"detection-ssdlite",
],
input_types=["Tensor", "PIL", "Datapoint"],
num_samples=10_000,
num_samples=1_000,
)

print("#" * 60)
Expand Down
Loading