Skip to content

MultiScaleRoIAlign creates a type mismatch when using mixed precision training with Apex #1335

@Anjum48

Description

@Anjum48

🐛 Bug

To Reproduce

Steps to reproduce the behavior:

  1. Use this tutorial https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
  2. Initialize the model & optimizer using model, optimizer = apex.amp.initialize(model, optimizer, opt_level="O1")
  3. Modify engine.py to use a scaled loss:
optimizer.zero_grad()
# losses.backward()

with amp.scale_loss(losses, optimizer) as scaled_loss:
    scaled_loss.backward()

Expected behavior

There is a type mismatch here: https://github.com/pytorch/vision/blob/master/torchvision/ops/poolers.py#L161 with per_level_feature being float16 and rois_per_level being float32

I made a workaround like this:

if per_level_feature.dtype != rois_per_level.dtype:
    rois_per_level = rois_per_level.type(per_level_feature.dtype)

But I suspect there is a more robust way or ensuring the ROI data types are consistent

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions