Skip to content

Commit 3595b4a

Browse files
author
Prashant Kumar
committed
Incorporate latest changes in the shark_dynamo backend.
1 parent 3a9cfe1 commit 3595b4a

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

shark/examples/shark_dynamo/basic_examples.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import torchdynamo
21
import torch
32
import torch_mlir
3+
import torch._dynamo as torchdynamo
44
from shark.sharkdynamo.utils import make_shark_compiler
55

66

shark/sharkdynamo/utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import List, Optional
44
import torch
55
from torch.fx.experimental.proxy_tensor import make_fx
6-
from functorch._src.compile_utils import strip_overloads
6+
from torch._functorch.compile_utils import strip_overloads
77
from shark.shark_inference import SharkInference
88
from torch._decomp import get_decompositions
99

@@ -119,14 +119,19 @@ def compiler(
119119
example_inputs,
120120
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
121121
)
122+
import io
123+
124+
bytecode_stream = io.BytesIO()
125+
linalg_module.operation.write_bytecode(bytecode_stream)
126+
mlir_module = bytecode_stream.getvalue()
122127

123128
shark_module = SharkInference(
124-
linalg_module, "forward", mlir_dialect="linalg", device=device
129+
mlir_module, mlir_dialect="linalg", device=device
125130
)
126131
shark_module.compile()
127132

128133
def forward(*inputs):
129-
result = shark_module.forward(inputs)
134+
result = shark_module("forward", inputs)
130135
result = tuple() if result is None else result
131136
return (result,) if was_unwrapped else result
132137

0 commit comments

Comments
 (0)