Skip to content

Add compliance suite skeleton and operator tests #11955

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: gh/GregoryComer/65/head
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions backends/test/compliance_suite/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Operator Compliance Test Suite

This directory contains operator tests that all backends are expected to pass. While not every backend will implement every operator or permutation, the expectation is that backend partitioners will only partition nodes that the backend can support. The partitioner should never error out due to not supporting an input node.

## Backend Registration

To plug into the test framework, each backend should provide an implementation of the Tester class, defined in backends/test/harness/tester.py. Backends can provide implementations of each stage, or use the default implementation, as appropriate.

At a minimum, the backend will likely need to provide a custom implementation of the Partition and ToEdgeTransformAndLower stages using the appropriate backend partitioner. See backends/xnnpack/test/tester/tester.py for an example implementation.

Once a tester is available, the backend flow(s) can be added in __init__.py in this directory by adding an entry to `ALL_TESTER_FLOWS`. Each flow entry consists of a name (used in the test case naming) and a function to instantiate a tester for a given model and input tuple.

## Test Cases

Operator test cases are defined under the operators/ directory. Tests are written in a backend-independent manner, and each test is programmatically expanded to generate a variant for each registered backend flow. The `@operator_test` decorator is applied to each test class to trigger this behavior. Tests can also be tagged with an appropriate type specifier, such as `@dtype_test`, to generate variants for each dtype. The decorators and "magic" live in __init__.py in this directory.
3 changes: 3 additions & 0 deletions backends/test/compliance_suite/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
load(":targets.bzl", "define_common_targets")

define_common_targets(is_fbcode = True)
135 changes: 135 additions & 0 deletions backends/test/compliance_suite/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import os
import unittest

from enum import Enum
from typing import Any, Callable, Tuple

import logging
import torch
from executorch.backends.test.harness import Tester

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


# Read enabled backends from the environment variable. Enable all if
# not specified (signalled by None).
def get_enabled_backends():
et_test_backends = os.environ.get("ET_TEST_BACKENDS")
if et_test_backends is not None:
return et_test_backends.split(",")
else:
return None

_ENABLED_BACKENDS = get_enabled_backends()

def is_backend_enabled(backend):
if _ENABLED_BACKENDS is None:
return True
else:
return backend in _ENABLED_BACKENDS

ALL_TEST_FLOWS = []

if is_backend_enabled("xnnpack"):
from executorch.backends.xnnpack.test.tester import Tester as XnnpackTester

XNNPACK_TEST_FLOW = ("xnnpack", XnnpackTester)
ALL_TEST_FLOWS.append(XNNPACK_TEST_FLOW)

if is_backend_enabled("coreml"):
from executorch.backends.apple.coreml.test.tester import CoreMLTester

COREML_TEST_FLOW = ("coreml", CoreMLTester)
ALL_TEST_FLOWS.append(COREML_TEST_FLOW)


DTYPES = [
torch.int8,
torch.uint8,
torch.int16,
torch.uint16,
torch.int32,
torch.uint32,
torch.int64,
torch.uint64,
torch.float16,
torch.float32,
torch.float64,
]

class TestType(Enum):
STANDARD = 1
DTYPE = 2

def dtype_test(func):
setattr(func, "test_type", TestType.DTYPE)
return func

def operator_test(cls):
_create_tests(cls)
return cls

def _create_tests(cls):
for key in dir(cls):
if key.startswith("test_"):
_expand_test(cls, key)

def _expand_test(cls, test_name: str):
test_func = getattr(cls, test_name)
for (flow_name, tester_factory) in ALL_TEST_FLOWS:
_create_test_for_backend(cls, test_func, flow_name, tester_factory)
delattr(cls, test_name)

def _create_test_for_backend(
cls,
test_func: Callable,
flow_name: str,
tester_factory: Callable[[torch.nn.Module, Tuple[Any]], Tester]
):
test_type = getattr(test_func, "test_type", TestType.STANDARD)

if test_type == TestType.STANDARD:
def wrapped_test(self):
test_func(self, tester_factory)

test_name = f"{test_func.__name__}_{flow_name}"
setattr(cls, test_name, wrapped_test)
elif test_type == TestType.DTYPE:
for dtype in DTYPES:
def wrapped_test(self):
test_func(self, dtype, tester_factory)

dtype_name = str(dtype)[6:] # strip "torch."
test_name = f"{test_func.__name__}_{dtype_name}_{flow_name}"
setattr(cls, test_name, wrapped_test)
else:
raise NotImplementedError(f"Unknown test type {test_type}.")


class OperatorTest(unittest.TestCase):
def _test_op(self, model, inputs, tester_factory):
tester = (
tester_factory(
model,
inputs,
)
.export()
.to_edge_transform_and_lower()
)

is_delegated = any(
n.target == torch._higher_order_ops.executorch_call_delegate
for n in tester.stages[tester.cur].graph_module.graph.nodes
if n.op == "call_function"
)

# Only run the runtime test if the op was delegated.
if is_delegated:
(
tester
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)

Empty file.
74 changes: 74 additions & 0 deletions backends/test/compliance_suite/operators/test_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

# pyre-strict

from typing import Callable

import torch

from executorch.backends.test.compliance_suite import (
dtype_test,
operator_test,
OperatorTest,
)

class Model(torch.nn.Module):
def forward(self, x, y):
return x + y

class ModelAlpha(torch.nn.Module):
def __init__(self, alpha):
super().__init__()
self.alpha = alpha

def forward(self, x, y):
return torch.add(x, y, alpha=self.alpha)

@operator_test
class Add(OperatorTest):
@dtype_test
def test_add_dtype(self, dtype, tester_factory: Callable) -> None:
self._test_op(
Model(),
(
(torch.rand(2, 10) * 100).to(dtype),
(torch.rand(2, 10) * 100).to(dtype),
),
tester_factory)

def test_add_f32_bcast_first(self, tester_factory: Callable) -> None:
self._test_op(
Model(),
(
torch.randn(5),
torch.randn(1, 5, 1, 5),
),
tester_factory)

def test_add_f32_bcast_second(self, tester_factory: Callable) -> None:
self._test_op(
Model(),
(
torch.randn(4, 4, 2, 7),
torch.randn(2, 7),
),
tester_factory)

def test_add_f32_bcast_unary(self, tester_factory: Callable) -> None:
self._test_op(
Model(),
(
torch.randn(5),
torch.randn(1, 1, 5),
),
tester_factory)

def test_add_f32_alpha(self, tester_factory: Callable) -> None:
self._test_op(
ModelAlpha(alpha=2),
(
torch.randn(1, 25),
torch.randn(1, 25),
),
tester_factory)

82 changes: 82 additions & 0 deletions backends/test/compliance_suite/operators/test_div.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

# pyre-strict

from typing import Callable, Optional

import torch

from executorch.backends.test.compliance_suite import (
dtype_test,
operator_test,
OperatorTest,
)

class Model(torch.nn.Module):
def forward(self, x, y):
return x / y

class ModelWithRounding(torch.nn.Module):
def __init__(self, rounding_mode: Optional[str]):
super().__init__()
self.rounding_mode = rounding_mode

def forward(self, x, y):
return torch.div(x, y, rounding_mode=self.rounding_mode)

@operator_test
class Divide(OperatorTest):
@dtype_test
def test_divide_dtype(self, dtype, tester_factory: Callable) -> None:
self._test_op(
Model(),
(
(torch.rand(2, 10) * 100).to(dtype),
(torch.rand(2, 10) * 100 + 0.1).to(dtype), # Adding 0.1 to avoid division by zero
),
tester_factory)

def test_divide_f32_bcast_first(self, tester_factory: Callable) -> None:
self._test_op(
Model(),
(
torch.randn(5),
torch.randn(1, 5, 1, 5).abs() + 0.1, # Using abs and adding 0.1 to avoid division by zero
),
tester_factory)

def test_divide_f32_bcast_second(self, tester_factory: Callable) -> None:
self._test_op(
Model(),
(
torch.randn(4, 4, 2, 7),
torch.randn(2, 7).abs() + 0.1, # Using abs and adding 0.1 to avoid division by zero
),
tester_factory)

def test_divide_f32_bcast_unary(self, tester_factory: Callable) -> None:
self._test_op(
Model(),
(
torch.randn(5),
torch.randn(1, 1, 5).abs() + 0.1, # Using abs and adding 0.1 to avoid division by zero
),
tester_factory)

def test_divide_f32_trunc(self, tester_factory: Callable) -> None:
self._test_op(
ModelWithRounding(rounding_mode="trunc"),
(
torch.randn(3, 4) * 10,
torch.randn(3, 4).abs() + 0.1, # Using abs and adding 0.1 to avoid division by zero
),
tester_factory)

def test_divide_f32_floor(self, tester_factory: Callable) -> None:
self._test_op(
ModelWithRounding(rounding_mode="floor"),
(
torch.randn(3, 4) * 10,
torch.randn(3, 4).abs() + 0.1, # Using abs and adding 0.1 to avoid division by zero
),
tester_factory)
56 changes: 56 additions & 0 deletions backends/test/compliance_suite/operators/test_mul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

# pyre-strict

from typing import Callable

import torch

from executorch.backends.test.compliance_suite import (
dtype_test,
operator_test,
OperatorTest,
)

class Model(torch.nn.Module):
def forward(self, x, y):
return x * y

@operator_test
class Multiply(OperatorTest):
@dtype_test
def test_multiply_dtype(self, dtype, tester_factory: Callable) -> None:
self._test_op(
Model(),
(
(torch.rand(2, 10) * 100).to(dtype),
(torch.rand(2, 10) * 100).to(dtype),
),
tester_factory)

def test_multiply_f32_bcast_first(self, tester_factory: Callable) -> None:
self._test_op(
Model(),
(
torch.randn(5),
torch.randn(1, 5, 1, 5),
),
tester_factory)

def test_multiply_f32_bcast_second(self, tester_factory: Callable) -> None:
self._test_op(
Model(),
(
torch.randn(4, 4, 2, 7),
torch.randn(2, 7),
),
tester_factory)

def test_multiply_f32_bcast_unary(self, tester_factory: Callable) -> None:
self._test_op(
Model(),
(
torch.randn(5),
torch.randn(1, 1, 5),
),
tester_factory)
Loading
Loading