From 94567641e33d5c16b2262ccbec33f2a603e612ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojtek=20Jasi=C5=84ski?= Date: Fri, 1 Nov 2024 19:46:46 +0100 Subject: [PATCH 1/3] Fix positional embedding resampling for non-square inputs in ViT --- timm/models/vision_transformer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 8bc09e94fb..b3b0ddca07 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -669,9 +669,11 @@ def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: if self.dynamic_img_size: B, H, W, C = x.shape + prev_grid_size = self.patch_embed.grid_size pos_embed = resample_abs_pos_embed( self.pos_embed, - (H, W), + new_size=(H, W), + old_size=prev_grid_size, num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, ) x = x.view(B, -1, C) From 67ed0e8a4c3feb206eec778a6a14fc6c65568dc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojtek=20Jasi=C5=84ski?= Date: Fri, 1 Nov 2024 23:24:13 +0100 Subject: [PATCH 2/3] fix pos embed dynamic resampling for deit --- timm/models/deit.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/timm/models/deit.py b/timm/models/deit.py index 63662c02d4..0072013bf6 100644 --- a/timm/models/deit.py +++ b/timm/models/deit.py @@ -75,9 +75,11 @@ def set_distilled_training(self, enable=True): def _pos_embed(self, x): if self.dynamic_img_size: B, H, W, C = x.shape + prev_grid_size = self.patch_embed.grid_size pos_embed = resample_abs_pos_embed( self.pos_embed, - (H, W), + new_size=(H, W), + old_size=prev_grid_size, num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, ) x = x.view(B, -1, C) From b8c4c6fb4b8a56962979f7bc26cb07748409e36d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojtek=20Jasi=C5=84ski?= Date: Fri, 1 Nov 2024 23:27:06 +0100 Subject: [PATCH 3/3] fix pos embed dynamic resampling for eva --- timm/models/eva.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/timm/models/eva.py b/timm/models/eva.py index 62e986ba3b..fe87154050 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -560,9 +560,11 @@ def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if self.dynamic_img_size: B, H, W, C = x.shape if self.pos_embed is not None: + prev_grid_size = self.patch_embed.grid_size pos_embed = resample_abs_pos_embed( self.pos_embed, - (H, W), + new_size=(H, W), + old_size=prev_grid_size, num_prefix_tokens=self.num_prefix_tokens, ) else: