-
Notifications
You must be signed in to change notification settings - Fork 755
Closed
Description
Hi there !
I would like to reproduce the operations that are done under torch.nn.Upsample() with flax 🎉 . In PyTorch, it seems that the flag align_corners doesn't mean "align the corners" but rather "sample with equal spacing" cc @cgarciae ! Can we add this feature in jax.image.resize? 🙏
Problem description:
Ideally:
import torch
import torch.nn as nn
import jax.numpy as jnp
import jax
import numpy as np
input_arr = jnp.array([
[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
]) / 9
torch_input_arr = torch.from_numpy(np.array(input_arr)).unsqueeze(0).unsqueeze(0)
upsample_with_align_corners = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
upsample_without_align_corners = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
output = jax.image.resize(
image=input_arr,
shape=(4,10),
method="bilinear")
output_torch_upsample_with_align_corners = upsample_with_align_corners(torch_input_arr)
output_torch_upsample_without_align_corners = upsample_without_align_corners(torch_input_arr)
print("jax: ", output)
print("torch upsample align corner:", output_torch_upsample_with_align_corners)
print("torch upsample without align corners",output_torch_upsample_without_align_corners)
I would like to match output_torch_upsample_with_align_corners and flax' output. Here is what I get:
# jax default output
DeviceArray([[0. , 0.02777778, 0.08333334, 0.1388889 , 0.19444445,
0.25 , 0.30555555, 0.3611111 , 0.4166667 , 0.44444445],
[0.1388889 , 0.16666667, 0.22222222, 0.2777778 , 0.33333334,
0.3888889 , 0.44444442, 0.5 , 0.5555556 , 0.5833334 ],
[0.4166667 , 0.44444448, 0.5 , 0.5555555 , 0.6111111 ,
0.6666667 , 0.7222222 , 0.7777778 , 0.8333333 , 0.8611111 ],
[0.5555556 , 0.5833334 , 0.6388889 , 0.6944444 , 0.75 ,
0.8055556 , 0.8611111 , 0.9166667 , 0.9722222 , 1. ]], dtype=float32)
# torch with align corner
tensor([[[[0.0000, 0.0494, 0.0988, 0.1481, 0.1975, 0.2469, 0.2963, 0.3457,
0.3951, 0.4444],
[0.1852, 0.2346, 0.2840, 0.3333, 0.3827, 0.4321, 0.4815, 0.5309,
0.5802, 0.6296],
[0.3704, 0.4198, 0.4691, 0.5185, 0.5679, 0.6173, 0.6667, 0.7160,
0.7654, 0.8148],
[0.5556, 0.6049, 0.6543, 0.7037, 0.7531, 0.8025, 0.8519, 0.9012,
0.9506, 1.0000]]]])
# torch without align corner
tensor([[[[0.0000, 0.0278, 0.0833, 0.1389, 0.1944, 0.2500, 0.3056, 0.3611,
0.4167, 0.4444],
[0.1389, 0.1667, 0.2222, 0.2778, 0.3333, 0.3889, 0.4444, 0.5000,
0.5556, 0.5833],
[0.4167, 0.4444, 0.5000, 0.5556, 0.6111, 0.6667, 0.7222, 0.7778,
0.8333, 0.8611],
[0.5556, 0.5833, 0.6389, 0.6944, 0.7500, 0.8056, 0.8611, 0.9167,
0.9722, 1.0000]]]])
Motivation
I am converting Dense Prediction Transformers into flax, and would like to match the output betwen PyTorch's model and flax' implementation! huggingface/transformers#17779
cc @cgarciae
Metadata
Metadata
Assignees
Labels
No labels