33import time
44import warnings
55
6+ import datasets
67import presets
78import torch
89import torch .utils .data
1112import utils
1213from torch import nn
1314from torch .utils .data .dataloader import default_collate
14- from torchvision .datasets .samplers import DistributedSampler , UniformClipSampler , RandomClipSampler
15+ from torchvision .datasets .samplers import DistributedSampler , RandomClipSampler , UniformClipSampler
1516
1617
1718def train_one_epoch (model , criterion , optimizer , lr_scheduler , data_loader , device , epoch , print_freq , scaler = None ):
@@ -21,7 +22,7 @@ def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, devi
2122 metric_logger .add_meter ("clips/s" , utils .SmoothedValue (window_size = 10 , fmt = "{value:.3f}" ))
2223
2324 header = f"Epoch: [{ epoch } ]"
24- for video , target in metric_logger .log_every (data_loader , print_freq , header ):
25+ for video , target , _ in metric_logger .log_every (data_loader , print_freq , header ):
2526 start_time = time .time ()
2627 video , target = video .to (device ), target .to (device )
2728 with torch .cuda .amp .autocast (enabled = scaler is not None ):
@@ -52,13 +53,25 @@ def evaluate(model, criterion, data_loader, device):
5253 metric_logger = utils .MetricLogger (delimiter = " " )
5354 header = "Test:"
5455 num_processed_samples = 0
56+ # Group and aggregate output of a video
57+ num_videos = len (data_loader .dataset .samples )
58+ num_classes = len (data_loader .dataset .classes )
59+ agg_preds = torch .zeros ((num_videos , num_classes ), dtype = torch .float32 , device = device )
60+ agg_targets = torch .zeros ((num_videos ), dtype = torch .int32 , device = device )
5561 with torch .inference_mode ():
56- for video , target in metric_logger .log_every (data_loader , 100 , header ):
62+ for video , target , video_idx in metric_logger .log_every (data_loader , 100 , header ):
5763 video = video .to (device , non_blocking = True )
5864 target = target .to (device , non_blocking = True )
5965 output = model (video )
6066 loss = criterion (output , target )
6167
68+ # Use softmax to convert output into prediction probability
69+ preds = torch .softmax (output , dim = 1 )
70+ for b in range (video .size (0 )):
71+ idx = video_idx [b ].item ()
72+ agg_preds [idx ] += preds [b ].detach ()
73+ agg_targets [idx ] = target [b ].detach ().item ()
74+
6275 acc1 , acc5 = utils .accuracy (output , target , topk = (1 , 5 ))
6376 # FIXME need to take into account that the datasets
6477 # could have been padded in distributed setup
@@ -95,6 +108,11 @@ def evaluate(model, criterion, data_loader, device):
95108 top1 = metric_logger .acc1 , top5 = metric_logger .acc5
96109 )
97110 )
111+ # Reduce the agg_preds and agg_targets from all gpu and show result
112+ agg_preds = utils .reduce_across_processes (agg_preds )
113+ agg_targets = utils .reduce_across_processes (agg_targets , op = torch .distributed .ReduceOp .MAX )
114+ agg_acc1 , agg_acc5 = utils .accuracy (agg_preds , agg_targets , topk = (1 , 5 ))
115+ print (" * Video Acc@1 {acc1:.3f} Video Acc@5 {acc5:.3f}" .format (acc1 = agg_acc1 , acc5 = agg_acc5 ))
98116 return metric_logger .acc1 .global_avg
99117
100118
@@ -110,7 +128,7 @@ def _get_cache_path(filepath, args):
110128
111129def collate_fn (batch ):
112130 # remove audio from the batch
113- batch = [(d [0 ], d [2 ]) for d in batch ]
131+ batch = [(d [0 ], d [2 ], d [ 3 ] ) for d in batch ]
114132 return default_collate (batch )
115133
116134
@@ -146,7 +164,7 @@ def main(args):
146164 else :
147165 if args .distributed :
148166 print ("It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster" )
149- dataset = torchvision . datasets .Kinetics (
167+ dataset = datasets .KineticsWithVideoId (
150168 args .data_path ,
151169 frames_per_clip = args .clip_len ,
152170 num_classes = args .kinetics_version ,
@@ -183,7 +201,7 @@ def main(args):
183201 else :
184202 if args .distributed :
185203 print ("It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster" )
186- dataset_test = torchvision . datasets .Kinetics (
204+ dataset_test = datasets .KineticsWithVideoId (
187205 args .data_path ,
188206 frames_per_clip = args .clip_len ,
189207 num_classes = args .kinetics_version ,
@@ -313,10 +331,10 @@ def main(args):
313331 print (f"Training time { total_time_str } " )
314332
315333
316- def parse_args ( ):
334+ def get_args_parser ( add_help = True ):
317335 import argparse
318336
319- parser = argparse .ArgumentParser (description = "PyTorch Video Classification Training" )
337+ parser = argparse .ArgumentParser (description = "PyTorch Video Classification Training" , add_help = add_help )
320338
321339 parser .add_argument ("--data-path" , default = "/datasets01_101/kinetics/070618/" , type = str , help = "dataset path" )
322340 parser .add_argument (
@@ -387,11 +405,9 @@ def parse_args():
387405 # Mixed precision training parameters
388406 parser .add_argument ("--amp" , action = "store_true" , help = "Use torch.cuda.amp for mixed precision training" )
389407
390- args = parser .parse_args ()
391-
392- return args
408+ return parser
393409
394410
395411if __name__ == "__main__" :
396- args = parse_args ()
412+ args = get_args_parser (). parse_args ()
397413 main (args )
0 commit comments