Skip to content

Commit b7ed683

Browse files
committed
Try to format code as in #5106
1 parent 94c7dde commit b7ed683

File tree

188 files changed

+553
-740
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

188 files changed

+553
-740
lines changed

references/classification/train_quantization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torchvision
1010
import utils
1111
from torch import nn
12-
from train import train_one_epoch, evaluate, load_data
12+
from train import evaluate, load_data, train_one_epoch
1313

1414

1515
def main(args):

references/detection/group_by_aspect_ratio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import copy
33
import math
44
from collections import defaultdict
5-
from itertools import repeat, chain
5+
from itertools import chain, repeat
66

77
import numpy as np
88
import torch

references/detection/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
import torchvision.models.detection.mask_rcnn
3030
import utils
3131
from coco_utils import get_coco, get_coco_kp
32-
from engine import train_one_epoch, evaluate
33-
from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
32+
from engine import evaluate, train_one_epoch
33+
from group_by_aspect_ratio import create_aspect_ratio_groups, GroupedBatchSampler
3434
from torchvision.transforms import InterpolationMode
3535
from transforms import SimpleCopyPaste
3636

references/detection/transforms.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
from typing import List, Tuple, Dict, Optional, Union
1+
from typing import Dict, List, Optional, Tuple, Union
22

33
import torch
44
import torchvision
55
from torch import nn, Tensor
66
from torchvision import ops
7-
from torchvision.transforms import functional as F
8-
from torchvision.transforms import transforms as T, InterpolationMode
7+
from torchvision.transforms import functional as F, InterpolationMode, transforms as T
98

109

1110
def _flip_coco_person_keypoints(kps, width):

references/optical_flow/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import torch
77
import torchvision.models.optical_flow
88
import utils
9-
from presets import OpticalFlowPresetTrain, OpticalFlowPresetEval
10-
from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K
9+
from presets import OpticalFlowPresetEval, OpticalFlowPresetTrain
10+
from torchvision.datasets import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
1111

1212

1313
def get_train_dataset(stage, dataset_root):

references/optical_flow/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import datetime
22
import os
33
import time
4-
from collections import defaultdict
5-
from collections import deque
4+
from collections import defaultdict, deque
65

76
import torch
87
import torch.distributed as dist
@@ -158,7 +157,7 @@ def log_every(self, iterable, print_freq=5, header=None):
158157
def compute_metrics(flow_pred, flow_gt, valid_flow_mask=None):
159158

160159
epe = ((flow_pred - flow_gt) ** 2).sum(dim=1).sqrt()
161-
flow_norm = (flow_gt ** 2).sum(dim=1).sqrt()
160+
flow_norm = (flow_gt**2).sum(dim=1).sqrt()
162161

163162
if valid_flow_mask is not None:
164163
epe = epe[valid_flow_mask]
@@ -183,7 +182,7 @@ def sequence_loss(flow_preds, flow_gt, valid_flow_mask, gamma=0.8, max_flow=400)
183182
raise ValueError(f"Gamma should be < 1, got {gamma}.")
184183

185184
# exlude invalid pixels and extremely large diplacements
186-
flow_norm = torch.sum(flow_gt ** 2, dim=1).sqrt()
185+
flow_norm = torch.sum(flow_gt**2, dim=1).sqrt()
187186
valid_flow_mask = valid_flow_mask & (flow_norm < max_flow)
188187

189188
valid_flow_mask = valid_flow_mask[:, None, :, :]

references/segmentation/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def update(self, a, b):
7575
with torch.inference_mode():
7676
k = (a >= 0) & (a < n)
7777
inds = n * a[k].to(torch.int64) + b[k]
78-
self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n)
78+
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
7979

8080
def reset(self):
8181
self.mat.zero_()

test/builtin_dataset_mocks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
import unittest.mock
1515
import warnings
1616
import xml.etree.ElementTree as ET
17-
from collections import defaultdict, Counter
17+
from collections import Counter, defaultdict
1818

1919
import numpy as np
2020
import pytest
2121
import torch
22-
from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file, combinations_grid
22+
from datasets_utils import combinations_grid, create_image_file, create_image_folder, make_tar, make_zip
2323
from torch.nn.functional import one_hot
2424
from torch.testing import make_tensor as _make_tensor
2525
from torchvision.prototype import datasets

test/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import pytest
55
import torch
6-
from common_utils import IN_CIRCLE_CI, CIRCLECI_GPU_NO_CUDA_MSG, IN_FBCODE, IN_RE_WORKER, CUDA_NOT_AVAILABLE_MSG
6+
from common_utils import CIRCLECI_GPU_NO_CUDA_MSG, CUDA_NOT_AVAILABLE_MSG, IN_CIRCLE_CI, IN_FBCODE, IN_RE_WORKER
77

88

99
def pytest_configure(config):

test/datasets_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import torch
2323
import torchvision.datasets
2424
import torchvision.io
25-
from common_utils import get_tmp_dir, disable_console_output
25+
from common_utils import disable_console_output, get_tmp_dir
2626

2727

2828
__all__ = [

0 commit comments

Comments
 (0)