Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,8 @@ static LogicalResult verifyConvOpErrorIf(T op) {
return success();

const int64_t biasChannels = biasType.getDimSize(0);
const int64_t outputChannels = outputType.getDimSize(3);
const int64_t outputChannels =
outputType.getDimSize(outputType.getRank() - 1);
if (biasChannels == ShapedType::kDynamic ||
outputChannels == ShapedType::kDynamic)
// Skip following checks if biasChannels or outputChannels is dynamic dim
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/Tosa/availability.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %

// -----
// CHECK-LABEL: conv3d
func.func @test_conv3d(%arg0: tensor<1x4x8x21x17xf32>, %arg1: tensor<34x1x1x1x17xf32>, %arg2: tensor<21xf32>) -> tensor<1x4x8x21x34xf32> {
func.func @test_conv3d(%arg0: tensor<1x4x8x21x17xf32>, %arg1: tensor<34x1x1x1x17xf32>, %arg2: tensor<34xf32>) -> tensor<1x4x8x21x34xf32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
// CHECK: extensions: [ [int4, int16, fp8e4m3, fp8e5m2, bf16] ]
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%0 = tosa.conv3d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<21xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x8x21x34xf32>
%0 = tosa.conv3d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<34xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x8x21x34xf32>
return %0 : tensor<1x4x8x21x34xf32>
}

Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Dialect/Tosa/invalid_extension.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xi8>, %arg1: tensor<8x1x1x4xi4>, %ar
}

// -----
func.func @test_conv3d(%arg0: tensor<1x4x8x21x17xi16>, %arg1: tensor<34x1x1x1x17xi8>, %arg2: tensor<21xi48>, %arg3: tensor<1xi16>, %arg4: tensor<1xi8>) -> tensor<1x4x8x21x34xi48> {
func.func @test_conv3d(%arg0: tensor<1x4x8x21x17xi16>, %arg1: tensor<34x1x1x1x17xi8>, %arg2: tensor<34xi48>, %arg3: tensor<1xi16>, %arg4: tensor<1xi8>) -> tensor<1x4x8x21x34xi48> {
// expected-error@+1 {{'tosa.conv3d' op illegal: requires [int16] but not enabled in target}}
%0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = i48, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xi16>, tensor<34x1x1x1x17xi8>, tensor<21xi48>, tensor<1xi16>, tensor<1xi8>) -> tensor<1x4x8x21x34xi48>
%0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = i48, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xi16>, tensor<34x1x1x1x17xi8>, tensor<34xi48>, tensor<1xi16>, tensor<1xi8>) -> tensor<1x4x8x21x34xi48>
return %0 : tensor<1x4x8x21x34xi48>
}

Expand Down Expand Up @@ -445,10 +445,10 @@ func.func @test_conv2d_non_const_input_zp(%arg0: tensor<1x4x4x4xi8>, %arg1: tens

// -----

func.func @test_conv3d_non_const_weight_zp(%arg0: tensor<1x4x8x21x17xi8>, %arg1: tensor<34x1x1x1x17xi8>, %arg2: tensor<21xi32>, %arg3: tensor<1xi8>) -> tensor<1x4x8x21x34xi32> {
func.func @test_conv3d_non_const_weight_zp(%arg0: tensor<1x4x8x21x17xi8>, %arg1: tensor<34x1x1x1x17xi8>, %arg2: tensor<34xi32>, %arg3: tensor<1xi8>) -> tensor<1x4x8x21x34xi32> {
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.conv3d' op expected compile time resolvable constant, but got variable value for operand #4}}
%0 = tosa.conv3d %arg0, %arg1, %arg2, %input_zp, %arg3 {acc_type = i32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xi8>, tensor<34x1x1x1x17xi8>, tensor<21xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x8x21x34xi32>
%0 = tosa.conv3d %arg0, %arg1, %arg2, %input_zp, %arg3 {acc_type = i32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xi8>, tensor<34x1x1x1x17xi8>, tensor<34xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x8x21x34xi32>
return %0 : tensor<1x4x8x21x34xi32>
}

Expand Down
16 changes: 8 additions & 8 deletions mlir/test/Dialect/Tosa/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,15 @@ func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8>

// -----
// CHECK-LABEL: conv3d
func.func @test_conv3d(%arg0: tensor<1x4x8x21x17xf32>, %arg1: tensor<34x1x1x1x17xf32>, %arg2: tensor<21xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x8x21x34xf32> {
%0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<21xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x8x21x34xf32>
func.func @test_conv3d(%arg0: tensor<1x4x8x21x17xf32>, %arg1: tensor<34x1x1x1x17xf32>, %arg2: tensor<34xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x8x21x34xf32> {
%0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<34xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x8x21x34xf32>
return %0 : tensor<1x4x8x21x34xf32>
}

// -----
// CHECK-LABEL: conv3d_with_local_bound
func.func @test_conv3d_with_local_bound(%arg0: tensor<1x4x8x21x17xf32>, %arg1: tensor<34x1x1x1x17xf32>, %arg2: tensor<21xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x8x21x34xf32> {
%0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, local_bound = true} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<21xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x8x21x34xf32>
func.func @test_conv3d_with_local_bound(%arg0: tensor<1x4x8x21x17xf32>, %arg1: tensor<34x1x1x1x17xf32>, %arg2: tensor<34xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x8x21x34xf32> {
%0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, local_bound = true} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<34xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x8x21x34xf32>
return %0 : tensor<1x4x8x21x34xf32>
}

Expand Down Expand Up @@ -823,8 +823,8 @@ func.func @test_conv2d_f8E5M2(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<8x1x1

// -----
// CHECK-LABEL: conv3d_f8E5M2
func.func @test_conv3d_f8E5M2(%arg0: tensor<1x4x8x21x17xf8E5M2>, %arg1: tensor<34x1x1x1x17xf8E5M2>, %arg2: tensor<21xf16>, %arg3: tensor<1xf8E5M2>, %arg4: tensor<1xf8E5M2>) -> tensor<1x4x8x21x34xf16> {
%0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f16, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf8E5M2>, tensor<34x1x1x1x17xf8E5M2>, tensor<21xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x8x21x34xf16>
func.func @test_conv3d_f8E5M2(%arg0: tensor<1x4x8x21x17xf8E5M2>, %arg1: tensor<34x1x1x1x17xf8E5M2>, %arg2: tensor<34xf16>, %arg3: tensor<1xf8E5M2>, %arg4: tensor<1xf8E5M2>) -> tensor<1x4x8x21x34xf16> {
%0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f16, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf8E5M2>, tensor<34x1x1x1x17xf8E5M2>, tensor<34xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x8x21x34xf16>
return %0 : tensor<1x4x8x21x34xf16>
}

Expand Down Expand Up @@ -968,8 +968,8 @@ func.func @test_conv2d_f8E4M3FN(%arg0: tensor<1x4x4x4xf8E4M3FN>, %arg1: tensor<8

// -----
// CHECK-LABEL: conv3d_f8E4M3FN
func.func @test_conv3d_f8E4M3FN(%arg0: tensor<1x4x8x21x17xf8E4M3FN>, %arg1: tensor<34x1x1x1x17xf8E4M3FN>, %arg2: tensor<21xf16>, %arg3: tensor<1xf8E4M3FN>, %arg4: tensor<1xf8E4M3FN>) -> tensor<1x4x8x21x34xf16> {
%0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f16, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf8E4M3FN>, tensor<34x1x1x1x17xf8E4M3FN>, tensor<21xf16>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x4x8x21x34xf16>
func.func @test_conv3d_f8E4M3FN(%arg0: tensor<1x4x8x21x17xf8E4M3FN>, %arg1: tensor<34x1x1x1x17xf8E4M3FN>, %arg2: tensor<34xf16>, %arg3: tensor<1xf8E4M3FN>, %arg4: tensor<1xf8E4M3FN>) -> tensor<1x4x8x21x34xf16> {
%0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f16, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf8E4M3FN>, tensor<34x1x1x1x17xf8E4M3FN>, tensor<34xf16>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x4x8x21x34xf16>
return %0 : tensor<1x4x8x21x34xf16>
}

Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %
}

// -----
func.func @test_conv3d(%arg0: tensor<1x4x8x21x17xf16>, %arg1: tensor<34x1x1x1x17xf16>, %arg2: tensor<21xf16>, %arg3: tensor<1xf16>, %arg4: tensor<1xf16>) -> tensor<1x4x8x21x34xf16> {
func.func @test_conv3d(%arg0: tensor<1x4x8x21x17xf16>, %arg1: tensor<34x1x1x1x17xf16>, %arg2: tensor<34xf16>, %arg3: tensor<1xf16>, %arg4: tensor<1xf16>) -> tensor<1x4x8x21x34xf16> {
// expected-error@+1 {{'tosa.conv3d' op illegal: requires [pro_fp] but not enabled in target}}
%0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf16>, tensor<34x1x1x1x17xf16>, tensor<21xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x4x8x21x34xf16>
%0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf16>, tensor<34x1x1x1x17xf16>, tensor<34xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x4x8x21x34xf16>
return %0 : tensor<1x4x8x21x34xf16>
}

Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xi8>, %arg1: tensor<8x1x1x4xi8>, %ar
}

// -----
func.func @test_conv3d(%arg0: tensor<1x4x8x21x17xi8>, %arg1: tensor<34x1x1x1x17xi8>, %arg2: tensor<21xi32>, %arg3: tensor<1xi8>, %arg4: tensor<1xi8>) -> tensor<1x4x8x21x34xi32> {
func.func @test_conv3d(%arg0: tensor<1x4x8x21x17xi8>, %arg1: tensor<34x1x1x1x17xi8>, %arg2: tensor<34xi32>, %arg3: tensor<1xi8>, %arg4: tensor<1xi8>) -> tensor<1x4x8x21x34xi32> {
// expected-error@+1 {{'tosa.conv3d' op illegal: requires [pro_int] but not enabled in target}}
%0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = i32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xi8>, tensor<34x1x1x1x17xi8>, tensor<21xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x8x21x34xi32>
%0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = i32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xi8>, tensor<34x1x1x1x17xi8>, tensor<34xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x8x21x34xi32>
return %0 : tensor<1x4x8x21x34xi32>
}

Expand Down
16 changes: 8 additions & 8 deletions mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -824,18 +824,18 @@ func.func @conv2d_strided(%input: tensor<1x13x15x1xf32>, %weights: tensor<1x1x1x
// -----

// CHECK-LABEL: @conv3d_static
func.func @conv3d_static(%input: tensor<2x8x9x10x3xf32>, %weights: tensor<5x3x6x4x3xf32>, %bias: tensor<7xf32>, %input_zp: tensor<1xf32>, %weight_zp: tensor<1xf32>) -> () {
func.func @conv3d_static(%input: tensor<2x8x9x10x3xf32>, %weights: tensor<5x3x6x4x3xf32>, %bias: tensor<5xf32>, %input_zp: tensor<1xf32>, %weight_zp: tensor<1xf32>) -> () {
// CHECK: -> tensor<2x6x4x7x5xf32>
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor<7xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?x?xf32>
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?x?xf32>
return
}

// -----

// CHECK-LABEL: @conv3d_dynamic_input
func.func @conv3d_dynamic_input(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tensor<5x3x6x4x3xf32>, %arg2: tensor<7xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
func.func @conv3d_dynamic_input(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tensor<5x3x6x4x3xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
// CHECK: -> tensor<?x?x?x?x5xf32>
%0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<?x?x?x?x?xf32>, tensor<5x3x6x4x3xf32>, tensor<7xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?x?xf32>
%0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<?x?x?x?x?xf32>, tensor<5x3x6x4x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?x?xf32>
return
}

Expand All @@ -860,18 +860,18 @@ func.func @conv3d_dynamic_bias(%arg0: tensor<2x8x9x10x3xf32>, %arg1: tensor<5x3x
// -----

// CHECK-LABEL: @conv3d_padded
func.func @conv3d_padded(%arg0: tensor<2x8x9x10x3xf32>, %arg1: tensor<5x3x6x4x3xf32>, %arg2: tensor<18xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
func.func @conv3d_padded(%arg0: tensor<2x8x9x10x3xf32>, %arg1: tensor<5x3x6x4x3xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
// CHECK: -> tensor<2x9x11x18x5xf32>
%0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 1, 2, 3, 4, 5, 6>, stride = array<i64: 1, 1, 1>} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor<18xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?x?xf32>
%0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 1, 2, 3, 4, 5, 6>, stride = array<i64: 1, 1, 1>} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?x?xf32>
return
}

// -----

// CHECK-LABEL: @conv3d_dilated
func.func @conv3d_dilated(%arg0: tensor<2x12x14x16x3xf32>, %arg1: tensor<5x3x6x2x3xf32>, %arg2: tensor<12xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
func.func @conv3d_dilated(%arg0: tensor<2x12x14x16x3xf32>, %arg1: tensor<5x3x6x2x3xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
// CHECK: -> tensor<2x6x4x12x5xf32>
%0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 3, 2, 4>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<2x12x14x16x3xf32>, tensor<5x3x6x2x3xf32>, tensor<12xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?x?xf32>
%0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 3, 2, 4>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<2x12x14x16x3xf32>, tensor<5x3x6x2x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?x?xf32>
return
}

Expand Down
Loading