Skip to content

Adding align_corners to jax.image.resize #2214

@younesbelkada

Description

@younesbelkada

Hi there !

I would like to reproduce the operations that are done under torch.nn.Upsample() with flax 🎉 . In PyTorch, it seems that the flag align_corners doesn't mean "align the corners" but rather "sample with equal spacing" cc @cgarciae ! Can we add this feature in jax.image.resize? 🙏

Problem description:

Ideally:

import torch
import torch.nn as nn
import jax.numpy as jnp
import jax
import numpy as np

input_arr = jnp.array([
    [0, 1, 2, 3, 4],
    [5, 6, 7, 8, 9],
]) / 9

torch_input_arr = torch.from_numpy(np.array(input_arr)).unsqueeze(0).unsqueeze(0)
upsample_with_align_corners = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
upsample_without_align_corners = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)

output = jax.image.resize(
    image=input_arr,
    shape=(4,10),
    method="bilinear")

output_torch_upsample_with_align_corners = upsample_with_align_corners(torch_input_arr)
output_torch_upsample_without_align_corners = upsample_without_align_corners(torch_input_arr)

print("jax: ", output)
print("torch upsample align corner:", output_torch_upsample_with_align_corners)
print("torch upsample without align corners",output_torch_upsample_without_align_corners)

I would like to match output_torch_upsample_with_align_corners and flax' output. Here is what I get:

# jax default output
DeviceArray([[0.        , 0.02777778, 0.08333334, 0.1388889 , 0.19444445,
              0.25      , 0.30555555, 0.3611111 , 0.4166667 , 0.44444445],
             [0.1388889 , 0.16666667, 0.22222222, 0.2777778 , 0.33333334,
              0.3888889 , 0.44444442, 0.5       , 0.5555556 , 0.5833334 ],
             [0.4166667 , 0.44444448, 0.5       , 0.5555555 , 0.6111111 ,
              0.6666667 , 0.7222222 , 0.7777778 , 0.8333333 , 0.8611111 ],
             [0.5555556 , 0.5833334 , 0.6388889 , 0.6944444 , 0.75      ,
              0.8055556 , 0.8611111 , 0.9166667 , 0.9722222 , 1.        ]],            dtype=float32)
# torch with align corner
tensor([[[[0.0000, 0.0494, 0.0988, 0.1481, 0.1975, 0.2469, 0.2963, 0.3457,
           0.3951, 0.4444],
          [0.1852, 0.2346, 0.2840, 0.3333, 0.3827, 0.4321, 0.4815, 0.5309,
           0.5802, 0.6296],
          [0.3704, 0.4198, 0.4691, 0.5185, 0.5679, 0.6173, 0.6667, 0.7160,
           0.7654, 0.8148],
          [0.5556, 0.6049, 0.6543, 0.7037, 0.7531, 0.8025, 0.8519, 0.9012,
           0.9506, 1.0000]]]])
# torch without align corner
tensor([[[[0.0000, 0.0278, 0.0833, 0.1389, 0.1944, 0.2500, 0.3056, 0.3611,
           0.4167, 0.4444],
          [0.1389, 0.1667, 0.2222, 0.2778, 0.3333, 0.3889, 0.4444, 0.5000,
           0.5556, 0.5833],
          [0.4167, 0.4444, 0.5000, 0.5556, 0.6111, 0.6667, 0.7222, 0.7778,
           0.8333, 0.8611],
          [0.5556, 0.5833, 0.6389, 0.6944, 0.7500, 0.8056, 0.8611, 0.9167,
           0.9722, 1.0000]]]])

Motivation

I am converting Dense Prediction Transformers into flax, and would like to match the output betwen PyTorch's model and flax' implementation! huggingface/transformers#17779

cc @cgarciae

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions