Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added testdata/dnn/onnx/data/input_einsum_1d_0.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/input_einsum_1d_1.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/input_einsum_2d_0.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/input_einsum_2d_1.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/input_einsum_3d_0.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/input_einsum_3d_1.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/input_einsum_4d_0.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/input_einsum_4d_1.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/input_einsum_5d_0.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/input_einsum_5d_1.npy
Binary file not shown.
Binary file not shown.
Binary file added testdata/dnn/onnx/data/input_einsum_hadamard_0.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/input_einsum_hadamard_1.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/input_einsum_inner_0.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/input_einsum_inner_1.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/input_einsum_sum.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/input_einsum_transpose.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/output_einsum_1d.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/output_einsum_2d.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/output_einsum_3d.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/output_einsum_4d.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/output_einsum_5d.npy
Binary file not shown.
Binary file not shown.
Binary file added testdata/dnn/onnx/data/output_einsum_hadamard.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/output_einsum_inner.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/output_einsum_sum.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/output_einsum_transpose.npy
Binary file not shown.
91 changes: 91 additions & 0 deletions testdata/dnn/onnx/generate_onnx_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,97 @@ def step(self) -> Tuple[np.ndarray, np.ndarray]:
Y = np.squeeze(Y)
return Y, Y_h

class Einsum(nn.Module):
def __init__(self, equation):
super(Einsum, self).__init__()
self.equation = equation

def forward(self, one, two):
return torch.einsum(self.equation, one, two)

class EinsumSingleInput(Einsum):
def forward(self, x):
return torch.einsum(self.equation, x)

# inner/dot product
mat1 = torch.ones(4)
mat2 = torch.ones(4)
equation = 'i,i'
einsum = Einsum(equation)
output = einsum(mat1, mat2)

save_data_and_model_multy_inputs("einsum_inner", einsum, mat1, mat2, export_params=True)

# 1d hadamard
mat1 = torch.ones(4)
mat2 = torch.ones(4)
equation = 'i,i->i'
einsum = Einsum(equation)
output = einsum(mat1, mat2)

save_data_and_model_multy_inputs("einsum_hadamard", einsum, mat1, mat2, export_params=True)

# 2d test case
mat1 = torch.randn(4, 5)
mat2 = torch.randn(5, 8)
equation = 'ij,jk->ik'
einsum = Einsum(equation)
output = einsum(mat1, mat2)

save_data_and_model_multy_inputs("einsum_2d", einsum, mat1, mat2, export_params=True)

# 3d test case
mat1 = torch.ones(2, 4, 5)
mat2 = torch.ones(2, 5, 8)
equation = 'bij,bjk->bik'
einsum = Einsum(equation)
output = einsum(mat1, mat2)

save_data_and_model_multy_inputs("einsum_3d", einsum, mat1, mat2, export_params=True)

# 4d test case
mat1 = torch.randn(1, 4, 7, 9)
mat2 = torch.randn(1, 5, 9, 8)
equation = 'imkj,injs->imnks'
einsum = Einsum(equation)
output = einsum(mat1, mat2)

save_data_and_model_multy_inputs("einsum_4d", einsum, mat1, mat2, export_params=True)

# 5d test case
mat1 = torch.randn(4, 2, 3, 4, 5)
mat2 = torch.randn(4, 2, 3, 5, 8)
equation = 'bhijk,bhikc->bhijc'
einsum = Einsum(equation)
output = einsum(mat1, mat2)

save_data_and_model_multy_inputs("einsum_5d", einsum, mat1, mat2, export_params=True)

# sum
mat = torch.randn(3, 4)
equation = "ij->i"
einsum = EinsumSingleInput(equation)
output = einsum(mat)

save_data_and_model("einsum_sum", mat, einsum, export_params=True)

# sum
mat = torch.randn(3, 5, 5)
equation = "...ii ->...i"
einsum = EinsumSingleInput(equation)
output = einsum(mat)

save_data_and_model("einsum_batch_diagonal", mat, einsum, export_params=True)

# einsum transpose
mat = torch.randn(3, 4)
equation = "ij->ji"
einsum = EinsumSingleInput(equation)
output = einsum(mat)

save_data_and_model("einsum_transpose", mat, einsum, export_params=True)


def _extract_value_info(x, name, type_proto=None): # type: (Union[List[Any], np.ndarray, None], Text, Optional[TypeProto]) -> onnx.ValueInfoProto
if type_proto is None:
if x is None:
Expand Down
16 changes: 16 additions & 0 deletions testdata/dnn/onnx/models/einsum_1d.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
pytorch2.0.0:�
5
one
two2/Einsum"Einsum*
equation"i,i->i� torch_jitZ
one


Z
two


b
2

Einsum2_dim_0B
Expand Down
17 changes: 17 additions & 0 deletions testdata/dnn/onnx/models/einsum_2d.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
pytorch2.0.0:�
8
one
two2/Einsum"Einsum*
equation" ij,jk->ik� torch_jitZ
one


Z
two


b-
2(
&"
Einsum2_dim_0
Einsum2_dim_1B
Expand Down
20 changes: 20 additions & 0 deletions testdata/dnn/onnx/models/einsum_3d.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
pytorch2.0.0:�
;
one
two2/Einsum"Einsum*
equation" bij,bjk->bik� torch_jitZ
one



Z
two



b>
29
73
Einsum2_dim_0
Einsum2_dim_1
Einsum2_dim_2B
Expand Down
24 changes: 24 additions & 0 deletions testdata/dnn/onnx/models/einsum_4d.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
pytorch2.0.0:�
?
one
two2/Einsum"Einsum*
equation"imkj,injs->imnks� torch_jitZ
one




 Z
two




b`
2[
YU
Einsum2_dim_0
Einsum2_dim_1
Einsum2_dim_2
Einsum2_dim_3
Einsum2_dim_4B
Expand Down
26 changes: 26 additions & 0 deletions testdata/dnn/onnx/models/einsum_5d.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
pytorch2.0.0:�
A
one
two2/Einsum"Einsum*!
equation"bhijk,bhikc->bhijc� torch_jitZ!
one





Z!
two





b`
2[
YU
Einsum2_dim_0
Einsum2_dim_1
Einsum2_dim_2
Einsum2_dim_3
Einsum2_dim_4B
Expand Down
13 changes: 13 additions & 0 deletions testdata/dnn/onnx/models/einsum_batch_diagonal.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
pytorch2.0.0:�
4
x1/Einsum"Einsum*
equation" ...ii ->...i� torch_jitZ
x



b-
1(
&"
Einsum1_dim_0
Einsum1_dim_1B
Expand Down
16 changes: 16 additions & 0 deletions testdata/dnn/onnx/models/einsum_hadamard.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
pytorch2.0.0:�
5
one
two2/Einsum"Einsum*
equation"i,i->i� torch_jitZ
one


Z
two


b
2

Einsum2_dim_0B
Expand Down
Binary file added testdata/dnn/onnx/models/einsum_inner.onnx
Binary file not shown.
11 changes: 11 additions & 0 deletions testdata/dnn/onnx/models/einsum_sum.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
pytorch2.0.0:m
-
x1/Einsum"Einsum*
equation"ij->i� torch_jitZ
x


b
1

Einsum1_dim_0B
Expand Down
12 changes: 12 additions & 0 deletions testdata/dnn/onnx/models/einsum_transpose.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
pytorch2.0.0:
.
x1/Einsum"Einsum*
equation"ij->ji� torch_jitZ
x


b-
1(
&"
Einsum1_dim_0
Einsum1_dim_1B
Expand Down