Skip to content

torch.nn.InstanceNorm2d fails at runtime  #4669

@wojtke

Description

@wojtke

🐛 Describe the bug

I found that torch.nn.InstanceNorm2d layer, which by default does not track running mean and var, is failing at runtime. It successfully goes through the lowering process with no errors.

I see that the error is caused by some assert assuming that aten__native_batch_norm_legit_no_stats is in training mode, even if it is not. The same goes for torch.nn.BatchNorm2d with track_running_stats=False.

Minimal case to reproduce:

import torch
from executorch.exir import to_edge

class Model(torch.nn.Module):
  def __init__(self):
    super(Model, self).__init__()
    self.gn = torch.nn.InstanceNorm2d(3)

  def forward(self, x: torch.Tensor):
      return self.gn(x)

model = Model().eval()
sample_inputs = (
   torch.rand(1, 3, 3, 3),
)

aten_dialect = torch.export.export(model, sample_inputs)
edge_program = to_edge(aten_dialect)
print(edge_program.exported_program())
executorch_program = edge_program.to_executorch()

with open("instance_norm.pte", "wb") as file:
    executorch_program.write_to_file(file)

The exported program at edge dialect step:

ExportedProgram:
   class GraphModule(torch.nn.Module):
       def forward(self, x: "f32[1, 3, 3, 3]"):
            # File: /Users/woj/executorch_main/executorch/test.py:11 in forward, code: return self.gn(x)
           aten_view_copy_default: "f32[1, 3, 3, 3]" = executorch_exir_dialects_edge__ops_aten_view_copy_default(x, [1, 3, 3, 3]);  x = None
           aten__native_batch_norm_legit_no_stats = executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_stats(aten_view_copy_default, None, None, True, 0.1, 1e-05);  aten_view_copy_default = None
           getitem: "f32[1, 3, 3, 3]" = aten__native_batch_norm_legit_no_stats[0];  aten__native_batch_norm_legit_no_stats = None
           aten_view_copy_default_1: "f32[1, 3, 3, 3]" = executorch_exir_dialects_edge__ops_aten_view_copy_default(getitem, [1, 3, 3, 3]);  getitem = None
           return (aten_view_copy_default_1,)
           
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_view_copy_default_1'), target=None)])
Range constraints: {}

I built the runner using:

rm -rf cmake-out && mkdir cmake-out                       
cmake \
   -DCMAKE_INSTALL_PREFIX=cmake-out \
   -DCMAKE_BUILD_TYPE=Release \
   -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
   -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
   -DEXECUTORCH_ENABLE_LOGGING=ON \
   -DPYTHON_EXECUTABLE=python \
   -Bcmake-out .

cmake --build cmake-out -j16 --target install --config Release

Then when I run the lowered model:

./cmake-out/executor_runner --model_path instance_norm.pte
I 00:00:00.000259 executorch:executor_runner.cpp:73] Model file instance_norm.pte is loaded.
I 00:00:00.000271 executorch:executor_runner.cpp:82] Using method forward
I 00:00:00.000277 executorch:executor_runner.cpp:129] Setting up planned buffer 0, size 256.
I 00:00:00.000303 executorch:executor_runner.cpp:152] Method loaded.
I 00:00:00.000310 executorch:executor_runner.cpp:162] Inputs prepared.
E 00:00:00.000316 executorch:op_native_batch_norm.cpp:178] Check failed (training == false): Portable kernels only support inference mode!
E 00:00:00.000321 executorch:method.cpp:1068] KernelCall failed at instruction 0:1 in operator aten::_native_batch_norm_legit.no_stats_out: 0x12
E 00:00:00.000326 executorch:method.cpp:1074] arg 0 with type id 1
E 00:00:00.000329 executorch:method.cpp:1074] arg 1 with type id 0
E 00:00:00.000331 executorch:method.cpp:1074] arg 2 with type id 0
E 00:00:00.000333 executorch:method.cpp:1074] arg 3 with type id 5
E 00:00:00.000336 executorch:method.cpp:1074] arg 4 with type id 3
E 00:00:00.000338 executorch:method.cpp:1074] arg 5 with type id 3
E 00:00:00.000340 executorch:method.cpp:1074] arg 6 with type id 1
E 00:00:00.000342 executorch:method.cpp:1074] arg 7 with type id 1
E 00:00:00.000344 executorch:method.cpp:1074] arg 8 with type id 1
E 00:00:00.000346 executorch:method.cpp:1074] arg 9 with type id 9
F 00:00:00.000348 executorch:executor_runner.cpp:170] In function main(), assert failed (status == Error::Ok): Execution of method forward failed with status 0x12
[1]    91328 abort      ./cmake-out/executor_runner --model_path instance_norm.pte

Versions

PyTorch version: 2.5.0.dev20240716
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.5 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.30.2
Libc version: N/A

Python version: 3.10.14 (main, May  6 2024, 14:42:37) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-14.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M3 Pro

Versions of relevant libraries:
[pip3] executorch==0.4.0a0+99e1ae1
[pip3] numpy==1.21.3
[pip3] torch==2.5.0.dev20240716
[pip3] torchaudio==2.4.0.dev20240716
[pip3] torchsr==1.0.4
[pip3] torchvision==0.20.0.dev20240716
[conda] executorch                0.4.0a0+99e1ae1          pypi_0    pypi
[conda] numpy                     1.21.3                   pypi_0    pypi
[conda] torch                     2.5.0.dev20240716          pypi_0    pypi
[conda] torchaudio                2.4.0.dev20240716          pypi_0    pypi
[conda] torchsr                   1.0.4                    pypi_0    pypi
[conda] torchvision               0.20.0.dev20240716          pypi_0    pypi

Metadata

Metadata

Labels

module: kernelsIssues related to kernel libraries and utilities, and code under kernels/triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions