Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -876,14 +876,17 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
// Tensors with integer types need to be converted to signless integer
// element type. All tensors with element types other than integer can reuse
// existing elements attribute.
// TODO: what about unsigned integer?
if (auto elements = dyn_cast<DenseIntElementsAttr>(op.getValueAttr())) {
Type builtinTensorElemTy = resultType.getElementType();
unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth();
bool isUnsigned =
cast<IntegerType>(builtinTensorElemTy).isUnsignedInteger();

DenseElementsAttr valueAttr =
elements.mapValues(builtinTensorElemTy, [&](const APInt &v) {
return APInt(bitWidth, v.getSExtValue());
APInt intValue =
isUnsigned ? v.zextOrTrunc(bitWidth) : v.sextOrTrunc(bitWidth);
return intValue;
});
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, resultType,
valueAttr);
Expand Down
2 changes: 2 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1771,6 +1771,8 @@
# raise TimeoutError(self.error_message)
# TimeoutError: Timeout
"BertModule_basic",
"UInt8Tensor_basic",
"BoolTensor_basic",
}

# Write the TOSA set as a "passing" set as it is very early in development
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2102,3 +2102,46 @@ def forward(self, a):
@register_test_case(module_factory=lambda: AtenDiagEmbedNonDefault4DDiag())
def AtenDiagEmbedNonDefault4DDiag_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4, 5))


# ==============================================================================


class UInt8Tensor(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
]
)
def forward(self):
x = torch.tensor([128], dtype=torch.uint8)
return torch.ops.aten.to(x, dtype=torch.float32)


@register_test_case(module_factory=lambda: UInt8Tensor())
def UInt8Tensor_basic(module, tu: TestUtils):
module.forward()


class BoolTensor(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
]
)
def forward(self):
x = torch.tensor([True], dtype=torch.bool)
return torch.ops.aten.to(x, dtype=torch.float32)


@register_test_case(module_factory=lambda: BoolTensor())
def BoolTensor_basic(module, tu: TestUtils):
module.forward()
Loading