diff --git a/references/video_classification/datasets.py b/references/video_classification/datasets.py new file mode 100644 index 00000000000..dec1e16b856 --- /dev/null +++ b/references/video_classification/datasets.py @@ -0,0 +1,15 @@ +from typing import Tuple + +import torchvision +from torch import Tensor + + +class KineticsWithVideoId(torchvision.datasets.Kinetics): + def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]: + video, audio, info, video_idx = self.video_clips.get_clip(idx) + label = self.samples[video_idx][1] + + if self.transform is not None: + video = self.transform(video) + + return video, audio, label, video_idx diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 016e6024886..4da8331a1c6 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -3,6 +3,7 @@ import time import warnings +import datasets import presets import torch import torch.utils.data @@ -11,7 +12,7 @@ import utils from torch import nn from torch.utils.data.dataloader import default_collate -from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler +from torchvision.datasets.samplers import DistributedSampler, RandomClipSampler, UniformClipSampler def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None): @@ -21,7 +22,7 @@ def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, devi metric_logger.add_meter("clips/s", utils.SmoothedValue(window_size=10, fmt="{value:.3f}")) header = f"Epoch: [{epoch}]" - for video, target in metric_logger.log_every(data_loader, print_freq, header): + for video, target, _ in metric_logger.log_every(data_loader, print_freq, header): start_time = time.time() video, target = video.to(device), target.to(device) with torch.cuda.amp.autocast(enabled=scaler is not None): @@ -52,13 +53,25 @@ def evaluate(model, criterion, data_loader, device): metric_logger = utils.MetricLogger(delimiter=" ") header = "Test:" num_processed_samples = 0 + # Group and aggregate output of a video + num_videos = len(data_loader.dataset.samples) + num_classes = len(data_loader.dataset.classes) + agg_preds = torch.zeros((num_videos, num_classes), dtype=torch.float32, device=device) + agg_targets = torch.zeros((num_videos), dtype=torch.int32, device=device) with torch.inference_mode(): - for video, target in metric_logger.log_every(data_loader, 100, header): + for video, target, video_idx in metric_logger.log_every(data_loader, 100, header): video = video.to(device, non_blocking=True) target = target.to(device, non_blocking=True) output = model(video) loss = criterion(output, target) + # Use softmax to convert output into prediction probability + preds = torch.softmax(output, dim=1) + for b in range(video.size(0)): + idx = video_idx[b].item() + agg_preds[idx] += preds[b].detach() + agg_targets[idx] = target[b].detach().item() + acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) # FIXME need to take into account that the datasets # could have been padded in distributed setup @@ -95,6 +108,11 @@ def evaluate(model, criterion, data_loader, device): top1=metric_logger.acc1, top5=metric_logger.acc5 ) ) + # Reduce the agg_preds and agg_targets from all gpu and show result + agg_preds = utils.reduce_across_processes(agg_preds) + agg_targets = utils.reduce_across_processes(agg_targets, op=torch.distributed.ReduceOp.MAX) + agg_acc1, agg_acc5 = utils.accuracy(agg_preds, agg_targets, topk=(1, 5)) + print(" * Video Acc@1 {acc1:.3f} Video Acc@5 {acc5:.3f}".format(acc1=agg_acc1, acc5=agg_acc5)) return metric_logger.acc1.global_avg @@ -110,7 +128,7 @@ def _get_cache_path(filepath, args): def collate_fn(batch): # remove audio from the batch - batch = [(d[0], d[2]) for d in batch] + batch = [(d[0], d[2], d[3]) for d in batch] return default_collate(batch) @@ -146,7 +164,7 @@ def main(args): else: if args.distributed: print("It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster") - dataset = torchvision.datasets.Kinetics( + dataset = datasets.KineticsWithVideoId( args.data_path, frames_per_clip=args.clip_len, num_classes=args.kinetics_version, @@ -183,7 +201,7 @@ def main(args): else: if args.distributed: print("It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster") - dataset_test = torchvision.datasets.Kinetics( + dataset_test = datasets.KineticsWithVideoId( args.data_path, frames_per_clip=args.clip_len, num_classes=args.kinetics_version, @@ -313,10 +331,10 @@ def main(args): print(f"Training time {total_time_str}") -def parse_args(): +def get_args_parser(add_help=True): import argparse - parser = argparse.ArgumentParser(description="PyTorch Video Classification Training") + parser = argparse.ArgumentParser(description="PyTorch Video Classification Training", add_help=add_help) parser.add_argument("--data-path", default="/datasets01_101/kinetics/070618/", type=str, help="dataset path") parser.add_argument( @@ -387,11 +405,9 @@ def parse_args(): # Mixed precision training parameters parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") - args = parser.parse_args() - - return args + return parser if __name__ == "__main__": - args = parse_args() + args = get_args_parser().parse_args() main(args) diff --git a/references/video_classification/utils.py b/references/video_classification/utils.py index 024426d5916..934f62f66ae 100644 --- a/references/video_classification/utils.py +++ b/references/video_classification/utils.py @@ -253,12 +253,12 @@ def init_distributed_mode(args): setup_for_distributed(args.rank == 0) -def reduce_across_processes(val): +def reduce_across_processes(val, op=dist.ReduceOp.SUM): if not is_dist_avail_and_initialized(): # nothing to sync, but we still convert to tensor for consistency with the distributed case. return torch.tensor(val) t = torch.tensor(val, device="cuda") dist.barrier() - dist.all_reduce(t) + dist.all_reduce(t, op=op) return t diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index d8bfc0dbb77..0fd76399b5e 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -445,12 +445,15 @@ class MViT_V1_B_Weights(WeightsEnum): "min_temporal_size": 16, "categories": _KINETICS400_CATEGORIES, "recipe": "https://github.com/facebookresearch/pytorchvideo/blob/main/docs/source/model_zoo.md", - "_docs": """These weights support 16-frame clip inputs and were ported from the paper.""", + "_docs": ( + "The weights were ported from the paper. The accuracies are estimated on video-level " + "with parameters `frame_rate=7.5`, `clips_per_video=5`, and `clip_len=16`" + ), "num_params": 36610672, "_metrics": { "Kinetics-400": { - "acc@1": 78.47, - "acc@5": 93.65, + "acc@1": 78.477, + "acc@5": 93.582, } }, }, diff --git a/torchvision/models/video/resnet.py b/torchvision/models/video/resnet.py index cd40717bbbd..6ec8bfc0b3e 100644 --- a/torchvision/models/video/resnet.py +++ b/torchvision/models/video/resnet.py @@ -312,7 +312,10 @@ def _video_resnet( "min_size": (1, 1), "categories": _KINETICS400_CATEGORIES, "recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification", - "_docs": """These weights reproduce closely the accuracy of the paper for 16-frame clip inputs.""", + "_docs": ( + "The weights reproduce closely the accuracy of the paper. The accuracies are estimated on video-level " + "with parameters `frame_rate=15`, `clips_per_video=5`, and `clip_len=16`." + ), } @@ -325,8 +328,8 @@ class R3D_18_Weights(WeightsEnum): "num_params": 33371472, "_metrics": { "Kinetics-400": { - "acc@1": 52.75, - "acc@5": 75.45, + "acc@1": 63.200, + "acc@5": 83.479, } }, }, @@ -343,8 +346,8 @@ class MC3_18_Weights(WeightsEnum): "num_params": 11695440, "_metrics": { "Kinetics-400": { - "acc@1": 53.90, - "acc@5": 76.29, + "acc@1": 63.960, + "acc@5": 84.130, } }, }, @@ -361,8 +364,8 @@ class R2Plus1D_18_Weights(WeightsEnum): "num_params": 31505325, "_metrics": { "Kinetics-400": { - "acc@1": 57.50, - "acc@5": 78.81, + "acc@1": 67.463, + "acc@5": 86.175, } }, },