From 45e9e94ddf56bde55f5c27048964eeca7f8a9704 Mon Sep 17 00:00:00 2001 From: Amy Roberts Date: Thu, 26 May 2022 11:39:51 +0100 Subject: [PATCH] Add mapping for batch norm layer keys --- src/transformers/modeling_tf_pytorch_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index 59846a892533..888277101412 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -163,6 +163,10 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a new_key = key.replace("gamma", "weight") if "beta" in key: new_key = key.replace("beta", "bias") + if "running_var" in key: + new_key = key.replace("running_var", "moving_variance") + if "running_mean" in key: + new_key = key.replace("running_mean", "moving_mean") if new_key: old_keys.append(key) new_keys.append(new_key)