From 9833ab0b8032982d20e7741545f32df672621d39 Mon Sep 17 00:00:00 2001 From: Amir Arsalan Soltani Date: Tue, 22 Jan 2019 11:50:33 -0500 Subject: [PATCH] Update convert_torch.py --- convert_torch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/convert_torch.py b/convert_torch.py index b639f73..894bffa 100644 --- a/convert_torch.py +++ b/convert_torch.py @@ -41,7 +41,9 @@ def forward(self, input): def copy_param(m,n): - if m.weight is not None: n.weight.data.copy_(m.weight) + if m.weight is not None: + m.weight.data = m.weight.view(n.weight.size()) + n.weight.data.copy_(m.weight) if m.bias is not None: n.bias.data.copy_(m.bias) if hasattr(n,'running_mean'): n.running_mean.copy_(m.running_mean) if hasattr(n,'running_var'): n.running_var.copy_(m.running_var)