Skip to content

Commit ac61421

Browse files
committed
Merge branch 'master' of github.com:pytorch/vision into restore_cache
2 parents 8e8c64a + 61e00d5 commit ac61421

File tree

4 files changed

+22
-6
lines changed

4 files changed

+22
-6
lines changed

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ clear and has sufficient instructions to be able to reproduce the issue.
3333
### Install PyTorch Nightly
3434

3535
```bash
36-
conda install pytorch -c pytorch-nightly
36+
conda install pytorch -c pytorch-nightly -c conda-forge
3737
# or with pip (see https://pytorch.org/get-started/locally/)
3838
# pip install numpy
3939
# pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu102/torch_nightly.html

test/test_transforms.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,20 @@ def test_to_tensor(self):
620620
output = trans(img)
621621
self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))
622622

623+
def test_to_tensor_with_other_default_dtypes(self):
624+
current_def_dtype = torch.get_default_dtype()
625+
626+
t = transforms.ToTensor()
627+
np_arr = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)
628+
img = Image.fromarray(np_arr)
629+
630+
for dtype in [torch.float16, torch.float, torch.double]:
631+
torch.set_default_dtype(dtype)
632+
res = t(img)
633+
self.assertTrue(res.dtype == dtype, msg=f"{res.dtype} vs {dtype}")
634+
635+
torch.set_default_dtype(current_def_dtype)
636+
623637
def test_max_value(self):
624638
for dtype in int_dtypes():
625639
self.assertEqual(F_t._max_value(dtype), torch.iinfo(dtype).max)

torchvision/models/detection/backbone_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ def resnet_fpn_backbone(
9696
# select layers that wont be frozen
9797
assert 0 <= trainable_layers <= 5
9898
layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers]
99-
# freeze layers only if pretrained backbone is used
99+
if trainable_layers == 5:
100+
layers_to_train.append('bn1')
100101
for name, parameter in backbone.named_parameters():
101102
if all([not name.startswith(layer) for layer in layers_to_train]):
102103
parameter.requires_grad_(False)
@@ -152,7 +153,6 @@ def mobilenet_backbone(
152153
assert 0 <= trainable_layers <= num_stages
153154
freeze_before = num_stages if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]
154155

155-
# freeze layers only if pretrained backbone is used
156156
for b in backbone[:freeze_before]:
157157
for parameter in b.parameters():
158158
parameter.requires_grad_(False)

torchvision/transforms/functional.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ def to_tensor(pic):
104104
if _is_numpy(pic) and not _is_numpy_image(pic):
105105
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))
106106

107+
default_float_dtype = torch.get_default_dtype()
108+
107109
if isinstance(pic, np.ndarray):
108110
# handle numpy array
109111
if pic.ndim == 2:
@@ -112,12 +114,12 @@ def to_tensor(pic):
112114
img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
113115
# backward compatibility
114116
if isinstance(img, torch.ByteTensor):
115-
return img.float().div(255)
117+
return img.to(dtype=default_float_dtype).div(255)
116118
else:
117119
return img
118120

119121
if accimage is not None and isinstance(pic, accimage.Image):
120-
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
122+
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=default_float_dtype)
121123
pic.copyto(nppic)
122124
return torch.from_numpy(nppic)
123125

@@ -137,7 +139,7 @@ def to_tensor(pic):
137139
# put it from HWC to CHW format
138140
img = img.permute((2, 0, 1)).contiguous()
139141
if isinstance(img, torch.ByteTensor):
140-
return img.float().div(255)
142+
return img.to(dtype=default_float_dtype).div(255)
141143
else:
142144
return img
143145

0 commit comments

Comments
 (0)