@@ -232,7 +232,50 @@ The dynamic control flow is captured correctly. We can verify in backends with d
232232Limitations
233233-----------
234234
235- * TODO: LHS indexing, indexed assignments.
235+ * Tensor in-place indexed assignment like `data[index] = new_data ` is currently not supported in exporting.
236+ One way to resolve this kind of issue is to use operator `scatter `, explicitly updating the original tensor. ::
237+
238+ data = torch.zeros(3, 4)
239+ index = torch.tensor(1)
240+ new_data = torch.arange(4).to(torch.float32)
241+
242+ # Assigning to left hand side indexing is not supported in exporting.
243+ # class InPlaceIndexedAssignment(torch.nn.Module):
244+ # def forward(self, data, index, new_data):
245+ # data[index] = new_data
246+ # return data
247+
248+ class InPlaceIndexedAssignmentONNX(torch.nn.Module):
249+ def forward(self, data, index, new_data):
250+ new_data = new_data.unsqueeze(0)
251+ index = index.expand(1, new_data.size(1))
252+ data.scatter_(0, index, new_data)
253+ return data
254+
255+ out = InPlaceIndexedAssignmentONNX()(data, index, new_data)
256+
257+ torch.onnx.export(InPlaceIndexedAssignmentONNX(), (data, index, new_data), 'inplace_assign.onnx')
258+
259+ # caffe2
260+ import caffe2.python.onnx.backend as backend
261+ import onnx
262+
263+ onnx_model = onnx.load('inplace_assign.onnx')
264+ rep = backend.prepare(onnx_model)
265+ out_caffe2 = rep.run((torch.zeros(3, 4).numpy(), index.numpy(), new_data.numpy()))
266+
267+ assert torch.all(torch.eq(out, torch.tensor(out_caffe2)))
268+
269+ # onnxruntime
270+ import onnxruntime
271+ sess = onnxruntime.InferenceSession('inplace_assign.onnx')
272+ out_ort = sess.run(None, {
273+ sess.get_inputs()[0].name: torch.zeros(3, 4).numpy(),
274+ sess.get_inputs()[1].name: index.numpy(),
275+ sess.get_inputs()[2].name: new_data.numpy(),
276+ })
277+
278+ assert torch.all(torch.eq(out, torch.tensor(out_ort)))
236279
237280* TODO: Tensor List.
238281
@@ -570,7 +613,36 @@ with matching custom ops implementation, e.g. `Caffe2 custom ops <https://caffe2
570613
571614Frequently Asked Questions
572615--------------------------
573- Q:
616+ Q: I have exported my lstm model, but its input size seems to be fixed?
617+
618+ The tracer records the example inputs shape in the graph. In case the model should accept
619+ inputs of dynamic shape, you can utilize the parameter `dynamic_axes ` in export api. ::
620+
621+ layer_count = 4
622+
623+ model = nn.LSTM(10, 20, num_layers=layer_count, bidirectional=True)
624+ model.eval()
625+
626+ with torch.no_grad():
627+ input = torch.randn(5, 3, 10)
628+ h0 = torch.randn(layer_count * 2, 3, 20)
629+ c0 = torch.randn(layer_count * 2, 3, 20)
630+ output, (hn, cn) = model(input, (h0, c0))
631+
632+ # default export
633+ torch.onnx.export(model, (input, (h0, c0)), 'lstm.onnx')
634+ onnx_model = onnx.load('lstm.onnx')
635+ # input shape [5, 3, 10]
636+ print(onnx_model.graph.input[0])
637+
638+ # export with `dynamic_axes`
639+ torch.onnx.export(model, (input, (h0, c0)), 'lstm.onnx',
640+ input_names=['input', 'h0', 'c0'],
641+ output_names=['output', 'hn', 'cn'],
642+ dynamic_axes={'input': {0: 'sequence'}, 'output': {0: 'sequence'}})
643+ onnx_model = onnx.load('lstm.onnx')
644+ # input shape ['sequence', 3, 10]
645+ print(onnx_model.graph.input[0])
574646
575647
576648Q: How to export models with loops in it?
0 commit comments