diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index 62f53bb39b..9c6dc6e6ae 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -354,6 +354,7 @@ def __init__( self.dim = dim self.max_res = max_res self.temperature = temperature + self.linear_bands = linear_bands self.in_pixels = in_pixels self.feat_shape = feat_shape self.ref_feat_shape = ref_feat_shape @@ -383,17 +384,7 @@ def __init__( self.pos_embed_cos = None else: # cache full sin/cos embeddings if shape provided up front - emb_sin, emb_cos = build_rotary_pos_embed( - feat_shape=feat_shape, - dim=dim, - max_res=max_res, - linear_bands=linear_bands, - in_pixels=in_pixels, - ref_feat_shape=self.ref_feat_shape, - grid_offset=self.grid_offset, - grid_indexing=self.grid_indexing, - temperature=self.temperature, - ) + emb_sin, emb_cos = self._get_pos_embed_values(feat_shape) self.bands = None self.register_buffer( 'pos_embed_sin', @@ -406,6 +397,30 @@ def __init__( persistent=False, ) + def _get_pos_embed_values(self, feat_shape: List[int]): + emb_sin, emb_cos = build_rotary_pos_embed( + feat_shape=feat_shape, + dim=self.dim, + max_res=self.max_res, + temperature=self.temperature, + linear_bands=self.linear_bands, + in_pixels=self.in_pixels, + ref_feat_shape=self.ref_feat_shape, + grid_offset=self.grid_offset, + grid_indexing=self.grid_indexing, + ) + return emb_sin, emb_cos + + def update_feat_shape(self, feat_shape: List[int]): + if self.feat_shape is not None and feat_shape != self.feat_shape: + # only update if feat_shape was set and different from previous value + assert self.pos_embed_sin is not None + assert self.pos_embed_cos is not None + emb_sin, emb_cos = self._get_pos_embed_values(feat_shape) + self.pos_embed_sin = emb_sin.to(self.pos_embed_sin.device, self.pos_embed_sin.dtype) + self.pos_embed_cos = emb_cos.to(self.pos_embed_cos.device, self.pos_embed_cos.dtype) + self.feat_shape = feat_shape + def get_embed(self, shape: Optional[List[int]] = None): if shape is not None and self.bands is not None: # rebuild embeddings every call, use if target shape changes @@ -453,6 +468,7 @@ def __init__( self.max_res = max_res self.temperature = temperature self.in_pixels = in_pixels + self.linear_bands = linear_bands self.feat_shape = feat_shape self.ref_feat_shape = ref_feat_shape self.grid_offset = grid_offset @@ -480,27 +496,40 @@ def __init__( self.pos_embed = None else: # cache full sin/cos embeddings if shape provided up front - embeds = build_rotary_pos_embed( - feat_shape=feat_shape, - dim=dim, - max_res=max_res, - linear_bands=linear_bands, - in_pixels=in_pixels, - ref_feat_shape=self.ref_feat_shape, - grid_offset=self.grid_offset, - grid_indexing=self.grid_indexing, - temperature=self.temperature, - ) self.bands = None self.register_buffer( 'pos_embed', - torch.cat(embeds, -1), + self._get_pos_embed_values(feat_shape=feat_shape), persistent=False, ) + def _get_pos_embed_values(self, feat_shape: List[int]): + embeds = build_rotary_pos_embed( + feat_shape=feat_shape, + dim=self.dim, + max_res=self.max_res, + temperature=self.temperature, + linear_bands=self.linear_bands, + in_pixels=self.in_pixels, + ref_feat_shape=self.ref_feat_shape, + grid_offset=self.grid_offset, + grid_indexing=self.grid_indexing, + ) + return torch.cat(embeds, -1) + + def update_feat_shape(self, feat_shape: List[int]): + if self.feat_shape is not None and feat_shape != self.feat_shape: + # only update if feat_shape was set and different from previous value + assert self.pos_embed is not None + self.pos_embed = self._get_pos_embed_values(feat_shape).to( + device=self.pos_embed.device, + dtype=self.pos_embed.dtype, + ) + self.feat_shape = feat_shape + def get_embed(self, shape: Optional[List[int]] = None): if shape is not None and self.bands is not None: - # rebuild embeddings every call, use if target shape changes + # rebuild embeddings from cached bands every call, use if target shape changes embeds = build_rotary_pos_embed( shape, self.bands, @@ -684,6 +713,7 @@ def __init__( head_dim = dim // num_heads assert head_dim % 4 == 0, f"head_dim must be divisible by 4, got {head_dim}" + freqs = init_random_2d_freqs( head_dim, depth, @@ -692,18 +722,32 @@ def __init__( rotate=True, ) # (2, depth, num_heads, head_dim//2) self.freqs = nn.Parameter(freqs) + if feat_shape is not None: # cache pre-computed grid - t_x, t_y = get_mixed_grid( - feat_shape, - grid_indexing=grid_indexing, - device=self.freqs.device - ) + t_x, t_y = self._get_grid_values(feat_shape) self.register_buffer('t_x', t_x, persistent=False) self.register_buffer('t_y', t_y, persistent=False) else: self.t_x = self.t_y = None + def _get_grid_values(self, feat_shape: Optional[List[int]]): + t_x, t_y = get_mixed_grid( + feat_shape, + grid_indexing=self.grid_indexing, + device=self.freqs.device + ) + return t_x, t_y + + def update_feat_shape(self, feat_shape: Optional[List[int]]): + if self.feat_shape is not None and feat_shape != self.feat_shape: + assert self.t_x is not None + assert self.t_y is not None + t_x, t_y = self._get_grid_values(feat_shape) + self.t_x = t_x.to(self.t_x.device, self.t_x.dtype) + self.t_y = t_y.to(self.t_y.device, self.t_y.dtype) + self.feat_shape = feat_shape + def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor: """Generate rotary embeddings for the given spatial shape. diff --git a/timm/models/eva.py b/timm/models/eva.py index be15794bed..bcfa3ee2cb 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -723,6 +723,35 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + def set_input_size( + self, + img_size: Optional[Tuple[int, int]] = None, + patch_size: Optional[Tuple[int, int]] = None, + ) -> None: + """Update the input image resolution and patch size. + + Args: + img_size: New input resolution, if None current resolution is used. + patch_size: New patch size, if None existing patch size is used. + """ + prev_grid_size = self.patch_embed.grid_size + self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size) + + if self.pos_embed is not None: + num_prefix_tokens = 0 if self.no_embed_class else self.num_prefix_tokens + num_new_tokens = self.patch_embed.num_patches + num_prefix_tokens + if num_new_tokens != self.pos_embed.shape[1]: + self.pos_embed = nn.Parameter(resample_abs_pos_embed( + self.pos_embed, + new_size=self.patch_embed.grid_size, + old_size=prev_grid_size, + num_prefix_tokens=num_prefix_tokens, + verbose=True, + )) + + if self.rope is not None: + self.rope.update_feat_shape(self.patch_embed.grid_size) + def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if self.dynamic_img_size: B, H, W, C = x.shape