Skip to content

transforms.ColorJitter().get_params(...) does not support float inputs #2669

@sagadre

Description

@sagadre

🐛 Bug

The docstring for transforms.ColorJitter().get_params(...) states that it accepts float and tuple inputs, just as ColorJitter's __init__(...), however, the current implementation supports tuple inputs only. Unlike the constructor it does not support float inputs.

To Reproduce

Steps to reproduce the behavior:

`from torchvision import transforms as T

t = T.ColorJitter()
t = t.get_params(0.4, 0.4, 0.4, 0.2)`

Expected behavior

the .get_params(...) should take float inputs.

  • PyTorch / torchvision Version (e.g., 1.0 / 0.4.0): 1.4.0
  • OS (e.g., Linux): Linux
  • How you installed PyTorch / torchvision (conda, pip, source): pip
  • Build command you used (if compiling from source): N/A
  • Python version: 3.8
  • CUDA/cuDNN version: 10.1
  • GPU models and configuration: GeForce 2080ti
  • Any other relevant information: N/A

cc @vfdev-5

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions