Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions references/video_classification/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Tuple

import torchvision
from torch import Tensor


class KineticsWithVideoId(torchvision.datasets.Kinetics):
def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]:
video, audio, info, video_idx = self.video_clips.get_clip(idx)
label = self.samples[video_idx][1]

if self.transform is not None:
video = self.transform(video)

return video, audio, label, video_idx
40 changes: 28 additions & 12 deletions references/video_classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
import warnings

import datasets
import presets
import torch
import torch.utils.data
Expand All @@ -11,7 +12,7 @@
import utils
from torch import nn
from torch.utils.data.dataloader import default_collate
from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler
from torchvision.datasets.samplers import DistributedSampler, RandomClipSampler, UniformClipSampler


def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None):
Expand All @@ -21,7 +22,7 @@ def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, devi
metric_logger.add_meter("clips/s", utils.SmoothedValue(window_size=10, fmt="{value:.3f}"))

header = f"Epoch: [{epoch}]"
for video, target in metric_logger.log_every(data_loader, print_freq, header):
for video, target, _ in metric_logger.log_every(data_loader, print_freq, header):
start_time = time.time()
video, target = video.to(device), target.to(device)
with torch.cuda.amp.autocast(enabled=scaler is not None):
Expand Down Expand Up @@ -52,13 +53,25 @@ def evaluate(model, criterion, data_loader, device):
metric_logger = utils.MetricLogger(delimiter=" ")
header = "Test:"
num_processed_samples = 0
# Group and aggregate output of a video
num_videos = len(data_loader.dataset.samples)
num_classes = len(data_loader.dataset.classes)
agg_preds = torch.zeros((num_videos, num_classes), dtype=torch.float32, device=device)
agg_targets = torch.zeros((num_videos), dtype=torch.int32, device=device)
with torch.inference_mode():
for video, target in metric_logger.log_every(data_loader, 100, header):
for video, target, video_idx in metric_logger.log_every(data_loader, 100, header):
video = video.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
output = model(video)
loss = criterion(output, target)

# Use softmax to convert output into prediction probability
preds = torch.softmax(output, dim=1)
for b in range(video.size(0)):
idx = video_idx[b].item()
agg_preds[idx] += preds[b].detach()
agg_targets[idx] = target[b].detach().item()

acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
# FIXME need to take into account that the datasets
# could have been padded in distributed setup
Expand Down Expand Up @@ -95,6 +108,11 @@ def evaluate(model, criterion, data_loader, device):
top1=metric_logger.acc1, top5=metric_logger.acc5
)
)
# Reduce the agg_preds and agg_targets from all gpu and show result
agg_preds = utils.reduce_across_processes(agg_preds)
agg_targets = utils.reduce_across_processes(agg_targets, op=torch.distributed.ReduceOp.MAX)
agg_acc1, agg_acc5 = utils.accuracy(agg_preds, agg_targets, topk=(1, 5))
print(" * Video Acc@1 {acc1:.3f} Video Acc@5 {acc5:.3f}".format(acc1=agg_acc1, acc5=agg_acc5))
return metric_logger.acc1.global_avg


Expand All @@ -110,7 +128,7 @@ def _get_cache_path(filepath, args):

def collate_fn(batch):
# remove audio from the batch
batch = [(d[0], d[2]) for d in batch]
batch = [(d[0], d[2], d[3]) for d in batch]
return default_collate(batch)


Expand Down Expand Up @@ -146,7 +164,7 @@ def main(args):
else:
if args.distributed:
print("It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster")
dataset = torchvision.datasets.Kinetics(
dataset = datasets.KineticsWithVideoId(
args.data_path,
frames_per_clip=args.clip_len,
num_classes=args.kinetics_version,
Expand Down Expand Up @@ -183,7 +201,7 @@ def main(args):
else:
if args.distributed:
print("It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster")
dataset_test = torchvision.datasets.Kinetics(
dataset_test = datasets.KineticsWithVideoId(
args.data_path,
frames_per_clip=args.clip_len,
num_classes=args.kinetics_version,
Expand Down Expand Up @@ -313,10 +331,10 @@ def main(args):
print(f"Training time {total_time_str}")


def parse_args():
def get_args_parser(add_help=True):
import argparse

parser = argparse.ArgumentParser(description="PyTorch Video Classification Training")
parser = argparse.ArgumentParser(description="PyTorch Video Classification Training", add_help=add_help)

parser.add_argument("--data-path", default="/datasets01_101/kinetics/070618/", type=str, help="dataset path")
parser.add_argument(
Expand Down Expand Up @@ -387,11 +405,9 @@ def parse_args():
# Mixed precision training parameters
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")

args = parser.parse_args()

return args
return parser


if __name__ == "__main__":
args = parse_args()
args = get_args_parser().parse_args()
main(args)
4 changes: 2 additions & 2 deletions references/video_classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,12 +253,12 @@ def init_distributed_mode(args):
setup_for_distributed(args.rank == 0)


def reduce_across_processes(val):
def reduce_across_processes(val, op=dist.ReduceOp.SUM):
if not is_dist_avail_and_initialized():
# nothing to sync, but we still convert to tensor for consistency with the distributed case.
return torch.tensor(val)

t = torch.tensor(val, device="cuda")
dist.barrier()
dist.all_reduce(t)
dist.all_reduce(t, op=op)
return t
9 changes: 6 additions & 3 deletions torchvision/models/video/mvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,12 +445,15 @@ class MViT_V1_B_Weights(WeightsEnum):
"min_temporal_size": 16,
"categories": _KINETICS400_CATEGORIES,
"recipe": "https://github.com/facebookresearch/pytorchvideo/blob/main/docs/source/model_zoo.md",
"_docs": """These weights support 16-frame clip inputs and were ported from the paper.""",
"_docs": (
"The weights were ported from the paper. The accuracies are estimated on video-level "
"with parameters `frame_rate=7.5`, `clips_per_video=5`, and `clip_len=16`"
),
"num_params": 36610672,
"_metrics": {
"Kinetics-400": {
"acc@1": 78.47,
"acc@5": 93.65,
"acc@1": 78.477,
"acc@5": 93.582,
}
},
},
Expand Down
17 changes: 10 additions & 7 deletions torchvision/models/video/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,10 @@ def _video_resnet(
"min_size": (1, 1),
"categories": _KINETICS400_CATEGORIES,
"recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification",
"_docs": """These weights reproduce closely the accuracy of the paper for 16-frame clip inputs.""",
"_docs": (
"The weights reproduce closely the accuracy of the paper. The accuracies are estimated on video-level "
"with parameters `frame_rate=15`, `clips_per_video=5`, and `clip_len=16`."
),
}


Expand All @@ -325,8 +328,8 @@ class R3D_18_Weights(WeightsEnum):
"num_params": 33371472,
"_metrics": {
"Kinetics-400": {
"acc@1": 52.75,
"acc@5": 75.45,
"acc@1": 63.200,
"acc@5": 83.479,
}
},
},
Expand All @@ -343,8 +346,8 @@ class MC3_18_Weights(WeightsEnum):
"num_params": 11695440,
"_metrics": {
"Kinetics-400": {
"acc@1": 53.90,
"acc@5": 76.29,
"acc@1": 63.960,
"acc@5": 84.130,
}
},
},
Expand All @@ -361,8 +364,8 @@ class R2Plus1D_18_Weights(WeightsEnum):
"num_params": 31505325,
"_metrics": {
"Kinetics-400": {
"acc@1": 57.50,
"acc@5": 78.81,
"acc@1": 67.463,
"acc@5": 86.175,
}
},
},
Expand Down