Skip to content

Commit 7512a61

Browse files
committed
LHS index + FAQ:dynamic length input/output
1 parent 32425bf commit 7512a61

File tree

1 file changed

+74
-2
lines changed

1 file changed

+74
-2
lines changed

docs/source/onnx.rst

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,50 @@ The dynamic control flow is captured correctly. We can verify in backends with d
232232
Limitations
233233
-----------
234234

235-
* TODO: LHS indexing, indexed assignments.
235+
* Tensor in-place indexed assignment like `data[index] = new_data` is currently not supported in exporting.
236+
One way to resolve this kind of issue is to use operator `scatter`, explicitly updating the original tensor. ::
237+
238+
data = torch.zeros(3, 4)
239+
index = torch.tensor(1)
240+
new_data = torch.arange(4).to(torch.float32)
241+
242+
# Assigning to left hand side indexing is not supported in exporting.
243+
# class InPlaceIndexedAssignment(torch.nn.Module):
244+
# def forward(self, data, index, new_data):
245+
# data[index] = new_data
246+
# return data
247+
248+
class InPlaceIndexedAssignmentONNX(torch.nn.Module):
249+
def forward(self, data, index, new_data):
250+
new_data = new_data.unsqueeze(0)
251+
index = index.expand(1, new_data.size(1))
252+
data.scatter_(0, index, new_data)
253+
return data
254+
255+
out = InPlaceIndexedAssignmentONNX()(data, index, new_data)
256+
257+
torch.onnx.export(InPlaceIndexedAssignmentONNX(), (data, index, new_data), 'inplace_assign.onnx')
258+
259+
# caffe2
260+
import caffe2.python.onnx.backend as backend
261+
import onnx
262+
263+
onnx_model = onnx.load('inplace_assign.onnx')
264+
rep = backend.prepare(onnx_model)
265+
out_caffe2 = rep.run((torch.zeros(3, 4).numpy(), index.numpy(), new_data.numpy()))
266+
267+
assert torch.all(torch.eq(out, torch.tensor(out_caffe2)))
268+
269+
# onnxruntime
270+
import onnxruntime
271+
sess = onnxruntime.InferenceSession('inplace_assign.onnx')
272+
out_ort = sess.run(None, {
273+
sess.get_inputs()[0].name: torch.zeros(3, 4).numpy(),
274+
sess.get_inputs()[1].name: index.numpy(),
275+
sess.get_inputs()[2].name: new_data.numpy(),
276+
})
277+
278+
assert torch.all(torch.eq(out, torch.tensor(out_ort)))
236279

237280
* TODO: Tensor List.
238281

@@ -570,7 +613,36 @@ with matching custom ops implementation, e.g. `Caffe2 custom ops <https://caffe2
570613

571614
Frequently Asked Questions
572615
--------------------------
573-
Q:
616+
Q: I have exported my lstm model, but its input size seems to be fixed?
617+
618+
The tracer records the example inputs shape in the graph. In case the model should accept
619+
inputs of dynamic shape, you can utilize the parameter `dynamic_axes` in export api. ::
620+
621+
layer_count = 4
622+
623+
model = nn.LSTM(10, 20, num_layers=layer_count, bidirectional=True)
624+
model.eval()
625+
626+
with torch.no_grad():
627+
input = torch.randn(5, 3, 10)
628+
h0 = torch.randn(layer_count * 2, 3, 20)
629+
c0 = torch.randn(layer_count * 2, 3, 20)
630+
output, (hn, cn) = model(input, (h0, c0))
631+
632+
# default export
633+
torch.onnx.export(model, (input, (h0, c0)), 'lstm.onnx')
634+
onnx_model = onnx.load('lstm.onnx')
635+
# input shape [5, 3, 10]
636+
print(onnx_model.graph.input[0])
637+
638+
# export with `dynamic_axes`
639+
torch.onnx.export(model, (input, (h0, c0)), 'lstm.onnx',
640+
input_names=['input', 'h0', 'c0'],
641+
output_names=['output', 'hn', 'cn'],
642+
dynamic_axes={'input': {0: 'sequence'}, 'output': {0: 'sequence'}})
643+
onnx_model = onnx.load('lstm.onnx')
644+
# input shape ['sequence', 3, 10]
645+
print(onnx_model.graph.input[0])
574646

575647

576648
Q: How to export models with loops in it?

0 commit comments

Comments
 (0)