From 11665118ba57825d662d2479822d4f48a25fe4a8 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Sat, 18 Jun 2022 22:34:59 +0200 Subject: [PATCH 01/22] first commit: - added DPT in Flax - all non slow tests passes in local - still some nits have to be investigated --- docs/source/en/index.mdx | 2 +- docs/source/en/model_doc/dpt.mdx | 17 +- src/transformers/__init__.py | 14 + src/transformers/modeling_flax_outputs.py | 69 + .../models/auto/modeling_flax_auto.py | 36 + src/transformers/models/dpt/__init__.py | 33 +- .../models/dpt/modeling_flax_dpt.py | 1138 +++++++++++++++++ src/transformers/utils/dummy_flax_objects.py | 28 + tests/models/dpt/test_modeling_flax_dpt.py | 187 +++ utils/check_repo.py | 2 + 10 files changed, 1523 insertions(+), 3 deletions(-) create mode 100644 src/transformers/models/dpt/modeling_flax_dpt.py create mode 100644 tests/models/dpt/test_modeling_flax_dpt.py diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 687974ab0723..7c3cacab935d 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -214,7 +214,7 @@ Flax), PyTorch, and/or TensorFlow. | DETR | ❌ | ❌ | ✅ | ❌ | ❌ | | DistilBERT | ✅ | ✅ | ✅ | ✅ | ✅ | | DPR | ✅ | ✅ | ✅ | ✅ | ❌ | -| DPT | ❌ | ❌ | ✅ | ❌ | ❌ | +| DPT | ❌ | ❌ | ✅ | ❌ | ✅ | | ELECTRA | ✅ | ✅ | ✅ | ✅ | ✅ | | Encoder decoder | ❌ | ❌ | ✅ | ✅ | ✅ | | FairSeq Machine-Translation | ✅ | ❌ | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/dpt.mdx b/docs/source/en/model_doc/dpt.mdx index cdf009c6c8a0..1f33507eae71 100644 --- a/docs/source/en/model_doc/dpt.mdx +++ b/docs/source/en/model_doc/dpt.mdx @@ -54,4 +54,19 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The origi ## DPTForSemanticSegmentation [[autodoc]] DPTForSemanticSegmentation - - forward \ No newline at end of file + - forward + +## FlaxDPTForSemanticSegmentation + +[[autodoc]] FlaxDPTForSemanticSegmentation + - __call__ + +## FlaxDPTForDepthEstimation + +[[autodoc]] FlaxDPTForDepthEstimation + - __call__ + +## FlaxDPTModel + +[[autodoc]] FlaxDPTModel + - __call__ \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a68420b127ad..c03a17e544f9 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2579,6 +2579,14 @@ "FlaxDistilBertPreTrainedModel", ] ) + _import_structure["models.dpt"].extend( + [ + "FlaxDPTModel", + "FlaxDPTPreTrainedModel", + "FlaxDPTForSemanticSegmentation", + "FlaxDPTForDepthEstimation", + ] + ) _import_structure["models.electra"].extend( [ "FlaxElectraForCausalLM", @@ -4794,6 +4802,12 @@ FlaxDistilBertModel, FlaxDistilBertPreTrainedModel, ) + from .models.dpt import ( + FlaxDPTForDepthEstimation, + FlaxDPTForSemanticSegmentation, + FlaxDPTModel, + FlaxDPTPreTrainedModel, + ) from .models.electra import ( FlaxElectraForCausalLM, FlaxElectraForMaskedLM, diff --git a/src/transformers/modeling_flax_outputs.py b/src/transformers/modeling_flax_outputs.py index 4f6cc5a901f8..51fbe101b36c 100644 --- a/src/transformers/modeling_flax_outputs.py +++ b/src/transformers/modeling_flax_outputs.py @@ -640,3 +640,72 @@ class FlaxSeq2SeqQuestionAnsweringModelOutput(ModelOutput): encoder_last_hidden_state: Optional[jnp.ndarray] = None encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None encoder_attentions: Optional[Tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxDepthEstimatorOutput(ModelOutput): + """ + Base class for outputs of depth estimation models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + predicted_depth (`torch.FloatTensor` of shape `(batch_size, height, width)`): + Predicted depth for each pixel. + + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: jnp.ndarray = None + predicted_depth: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxSemanticSegmenterOutput(ModelOutput): + """ + Base class for outputs of depth estimation models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`): + Classification scores for each pixel. + + + + The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is + to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the + original image size as post-processing. You should always check your logits shape and resize as needed. + + + + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: jnp.ndarray = None + logits: jnp.ndarray = None + predicted_depth: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index 98c5d6fb5a10..6d3aa807d7ad 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -37,6 +37,7 @@ ("blenderbot-small", "FlaxBlenderbotSmallModel"), ("clip", "FlaxCLIPModel"), ("distilbert", "FlaxDistilBertModel"), + ("dpt", "FlaxDPTModel"), ("electra", "FlaxElectraModel"), ("gpt2", "FlaxGPT2Model"), ("gpt_neo", "FlaxGPTNeoModel"), @@ -211,6 +212,17 @@ ] ) +FLAX_MODEL_FOR_DEPTH_ESTIMATION_MAPPING = OrderedDict( + [ + ("dpt", "FlaxDPTForDepthEstimation"), + ] +) + +FLAX_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = OrderedDict( + [ + ("dpt", "FlaxDPTForSemanticSegmentation"), + ] +) FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES) FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES) @@ -241,6 +253,12 @@ FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES ) +FLAX_MODEL_FOR_DEPTH_ESTIMATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_DEPTH_ESTIMATION_MAPPING +) +FLAX_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING +) class FlaxAutoModel(_BaseAutoModelClass): @@ -344,3 +362,21 @@ class FlaxAutoModelForSpeechSeq2Seq(_BaseAutoModelClass): FlaxAutoModelForSpeechSeq2Seq = auto_class_update( FlaxAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling" ) + + +class FlaxAutoModelForDepthEstimation(_BaseAutoModelClass): + _model_mapping = FLAX_MODEL_FOR_DEPTH_ESTIMATION_MAPPING + + +FlaxAutoModelForDepthEstimation = auto_class_update( + FlaxAutoModelForDepthEstimation, head_doc="depth estimation modeling" +) + + +class FlaxAutoModelForSemanticSegmentation(_BaseAutoModelClass): + _model_mapping = FLAX_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING + + +FlaxAutoModelForSemanticSegmentation = auto_class_update( + FlaxAutoModelForSemanticSegmentation, head_doc="semantic segmentation modeling" +) diff --git a/src/transformers/models/dpt/__init__.py b/src/transformers/models/dpt/__init__.py index 1df82ab62824..2de43595c57d 100644 --- a/src/transformers/models/dpt/__init__.py +++ b/src/transformers/models/dpt/__init__.py @@ -17,7 +17,13 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...file_utils import _LazyModule, is_tokenizers_available, is_torch_available, is_vision_available +from ...file_utils import ( + _LazyModule, + is_flax_available, + is_tokenizers_available, + is_torch_available, + is_vision_available, +) from ...utils import OptionalDependencyNotAvailable @@ -45,6 +51,19 @@ "DPTPreTrainedModel", ] +try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_flax_dpt"] = [ + "FlaxDPTForSemanticSegmentation", + "FlaxDPTForDepthEstimation", + "FlaxDPTModel", + "FlaxDPTPreTrainedModel", + ] + if TYPE_CHECKING: from .configuration_dpt import DPT_PRETRAINED_CONFIG_ARCHIVE_MAP, DPTConfig @@ -71,6 +90,18 @@ DPTPreTrainedModel, ) + try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_flax_dpt import ( + FlaxDPTForDepthEstimation, + FlaxDPTForSemanticSegmentation, + FlaxDPTModel, + FlaxDPTPreTrainedModel, + ) else: import sys diff --git a/src/transformers/models/dpt/modeling_flax_dpt.py b/src/transformers/models/dpt/modeling_flax_dpt.py new file mode 100644 index 000000000000..83ac0257c641 --- /dev/null +++ b/src/transformers/models/dpt/modeling_flax_dpt.py @@ -0,0 +1,1138 @@ +# coding=utf-8 +# Copyright 2022 Intel Labs, OpenMMLab and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Flax DPT (Dense Prediction Transformers) model. + +TThis implementation is heavily inspired by OpenMMLab's implementation, found here: +https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/dpt_head.py. + +""" +import math +from typing import Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +import optax +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict + +from ...modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxBaseModelOutputWithPooling, + FlaxDepthEstimatorOutput, + FlaxSemanticSegmenterOutput, +) +from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward +from .configuration_dpt import DPTConfig + + +DPT_START_DOCSTRING = r""" + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading, saving and converting weights from PyTorch models) This + model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + general usage and behavior. Finally, this model supports inherent JAX features such as: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + Parameters: + config ([`ViTConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). This can be used to enable mixed-precision training or half-precision + inference on GPUs or TPUs. If specified all the computation will be performed with the given `dtype`. + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] + and [`~FlaxPreTrainedModel.to_bf16`]. +""" + +DPT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`DPTFeatureExtractor`]. See + [`DPTFeatureExtractor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class FlaxPatchEmbeddings(nn.Module): + + config: DPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + image_size = self.config.image_size + patch_size = self.config.patch_size + num_patches = (image_size // patch_size) * (image_size // patch_size) + self.num_patches = num_patches + self.projection = nn.Conv( + self.config.hidden_size, + kernel_size=(patch_size, patch_size), + strides=(patch_size, patch_size), + padding="VALID", + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + + def __call__(self, pixel_values): + x = self.projection(pixel_values) + batch_size, _, _, channels = x.shape + return jnp.reshape(x, (batch_size, -1, channels)) + + +class FlaxDPTEmbeddings(nn.Module): + """Construct the CLS token, position and patch embeddings.""" + + config: DPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.cls_token = self.param("cls_token", nn.initializers.zeros, (1, 1, self.config.hidden_size)) + self.patch_embeddings = FlaxPatchEmbeddings(self.config, dtype=self.dtype) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = self.param( + "position_embeddings", nn.initializers.zeros, (1, num_patches + 1, self.config.hidden_size) + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, pixel_values, deterministic=True): + batch_size = pixel_values.shape[0] + + embeddings = self.patch_embeddings(pixel_values) + + cls_tokens = jnp.broadcast_to(self.cls_token, (batch_size, 1, self.config.hidden_size)) + embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1) + embeddings = embeddings + self.position_embeddings + embeddings = self.dropout(embeddings, deterministic=deterministic) + return embeddings + + +class FlaxViTSelfAttention(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + if self.config.hidden_size % self.config.num_attention_heads != 0: + raise ValueError( + "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`:" + " {self.config.num_attention_heads}" + ) + + self.query = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + use_bias=self.config.qkv_bias, + ) + self.key = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + use_bias=self.config.qkv_bias, + ) + self.value = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + use_bias=self.config.qkv_bias, + ) + + def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False): + head_dim = self.config.hidden_size // self.config.num_attention_heads + + query_states = self.query(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + value_states = self.value(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + key_states = self.key(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + + dropout_rng = None + if not deterministic and self.config.attention_probs_dropout_prob > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_probs_dropout_prob, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) + + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + +class FlaxDPTViTOutput(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, attention_output, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = hidden_states + attention_output + return hidden_states + + +class FlaxDPTViTSelfOutput(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, input_tensor, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + return hidden_states + + +class FlaxDPTViTAttention(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.attention = FlaxViTSelfAttention(self.config, dtype=self.dtype) + self.output = FlaxDPTViTSelfOutput(self.config, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic=True, output_attentions: bool = False): + attn_outputs = self.attention(hidden_states, deterministic=deterministic, output_attentions=output_attentions) + attn_output = attn_outputs[0] + hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_outputs[1],) + + return outputs + + +class FlaxDPTViTIntermediate(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.intermediate_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.activation = ACT2FN[self.config.hidden_act] + + def __call__(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +# DPT reassemble & Fusion +class FlaxDPTReassembleLayer(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + factor: int = 1 + channels: int = None + + def setup(self): + # projection + self.projection = nn.Conv( + self.config.hidden_size, + kernel_size=(1, 1), + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + + # up/down sampling depending on factor + if self.factor > 1: + self.resize = nn.ConvTranspose( + self.channels, kernel_size=(self.factor, self.factor), strides=(self.factor, self.factor) + ) + elif self.factor < 1: + # so should downsample + self.resize = nn.Conv( + self.channels, + kernel_size=(3, 3), + strides=(int(1 / self.factor), int(1 / self.factor)), + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + + def __call__(self, hidden_state): + hidden_state = self.projection(hidden_state) + if self.factor != 1: + hidden_state = self.resize(hidden_state) + return hidden_state + + +class FlaxDPTReassembleStage(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + + self.layers = [ + FlaxDPTReassembleLayer(self.config, factor=factor, channels=self.config.neck_hidden_sizes[i]) + for i, factor in zip(range(len(self.config.neck_hidden_sizes)), self.config.reassemble_factors) + ] + + if self.config.readout_type == "project": + self.readout_projects = [ + nn.Sequential([nn.Dense(self.config.hidden_size), ACT2FN[self.config.hidden_act]]) + for _ in range(len(self.config.neck_hidden_sizes)) + ] + + def __call__(self, hidden_states): + """ + Args: + hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length + 1, hidden_size)`): + List of hidden states from the backbone. + """ + out = [] + + for i, hidden_state in enumerate(hidden_states): + # reshape to (B, C, H, W) + hidden_state, cls_token = hidden_state[:, 1:], hidden_state[:, 0] + batch_size, sequence_length, num_channels = hidden_state.shape + size = int(math.sqrt(sequence_length)) + hidden_state = jnp.reshape(hidden_state, (batch_size, size, size, num_channels)) + + feature_shape = hidden_state.shape + if self.config.readout_type == "project": + # reshape to (B, H*W, C) + hidden_state = jnp.reshape(hidden_state, (batch_size, size * size, num_channels)) + readout = jnp.expand_dims(cls_token, axis=1) + readout = jnp.repeat(readout, size * size, axis=1) + # concatenate the readout token to the hidden states and project + hidden_state = self.readout_projects[i](jnp.concatenate((hidden_state, readout), axis=-1)) + # reshape back to (B, C, H, W) + hidden_state = jnp.reshape(hidden_state, feature_shape) + elif self.config.readout_type == "add": + hidden_state = jnp.reshape(hidden_state, (batch_size, size * size, num_channels)) + jnp.expand_dims( + cls_token, axis=-1 + ) + hidden_state = jnp.reshape(hidden_state, feature_shape) + hidden_state = self.layers[i](hidden_state) + out.append(hidden_state) + + return out + + +class FlaxDPTFeatureFusionStage(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + super().__init__() + self.layers = [FlaxDPTFeatureFusionLayer(self.config) for _ in range(len(self.config.neck_hidden_sizes))] + + def __call__(self, hidden_states): + # reversing the hidden_states, we start from the last + hidden_states = hidden_states[::-1] + + fused_hidden_states = [] + # first layer only uses the last hidden_state + fused_hidden_state = self.layers[0](hidden_states[0]) + fused_hidden_states.append(fused_hidden_state) + # looping from the last layer to the second + for hidden_state, layer in zip(hidden_states[1:], self.layers[1:]): + fused_hidden_state = layer(fused_hidden_state, hidden_state) + fused_hidden_states.append(fused_hidden_state) + + return fused_hidden_states + + +class FlaxDPTPreActResidualLayer(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + + self.use_batch_norm = self.config.use_batch_norm_in_fusion_residual + self.activation1 = ACT2FN["relu"] + self.convolution1 = nn.Conv( + self.config.fusion_hidden_size, + kernel_size=(3, 3), + strides=(1, 1), + padding=1, + use_bias=not self.use_batch_norm, + ) + + self.activation2 = ACT2FN["relu"] + self.convolution2 = nn.Conv( + self.config.fusion_hidden_size, + kernel_size=(3, 3), + strides=(1, 1), + padding=1, + use_bias=not self.use_batch_norm, + ) + + if self.use_batch_norm: + self.batch_norm1 = nn.BatchNorm(use_running_average=False) + self.batch_norm2 = nn.BatchNorm(use_running_average=False) + + def __call__(self, hidden_state): + residual = hidden_state + hidden_state = self.activation1(hidden_state) + + hidden_state = self.convolution1(hidden_state) + + if self.use_batch_norm: + hidden_state = self.batch_norm1(hidden_state) + + hidden_state = self.activation2(hidden_state) + hidden_state = self.convolution2(hidden_state) + + if self.use_batch_norm: + hidden_state = self.batch_norm2(hidden_state) + + return hidden_state + residual + + def __call__(self, hidden_state): + residual = hidden_state + hidden_state = self.activation1(hidden_state) + + hidden_state = self.convolution1(hidden_state) + + if self.use_batch_norm: + hidden_state = self.batch_norm1(hidden_state) + + hidden_state = self.activation2(hidden_state) + hidden_state = self.convolution2(hidden_state) + + if self.use_batch_norm: + hidden_state = self.batch_norm2(hidden_state) + + return hidden_state + residual + + +class FlaxDPTFeatureFusionLayer(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + align_corners: bool = True + + def setup(self): + self.projection = nn.Conv(self.config.fusion_hidden_size, kernel_size=(1, 1)) # , bias=True) + + self.residual_layer1 = FlaxDPTPreActResidualLayer(self.config) + self.residual_layer2 = FlaxDPTPreActResidualLayer(self.config) + self.upsample = FlaxDPTUpsample() + + def __call__(self, hidden_state, residual=None): + if residual is not None: + if hidden_state.shape != residual.shape: + size = hidden_state.shape + residual = self.upsample(residual, size) + hidden_state = hidden_state + self.residual_layer1(residual) + + hidden_state = self.residual_layer2(hidden_state) + hidden_state = self.upsample(hidden_state) + hidden_state = self.projection(hidden_state) + + return hidden_state + + +class FlaxDPTViTLayer(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.attention = FlaxDPTViTAttention(self.config, dtype=self.dtype) + self.intermediate = FlaxDPTViTIntermediate(self.config, dtype=self.dtype) + self.output = FlaxDPTViTOutput(self.config, dtype=self.dtype) + self.layernorm_before = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.layernorm_after = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False): + attention_outputs = self.attention( + self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention + deterministic=deterministic, + output_attentions=output_attentions, + ) + + attention_output = attention_outputs[0] + + # first residual connection + attention_output = attention_output + hidden_states + + # in ViT, layernorm is also applied after self-attention + layer_output = self.layernorm_after(attention_output) + + hidden_states = self.intermediate(layer_output) + hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attention_outputs[1],) + return outputs + + +class FlaxDPTViTLayerCollection(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxDPTViTLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer(hidden_states, deterministic=deterministic, output_attentions=output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +class FlaxDPTViTPooler(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + cls_hidden_state = hidden_states[:, 0] + cls_hidden_state = self.dense(cls_hidden_state) + return nn.tanh(cls_hidden_state) + + +class FlaxDPTViTEncoder(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layer = FlaxDPTViTLayerCollection(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_states, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + return self.layer( + hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +DPT_START_DOCSTRING = r""" + This model is a Flax [jax.nn.Module](https://jax.readthedocs.io/en/latest/jax.nn.html?highlight=nn.Module) + subclass. Use it as a regular Flax Module and refer to the Flax documentation for all matter related to general + usage and behavior. + + Parameters: + config ([`DPTConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DPT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`jax.numpy.array` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`ViTFeatureExtractor`]. See + [`ViTFeatureExtractor.__call__`] for details. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. +""" + + +class FlaxDPTPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DPTConfig + base_model_prefix = "dpt" + main_input_name = "pixel_values" + module_class: nn.Module = None + + def __init__( + self, + config: DPTConfig, + input_shape=None, + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + if input_shape is None: + input_shape = (1, config.image_size, config.image_size, 3) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + pixel_values = jnp.zeros(input_shape, dtype=self.dtype) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + @add_start_docstrings_to_model_forward(DPT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__( + self, + pixel_values, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[jnp.ndarray] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + return self.module.apply( + {"params": params or self.params}, + jnp.array(pixel_values, dtype=jnp.float32), + not train, + output_attentions, + output_hidden_states, + return_dict, + labels, + rngs=rngs, + ) + + +class FlaxDPTModule(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + add_pooling_layer: bool = True + + def setup(self): + self.embeddings = FlaxDPTEmbeddings(self.config, dtype=self.dtype) + self.encoder = FlaxDPTViTEncoder(self.config, dtype=self.dtype) + self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.pooler = FlaxDPTViTPooler(self.config, dtype=self.dtype) if self.add_pooling_layer else None + + def __call__( + self, + pixel_values, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + labels: Optional[jnp.ndarray] = None, + ): + + hidden_states = self.embeddings(pixel_values, deterministic=deterministic) + + outputs = self.encoder( + hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + hidden_states = self.layernorm(hidden_states) + pooled = self.pooler(hidden_states) if self.add_pooling_layer else None + + if not return_dict: + # if pooled is None, don't return it + if pooled is None: + return (hidden_states,) + outputs[1:] + return (hidden_states, pooled) + outputs[1:] + + return FlaxBaseModelOutputWithPooling( + last_hidden_state=hidden_states, + pooler_output=pooled, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + DPT Model with a semantic segmentation head on top e.g. for ADE20k, CityScapes. + """, + DPT_START_DOCSTRING, +) +class FlaxDPTModel(FlaxDPTPreTrainedModel): + module_class = FlaxDPTModule + + +class FlaxDPTNeck(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + # postprocessing + self.reassemble_stage = FlaxDPTReassembleStage(self.config) + self.conv_list = [ + nn.Conv(self.config.fusion_hidden_size, kernel_size=(3, 3), padding=1, use_bias=False) + for i in range(len(self.config.neck_hidden_sizes)) + ] + # fusion + self.fusion_stage = FlaxDPTFeatureFusionStage(self.config) + + def __call__(self, hidden_states): + if not isinstance(hidden_states, list): + raise ValueError("hidden_states should be a list of tensors") + + if len(hidden_states) != len(self.config.neck_hidden_sizes): + raise ValueError("The number of hidden states should be equal to the number of neck hidden sizes.") + + # postprocess hidden states + features = self.reassemble_stage(hidden_states) + + features = [self.conv_list[i](feature) for i, feature in enumerate(features)] + + # fusion blocks + output = self.fusion_stage(features) + + return output + + +class FlaxDPTUpsample(nn.Module): + scale: int = 2 + method: str = "bilinear" + + def setup(self): + pass + + def __call__(self, x, output_size=None): + if output_size is None: + output_size = x.shape + output_size = (output_size[0], output_size[1] * self.scale, output_size[2] * self.scale, output_size[3]) + return jax.image.resize(x, output_size, method="bilinear") + + +class FlaxDPTDepthEstimationHead(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + + features = self.config.fusion_hidden_size + self.head = nn.Sequential( + [ + nn.Conv(features // 2, kernel_size=(3, 3), strides=(1, 1), padding=1), + FlaxDPTUpsample(scale=2, method="bilinear"), + nn.Conv(32, kernel_size=(3, 3), strides=(1, 1), padding=1), + ACT2FN["relu"], + nn.Conv(1, kernel_size=(1, 1), strides=(1, 1), padding=0), + ACT2FN["relu"], + ] + ) + + def __call__(self, hidden_states): + # use last features + hidden_states = hidden_states[self.config.head_in_index] + + predicted_depth = self.head(hidden_states) + + predicted_depth = jnp.squeeze(predicted_depth, -1) + + return predicted_depth + + +class FlaxDPTForDepthEstimationModule(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + + self.dpt = FlaxDPTModule(self.config, add_pooling_layer=False) + + # Neck + self.neck = FlaxDPTNeck(self.config) + + # Depth estimation head + self.head = FlaxDPTDepthEstimationHead(self.config) + + @add_start_docstrings_to_model_forward(DPT_INPUTS_DOCSTRING) + def __call__( + self, + pixel_values, + deterministic: bool = True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + labels=None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth depth estimation maps for computing the loss. + + Returns: + + Examples: + ```python + >>> from transformers import DPTFeatureExtractor, DPTForDepthEstimation + >>> import torch + >>> import numpy as np + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large") + >>> model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large") + + >>> # prepare image for the model + >>> inputs = feature_extractor(images=image, return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + ... predicted_depth = outputs.predicted_depth + + >>> # interpolate to original size + >>> prediction = torch.nn.functional.interpolate( + ... predicted_depth.unsqueeze(1), + ... size=image.size[::-1], + ... mode="bicubic", + ... align_corners=False, + ... ) + + >>> # visualize the prediction + >>> output = prediction.squeeze().cpu().numpy() + >>> formatted = (output * 255 / np.max(output)).astype("uint8") + >>> depth = Image.fromarray(formatted) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + outputs = self.dpt( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=True, # we need the intermediate hidden states + return_dict=return_dict, + ) + + hidden_states = outputs.hidden_states if return_dict else outputs + + # only keep certain features based on config.backbone_out_indices + # note that the hidden_states also include the initial embeddings + if return_dict: + hidden_states = [ + feature for idx, feature in enumerate(hidden_states) if idx in self.config.backbone_out_indices + ] + else: + hidden_states = [ + feature for idx, feature in enumerate(hidden_states[1]) if idx in self.config.backbone_out_indices + ] + + hidden_states = self.neck(hidden_states) + + predicted_depth = self.head(hidden_states) + + loss = None + if labels is not None: + raise NotImplementedError("Training is not implemented yet") + + if not return_dict: + if output_hidden_states: + output = (predicted_depth,) + outputs[1:] + else: + output = (predicted_depth,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return FlaxDepthEstimatorOutput( + loss=loss, + predicted_depth=predicted_depth, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + DPT Model with a depth estimation head on top (consisting of 3 convolutional layers) e.g. for KITTI, NYUv2. + """, + DPT_START_DOCSTRING, +) +class FlaxDPTForDepthEstimation(FlaxDPTPreTrainedModel): + module_class = FlaxDPTForDepthEstimationModule + + +class FlaxDPTSemanticSegmentationHead(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + # TODO: Change this to make sure BatchNorm + Dropout work / Put them outside a Sequential Module + def setup(self): + features = self.config.fusion_hidden_size + self.head = nn.Sequential( + [ + nn.Conv(features, kernel_size=(3, 3), padding=1), + # nn.BatchNorm(use_running_average=False), + ACT2FN["relu"], + # nn.Dropout(self.config.semantic_classifier_dropout, deterministic=False), + nn.Conv(self.config.num_labels, kernel_size=(1, 1)), + FlaxDPTUpsample(scale=2, method="bilinear"), + ] + ) + + def __call__(self, hidden_states): + # use last features + hidden_states = hidden_states[self.config.head_in_index] + + logits = self.head(hidden_states) + + return logits + + +class FlaxDPTAuxiliaryHead(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + # TODO: Change this to make sure BatchNorm + Dropout work / Put them outside a Sequential Module + def setup(self): + features = self.config.fusion_hidden_size + self.head = nn.Sequential( + [ + nn.Conv(features, kernel_size=(3, 3), padding=1, use_bias=False), # bias=False + # nn.BatchNorm(use_running_average=False), + ACT2FN["relu"], + # nn.Dropout(0.1, deterministic=False), + nn.Conv(self.config.num_labels, kernel_size=(1, 1)), + ] + ) + + def __call__(self, hidden_states): + logits = self.head(hidden_states) + + return logits + + +class FlaxDPTForSemanticSegmentationModule(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + + self.dpt = FlaxDPTModule(self.config, add_pooling_layer=False) + + # Neck + self.neck = FlaxDPTNeck(self.config) + + # Segmentation head(s) + self.head = FlaxDPTSemanticSegmentationHead(self.config) + self.auxiliary_head = FlaxDPTAuxiliaryHead(self.config) if self.config.use_auxiliary_head else None + + self.upsample = FlaxDPTUpsample(scale=2, method="bilinear") + + @add_start_docstrings_to_model_forward(DPT_INPUTS_DOCSTRING) + def __call__( + self, + pixel_values=None, + deterministic: bool = True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + labels=None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + ```python + >>> from transformers import DPTFeatureExtractor, DPTForSemanticSegmentation + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large-ade") + >>> model = DPTForSemanticSegmentation.from_pretrained("Intel/dpt-large-ade") + + >>> inputs = feature_extractor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + outputs = self.dpt( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=True, # we need the intermediate hidden states + return_dict=return_dict, + ) + + hidden_states = outputs.hidden_states if return_dict else outputs + + # only keep certain features based on config.backbone_out_indices + # note that the hidden_states also include the initial embeddings + if return_dict: + hidden_states = [ + feature for idx, feature in enumerate(hidden_states) if idx in self.config.backbone_out_indices + ] + else: + hidden_states = [ + feature for idx, feature in enumerate(hidden_states[1]) if idx in self.config.backbone_out_indices + ] + hidden_states = self.neck(hidden_states) + + logits = self.head(hidden_states) + + auxiliary_logits = None + if self.auxiliary_head is not None: + auxiliary_logits = self.auxiliary_head(hidden_states[-1]) + + loss = None + if labels is not None: + if self.config.num_labels == 1: + raise ValueError("The number of labels should be greater than one") + else: + # upsample logits to the images' original size + output_shape = (logits.shape[0], labels.shape[1], labels.shape[2], logits.shape[3]) + upsampled_logits = self.upsample(logits, output_shape) + + if auxiliary_logits is not None: + upsampled_auxiliary_logits = self.upsample(auxiliary_logits, output_shape) + # compute weighted loss + # Copied from: https://flax.readthedocs.io/en/latest/notebooks/annotated_mnist.html + labels_onehot = jax.nn.one_hot(labels, num_classes=self.config.num_labels) + main_loss = optax.softmax_cross_entropy(logits=upsampled_logits, labels=labels_onehot).mean() + auxiliary_loss = optax.softmax_cross_entropy( + logits=upsampled_auxiliary_logits, labels=labels_onehot + ).mean() + loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss + + if not return_dict: + if output_hidden_states: + output = (logits,) + outputs[1:] + else: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return FlaxSemanticSegmenterOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + DPT Model with a semantic segmentation head on top e.g. for ADE20k, CityScapes. + """, + DPT_START_DOCSTRING, +) +class FlaxDPTForSemanticSegmentation(FlaxDPTPreTrainedModel): + module_class = FlaxDPTForSemanticSegmentationModule diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 44c3b1cf3e4b..07bcca8a3956 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -592,6 +592,34 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) +class FlaxDPTForDepthEstimation(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxDPTForSemanticSegmentation(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxDPTModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxDPTPreTrainedModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxElectraForCausalLM(metaclass=DummyObject): _backends = ["flax"] diff --git a/tests/models/dpt/test_modeling_flax_dpt.py b/tests/models/dpt/test_modeling_flax_dpt.py new file mode 100644 index 000000000000..46c8d8f50a7f --- /dev/null +++ b/tests/models/dpt/test_modeling_flax_dpt.py @@ -0,0 +1,187 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the Flax DPT model. """ + +import inspect +import unittest + +import numpy as np + +from transformers import DPTConfig, is_flax_available +from transformers.testing_utils import require_flax, slow + +from ...test_configuration_common import ConfigTester +from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor + + +if is_flax_available(): + + import jax + from transformers.models.dpt.modeling_flax_dpt import ( + FlaxDPTForDepthEstimation, + FlaxDPTForSemanticSegmentation, + FlaxDPTModel, + ) + + +class FlaxDPTModelTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=2, + image_size=32, + patch_size=16, + num_channels=3, + is_training=True, + use_labels=False, + hidden_size=32, + num_hidden_layers=4, + backbone_out_indices=[0, 1, 2, 3], + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + initializer_range=0.02, + num_labels=3, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.is_training = is_training + self.use_labels = use_labels + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.backbone_out_indices = backbone_out_indices + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.num_labels = num_labels + self.scope = scope + # sequence length of DPT = num_patches + 1 (we add 1 for the [CLS] token) + num_patches = (image_size // patch_size) ** 2 + self.seq_length = num_patches + 1 + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + labels = None + if self.use_labels: + labels = ids_tensor([self.batch_size, self.image_size, self.image_size], self.num_labels) + + config = DPTConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + backbone_out_indices=self.backbone_out_indices, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + is_decoder=False, + initializer_range=self.initializer_range, + ) + + return config, pixel_values, labels + + def create_and_check_model(self, config, pixel_values, labels): + + model = FlaxDPTModel(config=config) + result = model(pixel_values) + # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + image_size = (self.image_size, self.image_size) + patch_size = (self.patch_size, self.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + pixel_values, + labels, + ) = config_and_inputs + inputs_dict = {"pixel_values": pixel_values, "labels": labels} + return config, inputs_dict + + +@require_flax +class FlaxViTModelTest(FlaxModelTesterMixin, unittest.TestCase): + + all_model_classes = ( + (FlaxDPTModel, FlaxDPTForSemanticSegmentation, FlaxDPTForDepthEstimation) if is_flax_available() else () + ) + + def setUp(self) -> None: + self.model_tester = FlaxDPTModelTester(self) + self.config_tester = ConfigTester(self, config_class=DPTConfig, has_text_modality=False, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + # We neeed to override this test because ViT's forward signature is different than text models. + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.__call__) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + # We need to override this test because ViT expects pixel_values instead of input_ids + def test_jit_compilation(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config) + + @jax.jit + def model_jitted(pixel_values, **kwargs): + return model(pixel_values=pixel_values, **kwargs) + + with self.subTest("JIT Enabled"): + jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple() + + with self.subTest("JIT Disabled"): + with jax.disable_jit(): + outputs = model_jitted(**prepared_inputs_dict).to_tuple() + + self.assertEqual(len(outputs), len(jitted_outputs)) + for jitted_output, output in zip(jitted_outputs, outputs): + self.assertEqual(jitted_output.shape, output.shape) + + @slow + def test_model_from_pretrained(self): + for model_class_name in self.all_model_classes: + model = model_class_name.from_pretrained("Intel/dpt-large", from_pt=True) + outputs = model(np.ones((1, 3, 384, 384))) + self.assertIsNotNone(outputs) + + # TODO: add tests for segmentation and depth estimation (logits) diff --git a/utils/check_repo.py b/utils/check_repo.py index c3060b048aef..3f69b7dfd88b 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -147,6 +147,8 @@ "FlaxCLIPTextModel", "FlaxCLIPVisionModel", "FlaxWav2Vec2ForCTC", + "FlaxDPTForSemanticSegmentation", + "FlaxDPTForDepthEstimation", "DetrForSegmentation", "DPRReader", "FlaubertForQuestionAnswering", From 563217b67a59730c578339c7f3602983c3783699 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Sat, 18 Jun 2022 22:46:19 +0200 Subject: [PATCH 02/22] make quality --- .../models/dpt/modeling_flax_dpt.py | 17 ----------------- tests/models/dpt/test_modeling_flax_dpt.py | 2 +- 2 files changed, 1 insertion(+), 18 deletions(-) diff --git a/src/transformers/models/dpt/modeling_flax_dpt.py b/src/transformers/models/dpt/modeling_flax_dpt.py index 83ac0257c641..20140d2d9d79 100644 --- a/src/transformers/models/dpt/modeling_flax_dpt.py +++ b/src/transformers/models/dpt/modeling_flax_dpt.py @@ -433,23 +433,6 @@ def __call__(self, hidden_state): return hidden_state + residual - def __call__(self, hidden_state): - residual = hidden_state - hidden_state = self.activation1(hidden_state) - - hidden_state = self.convolution1(hidden_state) - - if self.use_batch_norm: - hidden_state = self.batch_norm1(hidden_state) - - hidden_state = self.activation2(hidden_state) - hidden_state = self.convolution2(hidden_state) - - if self.use_batch_norm: - hidden_state = self.batch_norm2(hidden_state) - - return hidden_state + residual - class FlaxDPTFeatureFusionLayer(nn.Module): config: DPTConfig diff --git a/tests/models/dpt/test_modeling_flax_dpt.py b/tests/models/dpt/test_modeling_flax_dpt.py index 46c8d8f50a7f..aff61075ae6a 100644 --- a/tests/models/dpt/test_modeling_flax_dpt.py +++ b/tests/models/dpt/test_modeling_flax_dpt.py @@ -127,7 +127,7 @@ def prepare_config_and_inputs_for_common(self): @require_flax -class FlaxViTModelTest(FlaxModelTesterMixin, unittest.TestCase): +class FlaxDPTModelTest(FlaxModelTesterMixin, unittest.TestCase): all_model_classes = ( (FlaxDPTModel, FlaxDPTForSemanticSegmentation, FlaxDPTForDepthEstimation) if is_flax_available() else () From 77220f74882cf3af9d4bdfd97c6bb0bb5c4842d8 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sun, 19 Jun 2022 00:12:17 +0200 Subject: [PATCH 03/22] implement pt_flax equivlence test to bypass `AttributeError: 'NoneType' object has no attribute 'tolist'` --- tests/models/dpt/test_modeling_flax_dpt.py | 138 ++++++++++++++++++++- 1 file changed, 137 insertions(+), 1 deletion(-) diff --git a/tests/models/dpt/test_modeling_flax_dpt.py b/tests/models/dpt/test_modeling_flax_dpt.py index aff61075ae6a..4a868d2180a4 100644 --- a/tests/models/dpt/test_modeling_flax_dpt.py +++ b/tests/models/dpt/test_modeling_flax_dpt.py @@ -20,7 +20,7 @@ import numpy as np from transformers import DPTConfig, is_flax_available -from transformers.testing_utils import require_flax, slow +from transformers.testing_utils import require_flax, slow, is_pt_flax_cross_test from ...test_configuration_common import ConfigTester from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor @@ -84,6 +84,7 @@ def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) labels = None + if self.use_labels: labels = ids_tensor([self.batch_size, self.image_size, self.image_size], self.num_labels) @@ -140,6 +141,141 @@ def setUp(self) -> None: def test_config(self): self.config_tester.run_common_tests() + + @is_pt_flax_cross_test + def test_equivalence_pt_to_flax(self): + import torch + import transformers + import tempfile + import jax.numpy as jnp + from transformers.testing_utils import torch_device + from transformers.modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax + # It might be better to put this inside the for loop below (because we modify the config there). + # But logically, it is fine. + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + + # Output all for aggressive testing + config.output_hidden_states = True + config.output_attentions = self.has_attentions + + # prepare inputs + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + prepared_inputs_dict.pop('labels') + pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()} + + # load corresponding PyTorch class + pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning + pt_model_class = getattr(transformers, pt_model_class_name) + + pt_model = pt_model_class(config).eval() + # Flax models don't use the `use_cache` option and cache is not returned as a default. + # So we disable `use_cache` here for PyTorch model. + pt_model.config.use_cache = False + fx_model = model_class(config, dtype=jnp.float32) + + fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) + fx_model.params = fx_state + + # send pytorch model to the correct device + pt_model.to(torch_device) + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs) + fx_outputs = fx_model(**prepared_inputs_dict) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys) + + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True) + + fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict) + + fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_outputs(fx_outputs_loaded.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys) + + + @is_pt_flax_cross_test + def test_equivalence_flax_to_pt(self): + import torch + import transformers + import tempfile + import jax.numpy as jnp + + from transformers.testing_utils import torch_device + + from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + + # Output all for aggressive testing + config.output_hidden_states = True + config.output_attentions = self.has_attentions + + # prepare inputs + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + prepared_inputs_dict.pop('labels') + pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()} + + # load corresponding PyTorch class + pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning + pt_model_class = getattr(transformers, pt_model_class_name) + + pt_model = pt_model_class(config).eval() + # Flax models don't use the `use_cache` option and cache is not returned as a default. + # So we disable `use_cache` here for PyTorch model. + pt_model.config.use_cache = False + fx_model = model_class(config, dtype=jnp.float32) + + pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) + + # make sure weights are tied in PyTorch + pt_model.tie_weights() + + # send pytorch model to the correct device + pt_model.to(torch_device) + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs) + fx_outputs = fx_model(**prepared_inputs_dict) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys) + + with tempfile.TemporaryDirectory() as tmpdirname: + fx_model.save_pretrained(tmpdirname) + pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True) + + # send pytorch model to the correct device + pt_model_loaded.to(torch_device) + pt_model_loaded.eval() + + with torch.no_grad(): + pt_outputs_loaded = pt_model_loaded(**pt_inputs) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_outputs(fx_outputs.to_tuple(), pt_outputs_loaded.to_tuple(), model_class, names=fx_keys) + + # We neeed to override this test because ViT's forward signature is different than text models. def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() From 5f1341ed5dd81db54ce8dd614efc5cb04bfd44ba Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Sun, 19 Jun 2022 15:42:43 +0200 Subject: [PATCH 04/22] add few changes: - BN seems to work now - Equivalency test pass with tol=1e-4 but only with a hack --- .../models/dpt/modeling_flax_dpt.py | 214 +++++++++++++----- tests/models/dpt/test_modeling_flax_dpt.py | 30 +-- 2 files changed, 170 insertions(+), 74 deletions(-) diff --git a/src/transformers/models/dpt/modeling_flax_dpt.py b/src/transformers/models/dpt/modeling_flax_dpt.py index 20140d2d9d79..8dc8ae1222c2 100644 --- a/src/transformers/models/dpt/modeling_flax_dpt.py +++ b/src/transformers/models/dpt/modeling_flax_dpt.py @@ -282,7 +282,7 @@ class FlaxDPTReassembleLayer(nn.Module): def setup(self): # projection self.projection = nn.Conv( - self.config.hidden_size, + self.channels, kernel_size=(1, 1), dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), @@ -310,22 +310,48 @@ def __call__(self, hidden_state): return hidden_state -class FlaxDPTReassembleStage(nn.Module): +class FlaxDPTReassembleLayerCollection(nn.Module): config: DPTConfig dtype: jnp.dtype = jnp.float32 def setup(self): - self.layers = [ - FlaxDPTReassembleLayer(self.config, factor=factor, channels=self.config.neck_hidden_sizes[i]) + FlaxDPTReassembleLayer(self.config, factor=factor, channels=self.config.neck_hidden_sizes[i], name=str(i)) for i, factor in zip(range(len(self.config.neck_hidden_sizes)), self.config.reassemble_factors) ] + def __call__(self, x, i): + return self.layers[i](x) + + +class FlaxDPTReadoutProjectCollectionLayer(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.layers = [ + nn.Sequential([nn.Dense(self.config.hidden_size, name=str(i)), ACT2FN[self.config.hidden_act]]) + for i in range(len(self.config.neck_hidden_sizes)) + ] + + def __call__(self, hidden_states, i): + return self.layers[i](hidden_states) + + +class FlaxDPTReassembleStage(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + + self.layers = FlaxDPTReassembleLayerCollection(self.config, self.dtype) + if self.config.readout_type == "project": - self.readout_projects = [ - nn.Sequential([nn.Dense(self.config.hidden_size), ACT2FN[self.config.hidden_act]]) - for _ in range(len(self.config.neck_hidden_sizes)) - ] + # self.readout_projects = [ + # nn.Sequential([nn.Dense(self.config.hidden_size), ACT2FN[self.config.hidden_act]]) + # for i in range(len(self.config.neck_hidden_sizes)) + # ] + self.readout_projects = FlaxDPTReadoutProjectCollectionLayer(self.config, self.dtype) def __call__(self, hidden_states): """ @@ -349,7 +375,8 @@ def __call__(self, hidden_states): readout = jnp.expand_dims(cls_token, axis=1) readout = jnp.repeat(readout, size * size, axis=1) # concatenate the readout token to the hidden states and project - hidden_state = self.readout_projects[i](jnp.concatenate((hidden_state, readout), axis=-1)) + # hidden_state = self.readout_projects[i](: + hidden_state = self.readout_projects(jnp.concatenate((hidden_state, readout), axis=-1), i) # reshape back to (B, C, H, W) hidden_state = jnp.reshape(hidden_state, feature_shape) elif self.config.readout_type == "add": @@ -357,7 +384,8 @@ def __call__(self, hidden_states): cls_token, axis=-1 ) hidden_state = jnp.reshape(hidden_state, feature_shape) - hidden_state = self.layers[i](hidden_state) + # hidden_state = self.layers[i](hidden_state) + hidden_state = self.layers(hidden_state, i) out.append(hidden_state) return out @@ -369,7 +397,8 @@ class FlaxDPTFeatureFusionStage(nn.Module): def setup(self): super().__init__() - self.layers = [FlaxDPTFeatureFusionLayer(self.config) for _ in range(len(self.config.neck_hidden_sizes))] + # self.layers = [FlaxDPTFeatureFusionLayer(self.config) for i in range(len(self.config.neck_hidden_sizes))] + self.layers = FlaxDPTFeatureFusionLayerCollection(self.config, self.dtype) def __call__(self, hidden_states): # reversing the hidden_states, we start from the last @@ -377,11 +406,15 @@ def __call__(self, hidden_states): fused_hidden_states = [] # first layer only uses the last hidden_state - fused_hidden_state = self.layers[0](hidden_states[0]) + fused_hidden_state = self.layers(hidden_states[0], residual=None, i=0) fused_hidden_states.append(fused_hidden_state) # looping from the last layer to the second - for hidden_state, layer in zip(hidden_states[1:], self.layers[1:]): - fused_hidden_state = layer(fused_hidden_state, hidden_state) + # for hidden_state, layer in zip(hidden_states[1:], self.layers.layers[1:]): + # fused_hidden_state = layer(fused_hidden_state, hidden_state) + # fused_hidden_states.append(fused_hidden_state) + + for i, hidden_state in enumerate(hidden_states[1:]): + fused_hidden_state = self.layers(fused_hidden_state, residual=hidden_state, i=i + 1) fused_hidden_states.append(fused_hidden_state) return fused_hidden_states @@ -413,8 +446,8 @@ def setup(self): ) if self.use_batch_norm: - self.batch_norm1 = nn.BatchNorm(use_running_average=False) - self.batch_norm2 = nn.BatchNorm(use_running_average=False) + self.batch_norm1 = nn.BatchNorm(use_running_average=True) + self.batch_norm2 = nn.BatchNorm(use_running_average=True) def __call__(self, hidden_state): residual = hidden_state @@ -434,6 +467,19 @@ def __call__(self, hidden_state): return hidden_state + residual +class FlaxDPTFeatureFusionLayerCollection(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.layers = [ + FlaxDPTFeatureFusionLayer(self.config, name=str(i)) for i in range(len(self.config.neck_hidden_sizes)) + ] + + def __call__(self, hidden_states, residual=None, i=0): + return self.layers[i](hidden_states, residual) + + class FlaxDPTFeatureFusionLayer(nn.Module): config: DPTConfig dtype: jnp.dtype = jnp.float32 @@ -685,7 +731,8 @@ def __call__( return_dict, labels, rngs=rngs, - ) + mutable=["batch_stats"], + )[0] class FlaxDPTModule(nn.Module): @@ -746,17 +793,28 @@ class FlaxDPTModel(FlaxDPTPreTrainedModel): module_class = FlaxDPTModule -class FlaxDPTNeck(nn.Module): +class FlaxDPTConvCollection(nn.Module): config: DPTConfig dtype: jnp.dtype = jnp.float32 def setup(self): - # postprocessing - self.reassemble_stage = FlaxDPTReassembleStage(self.config) - self.conv_list = [ - nn.Conv(self.config.fusion_hidden_size, kernel_size=(3, 3), padding=1, use_bias=False) + self.convs = [ + nn.Conv(self.config.fusion_hidden_size, kernel_size=(3, 3), padding=1, use_bias=False, name=str(i)) for i in range(len(self.config.neck_hidden_sizes)) ] + + def __call__(self, features): + return [self.convs[i](feature) for i, feature in enumerate(features)] + + +class FlaxDPTNeck(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + # postprocessing + self.reassemble_stage = FlaxDPTReassembleStage(self.config, self.dtype) + self.convs = FlaxDPTConvCollection(self.config, self.dtype) # fusion self.fusion_stage = FlaxDPTFeatureFusionStage(self.config) @@ -770,7 +828,8 @@ def __call__(self, hidden_states): # postprocess hidden states features = self.reassemble_stage(hidden_states) - features = [self.conv_list[i](feature) for i, feature in enumerate(features)] + # features = [self.convs[i](feature) for i, feature in enumerate(features)] + features = self.convs(features) # fusion blocks output = self.fusion_stage(features) @@ -792,24 +851,39 @@ def __call__(self, x, output_size=None): return jax.image.resize(x, output_size, method="bilinear") -class FlaxDPTDepthEstimationHead(nn.Module): +class FlaxDPTDepthEstimationHeadCollectionLayer(nn.Module): config: DPTConfig dtype: jnp.dtype = jnp.float32 def setup(self): - - features = self.config.fusion_hidden_size - self.head = nn.Sequential( - [ - nn.Conv(features // 2, kernel_size=(3, 3), strides=(1, 1), padding=1), - FlaxDPTUpsample(scale=2, method="bilinear"), - nn.Conv(32, kernel_size=(3, 3), strides=(1, 1), padding=1), - ACT2FN["relu"], - nn.Conv(1, kernel_size=(1, 1), strides=(1, 1), padding=0), - ACT2FN["relu"], - ] + self.conv1 = nn.Conv( + self.config.fusion_hidden_size // 2, kernel_size=(3, 3), strides=(1, 1), padding=1, name="0" ) + self.upsample = FlaxDPTUpsample(scale=2, method="bilinear") + + self.conv2 = nn.Conv(32, kernel_size=(3, 3), strides=(1, 1), padding=1, name="2") + + self.act = ACT2FN["relu"] + + self.conv3 = nn.Conv(1, kernel_size=(1, 1), strides=(1, 1), padding=0, name="4") + + def __call__(self, hidden_state): + x = self.conv1(hidden_state) + x = self.upsample(x) + x = self.conv2(x) + x = self.act(x) + x = self.conv3(x) + return self.act(x) + + +class FlaxDPTDepthEstimationHead(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.head = FlaxDPTDepthEstimationHeadCollectionLayer(self.config, self.dtype) + def __call__(self, hidden_states): # use last features hidden_states = hidden_states[self.config.head_in_index] @@ -943,54 +1017,76 @@ class FlaxDPTForDepthEstimation(FlaxDPTPreTrainedModel): module_class = FlaxDPTForDepthEstimationModule +class FlaxDPTSemanticSegmentationHeadCollectionLayer(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv1 = nn.Conv(self.config.fusion_hidden_size, kernel_size=(3, 3), padding=1, name="0", use_bias=False) + self.bn = nn.BatchNorm(use_running_average=True, name="1") + self.dropout = nn.Dropout(0.1, deterministic=True) + self.act = ACT2FN["relu"] + self.conv2 = nn.Conv(self.config.num_labels, kernel_size=(1, 1), name="4") + self.upsample = FlaxDPTUpsample(scale=2, method="bilinear") + + def __call__(self, hidden_states): + x = self.conv1(hidden_states) + x = self.bn(x) + x = self.act(x) + x = self.dropout(x) + x = self.conv2(x) + x = self.upsample(x) + return x + + class FlaxDPTSemanticSegmentationHead(nn.Module): config: DPTConfig dtype: jnp.dtype = jnp.float32 # TODO: Change this to make sure BatchNorm + Dropout work / Put them outside a Sequential Module def setup(self): - features = self.config.fusion_hidden_size - self.head = nn.Sequential( - [ - nn.Conv(features, kernel_size=(3, 3), padding=1), - # nn.BatchNorm(use_running_average=False), - ACT2FN["relu"], - # nn.Dropout(self.config.semantic_classifier_dropout, deterministic=False), - nn.Conv(self.config.num_labels, kernel_size=(1, 1)), - FlaxDPTUpsample(scale=2, method="bilinear"), - ] - ) + self.head = FlaxDPTSemanticSegmentationHeadCollectionLayer(self.config, self.dtype) + # @nn.compact def __call__(self, hidden_states): # use last features hidden_states = hidden_states[self.config.head_in_index] logits = self.head(hidden_states) + return jnp.transpose(logits, (0, 3, 1, 2)) - return logits + +class FlaxDPTAuxiliaryHeadCollectionLayer(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv1 = nn.Conv(self.config.fusion_hidden_size, kernel_size=(3, 3), padding=1, name="0", use_bias=False) + self.bn = nn.BatchNorm(use_running_average=True, name="1") + self.act = ACT2FN["relu"] + self.dropout = nn.Dropout(0.1, deterministic=True) + self.conv2 = nn.Conv(self.config.num_labels, kernel_size=(1, 1), name="4") + + def __call__(self, hidden_states): + x = self.conv1(hidden_states) + x = self.bn(x) + x = self.act(x) + x = self.dropout(x) + x = self.conv2(x) + return x class FlaxDPTAuxiliaryHead(nn.Module): config: DPTConfig dtype: jnp.dtype = jnp.float32 - # TODO: Change this to make sure BatchNorm + Dropout work / Put them outside a Sequential Module def setup(self): - features = self.config.fusion_hidden_size - self.head = nn.Sequential( - [ - nn.Conv(features, kernel_size=(3, 3), padding=1, use_bias=False), # bias=False - # nn.BatchNorm(use_running_average=False), - ACT2FN["relu"], - # nn.Dropout(0.1, deterministic=False), - nn.Conv(self.config.num_labels, kernel_size=(1, 1)), - ] - ) + self.head = FlaxDPTAuxiliaryHeadCollectionLayer(self.config, self.dtype) def __call__(self, hidden_states): logits = self.head(hidden_states) - return logits + return jnp.transpose(logits, (0, 3, 1, 2)) class FlaxDPTForSemanticSegmentationModule(nn.Module): diff --git a/tests/models/dpt/test_modeling_flax_dpt.py b/tests/models/dpt/test_modeling_flax_dpt.py index 4a868d2180a4..1c6a895a6819 100644 --- a/tests/models/dpt/test_modeling_flax_dpt.py +++ b/tests/models/dpt/test_modeling_flax_dpt.py @@ -20,7 +20,7 @@ import numpy as np from transformers import DPTConfig, is_flax_available -from transformers.testing_utils import require_flax, slow, is_pt_flax_cross_test +from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow from ...test_configuration_common import ConfigTester from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor @@ -141,15 +141,17 @@ def setUp(self) -> None: def test_config(self): self.config_tester.run_common_tests() - @is_pt_flax_cross_test def test_equivalence_pt_to_flax(self): + import tempfile + import torch + + import jax.numpy as jnp import transformers - import tempfile - import jax.numpy as jnp - from transformers.testing_utils import torch_device from transformers.modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax + from transformers.testing_utils import torch_device + # It might be better to put this inside the for loop below (because we modify the config there). # But logically, it is fine. config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -163,7 +165,7 @@ def test_equivalence_pt_to_flax(self): # prepare inputs prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - prepared_inputs_dict.pop('labels') + prepared_inputs_dict.pop("labels") pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()} # load corresponding PyTorch class @@ -204,18 +206,17 @@ def test_equivalence_pt_to_flax(self): self.assertEqual(fx_keys, pt_keys) self.check_outputs(fx_outputs_loaded.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys) - @is_pt_flax_cross_test def test_equivalence_flax_to_pt(self): + import tempfile + import torch + + import jax.numpy as jnp import transformers - import tempfile - import jax.numpy as jnp - + from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model from transformers.testing_utils import torch_device - from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: @@ -227,9 +228,9 @@ def test_equivalence_flax_to_pt(self): # prepare inputs prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - prepared_inputs_dict.pop('labels') + prepared_inputs_dict.pop("labels") pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()} - + # load corresponding PyTorch class pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning pt_model_class = getattr(transformers, pt_model_class_name) @@ -275,7 +276,6 @@ def test_equivalence_flax_to_pt(self): self.assertEqual(fx_keys, pt_keys) self.check_outputs(fx_outputs.to_tuple(), pt_outputs_loaded.to_tuple(), model_class, names=fx_keys) - # We neeed to override this test because ViT's forward signature is different than text models. def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() From 86320ad0aa929bb491ce2e120177189ffd4f9964 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 20 Jun 2022 11:58:21 +0200 Subject: [PATCH 05/22] add modification sequential --- src/transformers/models/dpt/modeling_flax_dpt.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/dpt/modeling_flax_dpt.py b/src/transformers/models/dpt/modeling_flax_dpt.py index 8dc8ae1222c2..95cf697a0dd0 100644 --- a/src/transformers/models/dpt/modeling_flax_dpt.py +++ b/src/transformers/models/dpt/modeling_flax_dpt.py @@ -324,13 +324,27 @@ def __call__(self, x, i): return self.layers[i](x) +class FlaxDPTReadoutProjectSequentialCollectionLayer(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.dense = nn.Dense(self.config.hidden_size, name="0") + self.act = ACT2FN[self.config.hidden_act] + + def __call__(self, x): + x = self.dense(x) + x = self.act(x) + return x + + class FlaxDPTReadoutProjectCollectionLayer(nn.Module): config: DPTConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.layers = [ - nn.Sequential([nn.Dense(self.config.hidden_size, name=str(i)), ACT2FN[self.config.hidden_act]]) + FlaxDPTReadoutProjectSequentialCollectionLayer(self.config, self.dtype, name=str(i)) for i in range(len(self.config.neck_hidden_sizes)) ] From 6b2c28c81b8424ca8435425f6016670c3eafb884 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 20 Jun 2022 13:43:11 +0200 Subject: [PATCH 06/22] fix copies --- .../models/dpt/modeling_flax_dpt.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/dpt/modeling_flax_dpt.py b/src/transformers/models/dpt/modeling_flax_dpt.py index 95cf697a0dd0..57c85b8c283e 100644 --- a/src/transformers/models/dpt/modeling_flax_dpt.py +++ b/src/transformers/models/dpt/modeling_flax_dpt.py @@ -42,7 +42,7 @@ DPT_START_DOCSTRING = r""" This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading, saving and converting weights from PyTorch models) This + library implements for all its model (such as downloading, saving and converting weights from Flax models) This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and behavior. Finally, this model supports inherent JAX features such as: @@ -51,7 +51,7 @@ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) Parameters: - config ([`ViTConfig`]): Model configuration class with all the parameters of the model. + config ([`DPTConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): @@ -79,6 +79,7 @@ """ +# Copied from transformers.models.vit.modeling_flax_vit.FlaxPatchEmbeddings with ViT->DPT class FlaxPatchEmbeddings(nn.Module): config: DPTConfig @@ -104,6 +105,7 @@ def __call__(self, pixel_values): return jnp.reshape(x, (batch_size, -1, channels)) +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTEmbeddings with ViT->DPT class FlaxDPTEmbeddings(nn.Module): """Construct the CLS token, position and patch embeddings.""" @@ -131,7 +133,8 @@ def __call__(self, pixel_values, deterministic=True): return embeddings -class FlaxViTSelfAttention(nn.Module): +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTSelfAttention with ViT->DPT +class FlaxDPTSelfAttention(nn.Module): config: DPTConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation @@ -215,7 +218,8 @@ def __call__(self, hidden_states, attention_output, deterministic: bool = True): return hidden_states -class FlaxDPTViTSelfOutput(nn.Module): +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTSelfOutput with ViT->DPT +class FlaxDPTSelfOutput(nn.Module): config: DPTConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation @@ -233,13 +237,14 @@ def __call__(self, hidden_states, input_tensor, deterministic: bool = True): return hidden_states +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTAttention with ViT->DPT class FlaxDPTViTAttention(nn.Module): config: DPTConfig dtype: jnp.dtype = jnp.float32 def setup(self): - self.attention = FlaxViTSelfAttention(self.config, dtype=self.dtype) - self.output = FlaxDPTViTSelfOutput(self.config, dtype=self.dtype) + self.attention = FlaxDPTSelfAttention(self.config, dtype=self.dtype) + self.output = FlaxDPTSelfOutput(self.config, dtype=self.dtype) def __call__(self, hidden_states, deterministic=True, output_attentions: bool = False): attn_outputs = self.attention(hidden_states, deterministic=deterministic, output_attentions=output_attentions) @@ -254,6 +259,7 @@ def __call__(self, hidden_states, deterministic=True, output_attentions: bool = return outputs +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTIntermediate with ViT->DPT class FlaxDPTViTIntermediate(nn.Module): config: DPTConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation From 926fb778d7f3b856647e4ba9544d77e6a756368f Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 20 Jun 2022 23:14:16 +0200 Subject: [PATCH 07/22] few fixes - more documentation - fix nit --- src/transformers/modeling_flax_outputs.py | 34 +++++++++---------- .../models/dpt/modeling_flax_dpt.py | 22 ++++++------ 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/src/transformers/modeling_flax_outputs.py b/src/transformers/modeling_flax_outputs.py index 51fbe101b36c..cb93bd319536 100644 --- a/src/transformers/modeling_flax_outputs.py +++ b/src/transformers/modeling_flax_outputs.py @@ -648,18 +648,18 @@ class FlaxDepthEstimatorOutput(ModelOutput): Base class for outputs of depth estimation models. Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Classification (or regression if config.num_labels==1) loss. - predicted_depth (`torch.FloatTensor` of shape `(batch_size, height, width)`): + loss (`jnp.ndarray` of shape `(1,)`, *optional*, returned when `labels` is provided): + Depth Estimation loss. + predicted_depth (`jnp.ndarray` of shape `(batch_size, height, width)`): Predicted depth for each pixel. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one + for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention @@ -678,10 +678,10 @@ class FlaxSemanticSegmenterOutput(ModelOutput): Base class for outputs of depth estimation models. Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Classification (or regression if config.num_labels==1) loss. - logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`): - Classification scores for each pixel. + loss (`jnp.ndarray` of shape `(1,)`, *optional*, returned when `labels` is provided): + Semantic Segmentation loss (CrossEntropy Loss). + logits (`jnp.ndarray` of shape `(batch_size, config.num_labels, logits_height, logits_width)`): + Semantic Segmentation raw logits for each pixel. @@ -691,13 +691,13 @@ class FlaxSemanticSegmenterOutput(ModelOutput): - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one + for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention diff --git a/src/transformers/models/dpt/modeling_flax_dpt.py b/src/transformers/models/dpt/modeling_flax_dpt.py index 57c85b8c283e..a82d1cf78504 100644 --- a/src/transformers/models/dpt/modeling_flax_dpt.py +++ b/src/transformers/models/dpt/modeling_flax_dpt.py @@ -868,7 +868,7 @@ def __call__(self, x, output_size=None): if output_size is None: output_size = x.shape output_size = (output_size[0], output_size[1] * self.scale, output_size[2] * self.scale, output_size[3]) - return jax.image.resize(x, output_size, method="bilinear") + return jax.image.resize(x, output_size, method=self.method) class FlaxDPTDepthEstimationHeadCollectionLayer(nn.Module): @@ -1044,16 +1044,16 @@ class FlaxDPTSemanticSegmentationHeadCollectionLayer(nn.Module): def setup(self): self.conv1 = nn.Conv(self.config.fusion_hidden_size, kernel_size=(3, 3), padding=1, name="0", use_bias=False) self.bn = nn.BatchNorm(use_running_average=True, name="1") - self.dropout = nn.Dropout(0.1, deterministic=True) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.act = ACT2FN["relu"] self.conv2 = nn.Conv(self.config.num_labels, kernel_size=(1, 1), name="4") self.upsample = FlaxDPTUpsample(scale=2, method="bilinear") - def __call__(self, hidden_states): + def __call__(self, hidden_states, deterministic=True): x = self.conv1(hidden_states) x = self.bn(x) x = self.act(x) - x = self.dropout(x) + x = self.dropout(x, deterministic=deterministic) x = self.conv2(x) x = self.upsample(x) return x @@ -1068,11 +1068,11 @@ def setup(self): self.head = FlaxDPTSemanticSegmentationHeadCollectionLayer(self.config, self.dtype) # @nn.compact - def __call__(self, hidden_states): + def __call__(self, hidden_states, deterministic=True): # use last features hidden_states = hidden_states[self.config.head_in_index] - logits = self.head(hidden_states) + logits = self.head(hidden_states, deterministic=deterministic) return jnp.transpose(logits, (0, 3, 1, 2)) @@ -1084,14 +1084,14 @@ def setup(self): self.conv1 = nn.Conv(self.config.fusion_hidden_size, kernel_size=(3, 3), padding=1, name="0", use_bias=False) self.bn = nn.BatchNorm(use_running_average=True, name="1") self.act = ACT2FN["relu"] - self.dropout = nn.Dropout(0.1, deterministic=True) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.conv2 = nn.Conv(self.config.num_labels, kernel_size=(1, 1), name="4") - def __call__(self, hidden_states): + def __call__(self, hidden_states, deterministic=True): x = self.conv1(hidden_states) x = self.bn(x) x = self.act(x) - x = self.dropout(x) + x = self.dropout(x, deterministic=deterministic) x = self.conv2(x) return x @@ -1103,8 +1103,8 @@ class FlaxDPTAuxiliaryHead(nn.Module): def setup(self): self.head = FlaxDPTAuxiliaryHeadCollectionLayer(self.config, self.dtype) - def __call__(self, hidden_states): - logits = self.head(hidden_states) + def __call__(self, hidden_states, deterministic=True): + logits = self.head(hidden_states, deterministic=deterministic) return jnp.transpose(logits, (0, 3, 1, 2)) From 72bf6a427a64089d2a2cad565dc23b6ff83178f1 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Sat, 25 Jun 2022 20:10:32 +0200 Subject: [PATCH 08/22] changes - add custom conv transpose2d function - modify test --- .../models/dpt/gradient_convolution.py | 460 ++++++++++++++++++ .../models/dpt/modeling_flax_dpt.py | 45 +- tests/models/dpt/test_modeling_flax_dpt.py | 145 +----- 3 files changed, 487 insertions(+), 163 deletions(-) create mode 100644 src/transformers/models/dpt/gradient_convolution.py diff --git a/src/transformers/models/dpt/gradient_convolution.py b/src/transformers/models/dpt/gradient_convolution.py new file mode 100644 index 000000000000..d67fb9b6537c --- /dev/null +++ b/src/transformers/models/dpt/gradient_convolution.py @@ -0,0 +1,460 @@ +from typing import (Any, Callable, Optional, Sequence, Tuple, + Union) +import flax.linen as nn +from jax import lax +from typing import (Any, Callable, NamedTuple, Optional, Sequence, Tuple, + Union) +from jax.lax import conv_general_dilated +from flax.linen.initializers import lecun_normal +from flax.linen.initializers import zeros +from flax.linen.module import compact +from jax import lax +import jax.numpy as jnp +import numpy as np + +default_kernel_init = lecun_normal() + +PRNGKey = Any +Shape = Tuple[int, ...] +Dtype = Any # this could be a real type? +Array = Any +PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]] + +PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]] +PaddingLike = Union[str, int, Sequence[Union[int, Tuple[int, int]]]] +LaxPadding = Union[str, Sequence[Tuple[int, int]]] + + +class ConvDimensionNumbers(NamedTuple): + """ + Describes batch, spatial, and feature dimensions of a convolution. + + Args: + lhs_spec: a tuple of nonnegative integer dimension numbers containing + `(batch dimension, feature dimension, spatial dimensions...)`. + rhs_spec: a tuple of nonnegative integer dimension numbers containing + `(out feature dimension, in feature dimension, spatial dimensions...)`. + out_spec: a tuple of nonnegative integer dimension numbers containing + `(batch dimension, feature dimension, spatial dimensions...)`. + """ + lhs_spec: Sequence[int] + rhs_spec: Sequence[int] + out_spec: Sequence[int] + +ConvGeneralDilatedDimensionNumbers = Union[ + None, ConvDimensionNumbers, Tuple[str, str, str]] + +def _flip_axes(x, axes): + """Flip ndarray 'x' along each axis specified in axes tuple.""" + for axis in axes: + x = np.flip(x, axis) + return x + +def conv_general_permutations(dimension_numbers): + """Utility for convolution dimension permutations relative to Conv HLO.""" + lhs_spec, rhs_spec, out_spec = dimension_numbers + lhs_char, rhs_char, out_char = charpairs = ("N", "C"), ("O", "I"), ("N", "C") + for i, (a, b) in enumerate(charpairs): + if not dimension_numbers[i].count(a) == dimension_numbers[i].count(b) == 1: + msg = ("convolution dimension_numbers[{}] must contain the characters " + "'{}' and '{}' exactly once, got {}.") + raise TypeError(msg.format(i, a, b, dimension_numbers[i])) + if len(dimension_numbers[i]) != len(set(dimension_numbers[i])): + msg = ("convolution dimension_numbers[{}] cannot have duplicate " + "characters, got {}.") + raise TypeError(msg.format(i, dimension_numbers[i])) + if not (set(lhs_spec) - set(lhs_char) == set(rhs_spec) - set(rhs_char) == + set(out_spec) - set(out_char)): + msg = ("convolution dimension_numbers elements must each have the same " + "set of spatial characters, got {}.") + raise TypeError(msg.format(dimension_numbers)) + + def getperm(spec, charpair): + spatial = (i for i, c in enumerate(spec) if c not in charpair) + if spec is not rhs_spec: + spatial = sorted(spatial, key=lambda i: rhs_spec.index(spec[i])) + return (spec.index(charpair[0]), spec.index(charpair[1])) + tuple(spatial) + + lhs_perm, rhs_perm, out_perm = map(getperm, dimension_numbers, charpairs) + return lhs_perm, rhs_perm, out_perm + +def conv_dimension_numbers_(lhs_shape, rhs_shape, dimension_numbers + ) -> ConvDimensionNumbers: + """Converts convolution `dimension_numbers` to a `ConvDimensionNumbers`. + + Args: + lhs_shape: tuple of nonnegative integers, shape of the convolution input. + rhs_shape: tuple of nonnegative integers, shape of the convolution kernel. + dimension_numbers: None or a tuple/list of strings or a ConvDimensionNumbers + object following the convolution dimension number specification format in xla_client.py. + + Returns: + A `ConvDimensionNumbers` object that represents `dimension_numbers` in the canonical form used by lax functions. + """ + if isinstance(dimension_numbers, ConvDimensionNumbers): + return dimension_numbers + if len(lhs_shape) != len(rhs_shape): + msg = "convolution requires lhs and rhs ndim to be equal, got {} and {}." + raise TypeError(msg.format(len(lhs_shape), len(rhs_shape))) + + if dimension_numbers is None: + iota = tuple(range(len(lhs_shape))) + return ConvDimensionNumbers(iota, iota, iota) + elif isinstance(dimension_numbers, (list, tuple)): + if len(dimension_numbers) != 3: + msg = "convolution dimension_numbers list/tuple must be length 3, got {}." + raise TypeError(msg.format(len(dimension_numbers))) + if not all(isinstance(elt, str) for elt in dimension_numbers): + msg = "convolution dimension_numbers elements must be strings, got {}." + raise TypeError(msg.format(tuple(map(type, dimension_numbers)))) + msg = ("convolution dimension_numbers[{}] must have len equal to the ndim " + "of lhs and rhs, got {} for lhs and rhs shapes {} and {}.") + for i, elt in enumerate(dimension_numbers): + if len(elt) != len(lhs_shape): + raise TypeError(msg.format(i, len(elt), lhs_shape, rhs_shape)) + + lhs_spec, rhs_spec, out_spec = conv_general_permutations(dimension_numbers) + return ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec) + else: + msg = "convolution dimension_numbers must be tuple/list or None, got {}." + raise TypeError(msg.format(type(dimension_numbers))) + +def _deconv_output_length(input_length, + filter_size, + padding, + output_padding=None, + stride=0, + dilation=1): + """Determines the output length of a transposed convolution given the input length. + Arguments: + Function modified from Keras. + input_length: Integer. filter_size: Integer. padding: one of `"SAME"`, `"VALID"`, or a 2-integer tuple. + output_padding: Integer, amount of padding along the output dimension. Can + be set to `None` in which case the output length is inferred. + stride: Integer. dilation: Integer. + Returns: + The output length (integer). + """ + if input_length is None: + return None + + # Get the dilated kernel size + filter_size = filter_size + (filter_size - 1) * (dilation - 1) + + # Infer length if output padding is None, else compute the exact length + if output_padding is None: + if padding == 'VALID': + length = input_length * stride + max(filter_size - stride, 0) + elif padding == 'SAME': + length = input_length * stride + else: + length = ((input_length - 1) * stride + filter_size + - padding[0] - padding[1]) + + else: + if padding == 'SAME': + pad = filter_size // 2 + total_pad = pad * 2 + elif padding == 'VALID': + total_pad = 0 + else: + total_pad = padding[0] + padding[1] + + length = ((input_length - 1) * stride + filter_size - total_pad + + output_padding) + + return length + +def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding: + if isinstance(padding, str): + return padding + if isinstance(padding, int): + return [(padding, padding)] * rank + + + if isinstance(padding, Sequence) and len(padding) == rank: + new_pad = [] + for p in padding: + if isinstance(p, int): + new_pad.append((p, p)) + elif isinstance(p, tuple) and len(p) == 2: + new_pad.append(p) + else: + break + if len(new_pad) == rank: + return new_pad + raise ValueError( + f'Invalid padding format: {padding}, should be str, int,' + f' or a sequence of len {rank} where each element is an' + f' int or pair of ints.') + +def _compute_adjusted_padding( + input_size: int, + output_size: int, + kernel_size: int, + stride: int, + padding: Union[str, Tuple[int, int]], + dilation: int = 1, +) -> Tuple[int, int]: + """Computes adjusted padding for desired ConvTranspose `output_size`. + Ported from DeepMind Haiku. + """ + kernel_size = (kernel_size - 1) * dilation + 1 + if padding == "VALID": + expected_input_size = (output_size - kernel_size + stride) // stride + if input_size != expected_input_size: + raise ValueError(f"The expected input size with the current set of input " + f"parameters is {expected_input_size} which doesn't " + f"match the actual input size {input_size}.") + padding_before = 0 + elif padding == "SAME": + expected_input_size = (output_size + stride - 1) // stride + if input_size != expected_input_size: + raise ValueError(f"The expected input size with the current set of input " + f"parameters is {expected_input_size} which doesn't " + f"match the actual input size {input_size}.") + padding_needed = max(0, + (input_size - 1) * stride + kernel_size - output_size) + padding_before = padding_needed // 2 + else: + padding_before = padding[0] # type: ignore[assignment] + + expanded_input_size = (input_size - 1) * stride + 1 + padded_out_size = output_size + kernel_size - 1 + pad_before = kernel_size - 1 - padding_before + pad_after = padded_out_size - expanded_input_size - pad_before + return (pad_before, pad_after) + +def gradient_based_conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int], + padding: Union[str, Sequence[Tuple[int, int]]], + output_padding: Optional[Sequence[int]] = None, + output_shape: Optional[Sequence[int]] = None, + dilation: Optional[Sequence[int]] = None, + dimension_numbers: ConvGeneralDilatedDimensionNumbers = None, + transpose_kernel: bool = True, + precision: PrecisionLike = None) -> Array: + """Convenience wrapper for calculating the N-d transposed convolution. + Args: + Much like *conv_transpose*, this function calculates transposed convolutions via fractionally strided convolution + rather than calculating the gradient (transpose) of a forward convolution. However, the latter is more common among + deep learning frameworks, such as TensorFlow, PyTorch, and Keras. This function provides the same set of APIs to help: + reproduce results in these frameworks. + lhs: a rank *n+2* dimensional input array. rhs: a rank *n+2* dimensional array of kernel weights. strides: sequence + of *n* integers, amounts to strides of the corresponding forward convolution. padding: *"SAME"*, *"VALID"*, or a + sequence of *n* integer 2-tuples that controls + the before-and-after padding for each *n* spatial dimension of the corresponding forward convolution. + output_padding: A sequence of integers specifying the amount of padding along + each spacial dimension of the output tensor, used to disambiguate the output shape of transposed convolutions + when the stride is larger than 1. (see a detailed description at + 1https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html) The amount of output padding along a + given dimension must be lower than the stride along that same dimension. If set to *None* (default), the output + shape is inferred. If both *output_padding* and *output_shape* are specified, they have to be mutually + compatible. + output_shape: Output shape of the spatial dimensions of a transpose + convolution. Can be *None* or an iterable of *n* integers. If a *None* value is given (default), the shape is + automatically calculated. Similar to *output_padding*, *output_shape* is also for disambiguating the output shape + when stride > 1 (see also https://www.tensorflow.org/api_docs/python/tf/nn/conv2d_transpose) If both + *output_padding* and *output_shape* are specified, they have to be mutually compatible. + dilation: *None*, or a sequence of *n* integers, giving the + dilation factor to apply in each spatial dimension of *rhs*. Dilated convolution is also known as atrous + convolution. + dimension_numbers: tuple of dimension descriptors as in + lax.conv_general_dilated. Defaults to tensorflow convention. + transpose_kernel: if *True* flips spatial axes and swaps the input/output + channel axes of the kernel. This makes the output of this function identical to the gradient-derived functions + like keras.layers.Conv2DTranspose and torch.nn.ConvTranspose2d applied to the same kernel. Although for typical + use in neural nets this is unnecessary and makes input/output channel specification confusing, you need to set + this to *True* in order to match the behavior in many deep learning frameworks, such as TensorFlow, Keras, and + PyTorch. + precision: Optional. Either `None`, which means the default precision for + the backend, a `lax.Precision` enum value (`Precision.DEFAULT`, `Precision.HIGH` or `Precision.HIGHEST`) or a + tuple of two `lax.Precision` enums indicating precision of ``lhs``` and `rhs`. + Returns: + Transposed N-d convolution. + """ + assert len(lhs.shape) == len(rhs.shape) and len(lhs.shape) >= 2 + ndims = len(lhs.shape) + one = (1,) * (ndims - 2) + # Set dimensional layout defaults if not specified. + if dimension_numbers is None: + if ndims == 2: + dimension_numbers = ('NC', 'IO', 'NC') + elif ndims == 3: + dimension_numbers = ('NHC', 'HIO', 'NHC') + elif ndims == 4: + dimension_numbers = ('NHWC', 'HWIO', 'NHWC') + elif ndims == 5: + dimension_numbers = ('NHWDC', 'HWDIO', 'NHWDC') + else: + raise ValueError('No 4+ dimensional dimension_number defaults.') + dn = conv_dimension_numbers_(lhs.shape, rhs.shape, dimension_numbers) + # dn = dimension_numbers + k_shape = np.take(rhs.shape, dn.rhs_spec) + k_sdims = k_shape[2:] # type: ignore[index] + i_shape = np.take(lhs.shape, dn.lhs_spec) + i_sdims = i_shape[2:] # type: ignore[index] + + # Calculate correct output shape given padding and strides. + if dilation is None: + dilation = (1,) * (rhs.ndim - 2) + + if output_padding is None: + output_padding = [None] * (rhs.ndim - 2) # type: ignore[list-item] + + if isinstance(padding, str): + if padding in {'SAME', 'VALID'}: + padding = [padding] * (rhs.ndim - 2) # type: ignore[list-item] + else: + raise ValueError(f"`padding` must be 'VALID' or 'SAME'. Passed: {padding}.") + + inferred_output_shape = tuple(map(_deconv_output_length, i_sdims, k_sdims, + padding, output_padding, strides, dilation)) + if output_shape is None: + output_shape = inferred_output_shape # type: ignore[assignment] + else: + if not output_shape == inferred_output_shape: + raise ValueError(f"`output_padding` and `output_shape` are not compatible." + f"Inferred output shape from `output_padding`: {inferred_output_shape}, " + f"but got `output_shape` {output_shape}") + + pads = tuple(map(_compute_adjusted_padding, i_sdims, output_shape, + k_sdims, strides, padding, dilation)) + + if transpose_kernel: + # flip spatial dims and swap input / output channel axes + rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:]) + rhs = np.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1]) + return conv_general_dilated(lhs, rhs, one, pads, strides, dilation, dimension_numbers, + precision=precision) + + +class ConvTransposeGradient(nn.Module): + """Convolution Module wrapping lax.conv_transpose. + + Attributes: + features: number of convolution filters. + kernel_size: shape of the convolutional kernel. For 1D convolution, + the kernel size can be passed as an integer. For all other cases, it must + be a sequence of integers. + strides: a sequence of `n` integers, representing the inter-window strides. + padding: either the string `'SAME'`, the string `'VALID'`, the string + `'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low, + high)` integer pairs that give the padding to apply before and after each + spatial dimension. A single int is interpeted as applying the same padding + in all dims and passign a single int in a sequence causes the same padding + to be used on both sides. + kernel_dilation: `None`, or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of the convolution + kernel. Convolution with kernel dilation is also known as 'atrous + convolution'. + use_bias: whether to add a bias to the output (default: True). + dtype: the dtype of the computation (default: float32). + param_dtype: the dtype passed to parameter initializers (default: float32). + precision: numerical precision of the computation see `jax.lax.Precision` + for details. + kernel_init: initializer for the convolutional kernel. + bias_init: initializer for the bias. + """ + features: int + kernel_size: Union[int, Tuple[int, ...]] + strides: Optional[Tuple[int, ...]] = None + padding: PaddingLike = 'SAME' + kernel_dilation: Optional[Sequence[int]] = None + use_bias: bool = True + dtype: Dtype = jnp.float32 + param_dtype: Dtype = jnp.float32 + precision: PrecisionLike = None + kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init + bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros + + @compact + def __call__(self, inputs: Array) -> Array: + """Applies a transposed convolution to the inputs. + + Behaviour mirrors of `jax.lax.conv_transpose`. + + Args: + inputs: input data with dimensions (batch, spatial_dims..., features). + This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: + this is different from the input convention used by `lax.conv_general_dilated`, which puts the spatial + dimensions last. + + Returns: + The convolved data. + """ + inputs = jnp.asarray(inputs, self.dtype) + + kernel_size: Tuple[int, ...] + if isinstance(self.kernel_size, int): + kernel_size = (self.kernel_size,) + else: + kernel_size = self.kernel_size + + is_single_input = False + if inputs.ndim == len(kernel_size) + 1: + is_single_input = True + inputs = jnp.expand_dims(inputs, axis=0) + + strides: Tuple[int, ...] + strides = self.strides or (1,) * (inputs.ndim - 2) + + in_features = inputs.shape[-1] + kernel_shape = kernel_size + (in_features, self.features) + kernel = self.param('kernel', self.kernel_init, kernel_shape, + self.param_dtype) + kernel = jnp.asarray(kernel, self.dtype) + + padding_lax = canonicalize_padding(self.padding, len(kernel_size)) + if padding_lax == 'CIRCULAR': + padding_lax = 'VALID' + + y = gradient_based_conv_transpose( + inputs, + kernel, + strides, + padding_lax, + dilation=self.kernel_dilation, + precision=self.precision) + + if self.padding == 'CIRCULAR': + # For circular padding, we need to identify the size of the final output + # ("period") along each spatial dimension, pad each dimension to an + # integer number of periods, and wrap the array periodically around each + # dimension. Padding should be done in such a way that the start of the + # original input data inside the padded array is located at integer + # number of periods - otherwise the result would be circularly shifted. + + # Compute period along each spatial dimension - it's input size scaled + # by the stride. + scaled_x_dims = [ + x_dim * stride for x_dim, stride in zip(inputs.shape[1:-1], strides) + ] + # Compute difference between the current size of y and the final output + # size, and complement this difference to 2 * period - that gives how + # much we need to pad. + size_diffs = [ + -(y_dim - x_dim) % (2 * x_dim) + for y_dim, x_dim in zip(y.shape[1:-1], scaled_x_dims) + ] + # Divide the padding equaly between left and right. The choice to put + # "+1" on the left (and not on the right) represents a convention for + # aligning even-sized kernels. + total_pad = [ + ((size_diff + 1) // 2, size_diff // 2) for size_diff in size_diffs + ] + y = np.pad(y, [(0, 0)] + total_pad + [(0, 0)]) + # Wrap the result periodically around each spatial dimension, + # one by one. + for i in range(1, y.ndim - 1): + y = y.reshape(y.shape[:i] + (-1, scaled_x_dims[i - 1]) + + y.shape[i + 1:]) + y = y.sum(axis=i) + + if is_single_input: + y = jnp.squeeze(y, axis=0) + if self.use_bias: + bias = self.param('bias', self.bias_init, (self.features,), + self.param_dtype) + bias = jnp.asarray(bias, self.dtype) + y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) + return y \ No newline at end of file diff --git a/src/transformers/models/dpt/modeling_flax_dpt.py b/src/transformers/models/dpt/modeling_flax_dpt.py index a82d1cf78504..78ac6aa2ab00 100644 --- a/src/transformers/models/dpt/modeling_flax_dpt.py +++ b/src/transformers/models/dpt/modeling_flax_dpt.py @@ -38,6 +38,7 @@ from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward from .configuration_dpt import DPTConfig +from .gradient_convolution import ConvTransposeGradient DPT_START_DOCSTRING = r""" @@ -292,12 +293,16 @@ def setup(self): kernel_size=(1, 1), dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + use_bias=True, ) # up/down sampling depending on factor if self.factor > 1: - self.resize = nn.ConvTranspose( - self.channels, kernel_size=(self.factor, self.factor), strides=(self.factor, self.factor) + self.resize = ConvTransposeGradient( + self.channels, + kernel_size=(self.factor, self.factor), + strides=(self.factor, self.factor), + use_bias=True, ) elif self.factor < 1: # so should downsample @@ -307,6 +312,7 @@ def setup(self): strides=(int(1 / self.factor), int(1 / self.factor)), dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + padding=(1, 1), ) def __call__(self, hidden_state): @@ -340,8 +346,7 @@ def setup(self): def __call__(self, x): x = self.dense(x) - x = self.act(x) - return x + return self.act(x) class FlaxDPTReadoutProjectCollectionLayer(nn.Module): @@ -367,10 +372,6 @@ def setup(self): self.layers = FlaxDPTReassembleLayerCollection(self.config, self.dtype) if self.config.readout_type == "project": - # self.readout_projects = [ - # nn.Sequential([nn.Dense(self.config.hidden_size), ACT2FN[self.config.hidden_act]]) - # for i in range(len(self.config.neck_hidden_sizes)) - # ] self.readout_projects = FlaxDPTReadoutProjectCollectionLayer(self.config, self.dtype) def __call__(self, hidden_states): @@ -395,7 +396,6 @@ def __call__(self, hidden_states): readout = jnp.expand_dims(cls_token, axis=1) readout = jnp.repeat(readout, size * size, axis=1) # concatenate the readout token to the hidden states and project - # hidden_state = self.readout_projects[i](: hidden_state = self.readout_projects(jnp.concatenate((hidden_state, readout), axis=-1), i) # reshape back to (B, C, H, W) hidden_state = jnp.reshape(hidden_state, feature_shape) @@ -417,7 +417,6 @@ class FlaxDPTFeatureFusionStage(nn.Module): def setup(self): super().__init__() - # self.layers = [FlaxDPTFeatureFusionLayer(self.config) for i in range(len(self.config.neck_hidden_sizes))] self.layers = FlaxDPTFeatureFusionLayerCollection(self.config, self.dtype) def __call__(self, hidden_states): @@ -429,14 +428,9 @@ def __call__(self, hidden_states): fused_hidden_state = self.layers(hidden_states[0], residual=None, i=0) fused_hidden_states.append(fused_hidden_state) # looping from the last layer to the second - # for hidden_state, layer in zip(hidden_states[1:], self.layers.layers[1:]): - # fused_hidden_state = layer(fused_hidden_state, hidden_state) - # fused_hidden_states.append(fused_hidden_state) - for i, hidden_state in enumerate(hidden_states[1:]): fused_hidden_state = self.layers(fused_hidden_state, residual=hidden_state, i=i + 1) fused_hidden_states.append(fused_hidden_state) - return fused_hidden_states @@ -452,7 +446,7 @@ def setup(self): self.config.fusion_hidden_size, kernel_size=(3, 3), strides=(1, 1), - padding=1, + padding="SAME", use_bias=not self.use_batch_norm, ) @@ -461,7 +455,7 @@ def setup(self): self.config.fusion_hidden_size, kernel_size=(3, 3), strides=(1, 1), - padding=1, + padding="SAME", use_bias=not self.use_batch_norm, ) @@ -848,7 +842,6 @@ def __call__(self, hidden_states): # postprocess hidden states features = self.reassemble_stage(hidden_states) - # features = [self.convs[i](feature) for i, feature in enumerate(features)] features = self.convs(features) # fusion blocks @@ -877,16 +870,16 @@ class FlaxDPTDepthEstimationHeadCollectionLayer(nn.Module): def setup(self): self.conv1 = nn.Conv( - self.config.fusion_hidden_size // 2, kernel_size=(3, 3), strides=(1, 1), padding=1, name="0" + self.config.fusion_hidden_size // 2, kernel_size=(3, 3), strides=(1, 1), padding="SAME", name="0" ) self.upsample = FlaxDPTUpsample(scale=2, method="bilinear") - self.conv2 = nn.Conv(32, kernel_size=(3, 3), strides=(1, 1), padding=1, name="2") + self.conv2 = nn.Conv(32, kernel_size=(3, 3), strides=(1, 1), padding="SAME", name="2") self.act = ACT2FN["relu"] - self.conv3 = nn.Conv(1, kernel_size=(1, 1), strides=(1, 1), padding=0, name="4") + self.conv3 = nn.Conv(1, kernel_size=(1, 1), strides=(1, 1), padding="VALID", name="4") def __call__(self, hidden_state): x = self.conv1(hidden_state) @@ -970,7 +963,7 @@ def __call__( >>> prediction = torch.nn.functional.interpolate( ... predicted_depth.unsqueeze(1), ... size=image.size[::-1], - ... mode="bicubic", + ... mode="bilinear", ... align_corners=False, ... ) @@ -997,11 +990,11 @@ def __call__( # note that the hidden_states also include the initial embeddings if return_dict: hidden_states = [ - feature for idx, feature in enumerate(hidden_states) if idx in self.config.backbone_out_indices + feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices ] else: hidden_states = [ - feature for idx, feature in enumerate(hidden_states[1]) if idx in self.config.backbone_out_indices + feature for idx, feature in enumerate(hidden_states[1][1:]) if idx in self.config.backbone_out_indices ] hidden_states = self.neck(hidden_states) @@ -1178,11 +1171,11 @@ def __call__( # note that the hidden_states also include the initial embeddings if return_dict: hidden_states = [ - feature for idx, feature in enumerate(hidden_states) if idx in self.config.backbone_out_indices + feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices ] else: hidden_states = [ - feature for idx, feature in enumerate(hidden_states[1]) if idx in self.config.backbone_out_indices + feature for idx, feature in enumerate(hidden_states[1][1:]) if idx in self.config.backbone_out_indices ] hidden_states = self.neck(hidden_states) diff --git a/tests/models/dpt/test_modeling_flax_dpt.py b/tests/models/dpt/test_modeling_flax_dpt.py index 1c6a895a6819..1ace983dbe78 100644 --- a/tests/models/dpt/test_modeling_flax_dpt.py +++ b/tests/models/dpt/test_modeling_flax_dpt.py @@ -20,7 +20,7 @@ import numpy as np from transformers import DPTConfig, is_flax_available -from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow +from transformers.testing_utils import require_flax, slow from ...test_configuration_common import ConfigTester from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor @@ -141,141 +141,6 @@ def setUp(self) -> None: def test_config(self): self.config_tester.run_common_tests() - @is_pt_flax_cross_test - def test_equivalence_pt_to_flax(self): - import tempfile - - import torch - - import jax.numpy as jnp - import transformers - from transformers.modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax - from transformers.testing_utils import torch_device - - # It might be better to put this inside the for loop below (because we modify the config there). - # But logically, it is fine. - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - - # Output all for aggressive testing - config.output_hidden_states = True - config.output_attentions = self.has_attentions - - # prepare inputs - prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - prepared_inputs_dict.pop("labels") - pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()} - - # load corresponding PyTorch class - pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning - pt_model_class = getattr(transformers, pt_model_class_name) - - pt_model = pt_model_class(config).eval() - # Flax models don't use the `use_cache` option and cache is not returned as a default. - # So we disable `use_cache` here for PyTorch model. - pt_model.config.use_cache = False - fx_model = model_class(config, dtype=jnp.float32) - - fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) - fx_model.params = fx_state - - # send pytorch model to the correct device - pt_model.to(torch_device) - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs) - fx_outputs = fx_model(**prepared_inputs_dict) - - fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) - pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) - - self.assertEqual(fx_keys, pt_keys) - self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys) - - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True) - - fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict) - - fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None]) - pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) - - self.assertEqual(fx_keys, pt_keys) - self.check_outputs(fx_outputs_loaded.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys) - - @is_pt_flax_cross_test - def test_equivalence_flax_to_pt(self): - import tempfile - - import torch - - import jax.numpy as jnp - import transformers - from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model - from transformers.testing_utils import torch_device - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - - # Output all for aggressive testing - config.output_hidden_states = True - config.output_attentions = self.has_attentions - - # prepare inputs - prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - prepared_inputs_dict.pop("labels") - pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()} - - # load corresponding PyTorch class - pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning - pt_model_class = getattr(transformers, pt_model_class_name) - - pt_model = pt_model_class(config).eval() - # Flax models don't use the `use_cache` option and cache is not returned as a default. - # So we disable `use_cache` here for PyTorch model. - pt_model.config.use_cache = False - fx_model = model_class(config, dtype=jnp.float32) - - pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) - - # make sure weights are tied in PyTorch - pt_model.tie_weights() - - # send pytorch model to the correct device - pt_model.to(torch_device) - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs) - fx_outputs = fx_model(**prepared_inputs_dict) - - fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) - pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) - - self.assertEqual(fx_keys, pt_keys) - self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys) - - with tempfile.TemporaryDirectory() as tmpdirname: - fx_model.save_pretrained(tmpdirname) - pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True) - - # send pytorch model to the correct device - pt_model_loaded.to(torch_device) - pt_model_loaded.eval() - - with torch.no_grad(): - pt_outputs_loaded = pt_model_loaded(**pt_inputs) - - fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) - pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None]) - - self.assertEqual(fx_keys, pt_keys) - self.check_outputs(fx_outputs.to_tuple(), pt_outputs_loaded.to_tuple(), model_class, names=fx_keys) - # We neeed to override this test because ViT's forward signature is different than text models. def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -320,4 +185,10 @@ def test_model_from_pretrained(self): outputs = model(np.ones((1, 3, 384, 384))) self.assertIsNotNone(outputs) - # TODO: add tests for segmentation and depth estimation (logits) + +# TODO: add tests for segmentation and depth estimation (logits) +# @require_vision +# @slow +# def test_model_from_pretrained_example(self): +# model = FlaxDPTForDepthEstimation.from_pretrained("Intel/dpt-large", from_pt=True) +# image = prepare_img() From 63a61444e1581c8ba4ddf55d5428f4117fbb77c6 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sat, 25 Jun 2022 22:42:17 +0200 Subject: [PATCH 09/22] style --- .../models/dpt/gradient_convolution.py | 587 +++++++++--------- 1 file changed, 289 insertions(+), 298 deletions(-) diff --git a/src/transformers/models/dpt/gradient_convolution.py b/src/transformers/models/dpt/gradient_convolution.py index d67fb9b6537c..9209a65b7de4 100644 --- a/src/transformers/models/dpt/gradient_convolution.py +++ b/src/transformers/models/dpt/gradient_convolution.py @@ -1,16 +1,14 @@ -from typing import (Any, Callable, Optional, Sequence, Tuple, - Union) +from typing import Any, Callable, NamedTuple, Optional, Sequence, Tuple, Union + +import numpy as np + import flax.linen as nn -from jax import lax -from typing import (Any, Callable, NamedTuple, Optional, Sequence, Tuple, - Union) -from jax.lax import conv_general_dilated -from flax.linen.initializers import lecun_normal -from flax.linen.initializers import zeros +import jax.numpy as jnp +from flax.linen.initializers import lecun_normal, zeros from flax.linen.module import compact from jax import lax -import jax.numpy as jnp -import numpy as np +from jax.lax import conv_general_dilated + default_kernel_init = lecun_normal() @@ -26,144 +24,140 @@ class ConvDimensionNumbers(NamedTuple): - """ - Describes batch, spatial, and feature dimensions of a convolution. - - Args: - lhs_spec: a tuple of nonnegative integer dimension numbers containing - `(batch dimension, feature dimension, spatial dimensions...)`. - rhs_spec: a tuple of nonnegative integer dimension numbers containing - `(out feature dimension, in feature dimension, spatial dimensions...)`. - out_spec: a tuple of nonnegative integer dimension numbers containing - `(batch dimension, feature dimension, spatial dimensions...)`. - """ - lhs_spec: Sequence[int] - rhs_spec: Sequence[int] - out_spec: Sequence[int] - -ConvGeneralDilatedDimensionNumbers = Union[ - None, ConvDimensionNumbers, Tuple[str, str, str]] + """ + Describes batch, spatial, and feature dimensions of a convolution. + + Args: + lhs_spec: a tuple of nonnegative integer dimension numbers containing + `(batch dimension, feature dimension, spatial dimensions...)`. + rhs_spec: a tuple of nonnegative integer dimension numbers containing + `(out feature dimension, in feature dimension, spatial dimensions...)`. + out_spec: a tuple of nonnegative integer dimension numbers containing + `(batch dimension, feature dimension, spatial dimensions...)`. + """ + + lhs_spec: Sequence[int] + rhs_spec: Sequence[int] + out_spec: Sequence[int] + + +ConvGeneralDilatedDimensionNumbers = Union[None, ConvDimensionNumbers, Tuple[str, str, str]] + def _flip_axes(x, axes): - """Flip ndarray 'x' along each axis specified in axes tuple.""" - for axis in axes: - x = np.flip(x, axis) - return x + """Flip ndarray 'x' along each axis specified in axes tuple.""" + for axis in axes: + x = np.flip(x, axis) + return x + def conv_general_permutations(dimension_numbers): - """Utility for convolution dimension permutations relative to Conv HLO.""" - lhs_spec, rhs_spec, out_spec = dimension_numbers - lhs_char, rhs_char, out_char = charpairs = ("N", "C"), ("O", "I"), ("N", "C") - for i, (a, b) in enumerate(charpairs): - if not dimension_numbers[i].count(a) == dimension_numbers[i].count(b) == 1: - msg = ("convolution dimension_numbers[{}] must contain the characters " - "'{}' and '{}' exactly once, got {}.") - raise TypeError(msg.format(i, a, b, dimension_numbers[i])) - if len(dimension_numbers[i]) != len(set(dimension_numbers[i])): - msg = ("convolution dimension_numbers[{}] cannot have duplicate " - "characters, got {}.") - raise TypeError(msg.format(i, dimension_numbers[i])) - if not (set(lhs_spec) - set(lhs_char) == set(rhs_spec) - set(rhs_char) == - set(out_spec) - set(out_char)): - msg = ("convolution dimension_numbers elements must each have the same " - "set of spatial characters, got {}.") - raise TypeError(msg.format(dimension_numbers)) - - def getperm(spec, charpair): - spatial = (i for i, c in enumerate(spec) if c not in charpair) - if spec is not rhs_spec: - spatial = sorted(spatial, key=lambda i: rhs_spec.index(spec[i])) - return (spec.index(charpair[0]), spec.index(charpair[1])) + tuple(spatial) - - lhs_perm, rhs_perm, out_perm = map(getperm, dimension_numbers, charpairs) - return lhs_perm, rhs_perm, out_perm - -def conv_dimension_numbers_(lhs_shape, rhs_shape, dimension_numbers - ) -> ConvDimensionNumbers: - """Converts convolution `dimension_numbers` to a `ConvDimensionNumbers`. - - Args: - lhs_shape: tuple of nonnegative integers, shape of the convolution input. - rhs_shape: tuple of nonnegative integers, shape of the convolution kernel. - dimension_numbers: None or a tuple/list of strings or a ConvDimensionNumbers - object following the convolution dimension number specification format in xla_client.py. - - Returns: - A `ConvDimensionNumbers` object that represents `dimension_numbers` in the canonical form used by lax functions. - """ - if isinstance(dimension_numbers, ConvDimensionNumbers): - return dimension_numbers - if len(lhs_shape) != len(rhs_shape): - msg = "convolution requires lhs and rhs ndim to be equal, got {} and {}." - raise TypeError(msg.format(len(lhs_shape), len(rhs_shape))) - - if dimension_numbers is None: - iota = tuple(range(len(lhs_shape))) - return ConvDimensionNumbers(iota, iota, iota) - elif isinstance(dimension_numbers, (list, tuple)): - if len(dimension_numbers) != 3: - msg = "convolution dimension_numbers list/tuple must be length 3, got {}." - raise TypeError(msg.format(len(dimension_numbers))) - if not all(isinstance(elt, str) for elt in dimension_numbers): - msg = "convolution dimension_numbers elements must be strings, got {}." - raise TypeError(msg.format(tuple(map(type, dimension_numbers)))) - msg = ("convolution dimension_numbers[{}] must have len equal to the ndim " - "of lhs and rhs, got {} for lhs and rhs shapes {} and {}.") - for i, elt in enumerate(dimension_numbers): - if len(elt) != len(lhs_shape): - raise TypeError(msg.format(i, len(elt), lhs_shape, rhs_shape)) - - lhs_spec, rhs_spec, out_spec = conv_general_permutations(dimension_numbers) - return ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec) - else: - msg = "convolution dimension_numbers must be tuple/list or None, got {}." - raise TypeError(msg.format(type(dimension_numbers))) - -def _deconv_output_length(input_length, - filter_size, - padding, - output_padding=None, - stride=0, - dilation=1): - """Determines the output length of a transposed convolution given the input length. - Arguments: - Function modified from Keras. - input_length: Integer. filter_size: Integer. padding: one of `"SAME"`, `"VALID"`, or a 2-integer tuple. - output_padding: Integer, amount of padding along the output dimension. Can - be set to `None` in which case the output length is inferred. - stride: Integer. dilation: Integer. - Returns: - The output length (integer). - """ - if input_length is None: - return None - - # Get the dilated kernel size - filter_size = filter_size + (filter_size - 1) * (dilation - 1) - - # Infer length if output padding is None, else compute the exact length - if output_padding is None: - if padding == 'VALID': - length = input_length * stride + max(filter_size - stride, 0) - elif padding == 'SAME': - length = input_length * stride + """Utility for convolution dimension permutations relative to Conv HLO.""" + lhs_spec, rhs_spec, out_spec = dimension_numbers + lhs_char, rhs_char, out_char = charpairs = ("N", "C"), ("O", "I"), ("N", "C") + for i, (a, b) in enumerate(charpairs): + if not dimension_numbers[i].count(a) == dimension_numbers[i].count(b) == 1: + msg = "convolution dimension_numbers[{}] must contain the characters '{}' and '{}' exactly once, got {}." + raise TypeError(msg.format(i, a, b, dimension_numbers[i])) + if len(dimension_numbers[i]) != len(set(dimension_numbers[i])): + msg = "convolution dimension_numbers[{}] cannot have duplicate characters, got {}." + raise TypeError(msg.format(i, dimension_numbers[i])) + if not (set(lhs_spec) - set(lhs_char) == set(rhs_spec) - set(rhs_char) == set(out_spec) - set(out_char)): + msg = "convolution dimension_numbers elements must each have the same set of spatial characters, got {}." + raise TypeError(msg.format(dimension_numbers)) + + def getperm(spec, charpair): + spatial = (i for i, c in enumerate(spec) if c not in charpair) + if spec is not rhs_spec: + spatial = sorted(spatial, key=lambda i: rhs_spec.index(spec[i])) + return (spec.index(charpair[0]), spec.index(charpair[1])) + tuple(spatial) + + lhs_perm, rhs_perm, out_perm = map(getperm, dimension_numbers, charpairs) + return lhs_perm, rhs_perm, out_perm + + +def conv_dimension_numbers_(lhs_shape, rhs_shape, dimension_numbers) -> ConvDimensionNumbers: + """Converts convolution `dimension_numbers` to a `ConvDimensionNumbers`. + + Args: + lhs_shape: tuple of nonnegative integers, shape of the convolution input. + rhs_shape: tuple of nonnegative integers, shape of the convolution kernel. + dimension_numbers: None or a tuple/list of strings or a ConvDimensionNumbers + object following the convolution dimension number specification format in xla_client.py. + + Returns: + A `ConvDimensionNumbers` object that represents `dimension_numbers` in the canonical form used by lax functions. + """ + if isinstance(dimension_numbers, ConvDimensionNumbers): + return dimension_numbers + if len(lhs_shape) != len(rhs_shape): + msg = "convolution requires lhs and rhs ndim to be equal, got {} and {}." + raise TypeError(msg.format(len(lhs_shape), len(rhs_shape))) + + if dimension_numbers is None: + iota = tuple(range(len(lhs_shape))) + return ConvDimensionNumbers(iota, iota, iota) + elif isinstance(dimension_numbers, (list, tuple)): + if len(dimension_numbers) != 3: + msg = "convolution dimension_numbers list/tuple must be length 3, got {}." + raise TypeError(msg.format(len(dimension_numbers))) + if not all(isinstance(elt, str) for elt in dimension_numbers): + msg = "convolution dimension_numbers elements must be strings, got {}." + raise TypeError(msg.format(tuple(map(type, dimension_numbers)))) + msg = ( + "convolution dimension_numbers[{}] must have len equal to the ndim " + "of lhs and rhs, got {} for lhs and rhs shapes {} and {}." + ) + for i, elt in enumerate(dimension_numbers): + if len(elt) != len(lhs_shape): + raise TypeError(msg.format(i, len(elt), lhs_shape, rhs_shape)) + + lhs_spec, rhs_spec, out_spec = conv_general_permutations(dimension_numbers) + return ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec) else: - length = ((input_length - 1) * stride + filter_size - - padding[0] - padding[1]) - - else: - if padding == 'SAME': - pad = filter_size // 2 - total_pad = pad * 2 - elif padding == 'VALID': - total_pad = 0 + msg = "convolution dimension_numbers must be tuple/list or None, got {}." + raise TypeError(msg.format(type(dimension_numbers))) + + +def _deconv_output_length(input_length, filter_size, padding, output_padding=None, stride=0, dilation=1): + """Determines the output length of a transposed convolution given the input length. + Arguments: + Function modified from Keras. + input_length: Integer. filter_size: Integer. padding: one of `"SAME"`, `"VALID"`, or a 2-integer tuple. + output_padding: Integer, amount of padding along the output dimension. Can + be set to `None` in which case the output length is inferred. + stride: Integer. dilation: Integer. + Returns: + The output length (integer). + """ + if input_length is None: + return None + + # Get the dilated kernel size + filter_size = filter_size + (filter_size - 1) * (dilation - 1) + + # Infer length if output padding is None, else compute the exact length + if output_padding is None: + if padding == "VALID": + length = input_length * stride + max(filter_size - stride, 0) + elif padding == "SAME": + length = input_length * stride + else: + length = (input_length - 1) * stride + filter_size - padding[0] - padding[1] + else: - total_pad = padding[0] + padding[1] + if padding == "SAME": + pad = filter_size // 2 + total_pad = pad * 2 + elif padding == "VALID": + total_pad = 0 + else: + total_pad = padding[0] + padding[1] - length = ((input_length - 1) * stride + filter_size - total_pad + - output_padding) + length = (input_length - 1) * stride + filter_size - total_pad + output_padding + + return length - return length def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding: if isinstance(padding, str): @@ -171,7 +165,6 @@ def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding: if isinstance(padding, int): return [(padding, padding)] * rank - if isinstance(padding, Sequence) and len(padding) == rank: new_pad = [] for p in padding: @@ -184,9 +177,11 @@ def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding: if len(new_pad) == rank: return new_pad raise ValueError( - f'Invalid padding format: {padding}, should be str, int,' - f' or a sequence of len {rank} where each element is an' - f' int or pair of ints.') + f"Invalid padding format: {padding}, should be str, int," + f" or a sequence of len {rank} where each element is an" + " int or pair of ints." + ) + def _compute_adjusted_padding( input_size: int, @@ -196,136 +191,145 @@ def _compute_adjusted_padding( padding: Union[str, Tuple[int, int]], dilation: int = 1, ) -> Tuple[int, int]: - """Computes adjusted padding for desired ConvTranspose `output_size`. - Ported from DeepMind Haiku. - """ - kernel_size = (kernel_size - 1) * dilation + 1 - if padding == "VALID": - expected_input_size = (output_size - kernel_size + stride) // stride - if input_size != expected_input_size: - raise ValueError(f"The expected input size with the current set of input " - f"parameters is {expected_input_size} which doesn't " - f"match the actual input size {input_size}.") - padding_before = 0 - elif padding == "SAME": - expected_input_size = (output_size + stride - 1) // stride - if input_size != expected_input_size: - raise ValueError(f"The expected input size with the current set of input " - f"parameters is {expected_input_size} which doesn't " - f"match the actual input size {input_size}.") - padding_needed = max(0, - (input_size - 1) * stride + kernel_size - output_size) - padding_before = padding_needed // 2 - else: - padding_before = padding[0] # type: ignore[assignment] - - expanded_input_size = (input_size - 1) * stride + 1 - padded_out_size = output_size + kernel_size - 1 - pad_before = kernel_size - 1 - padding_before - pad_after = padded_out_size - expanded_input_size - pad_before - return (pad_before, pad_after) - -def gradient_based_conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int], - padding: Union[str, Sequence[Tuple[int, int]]], - output_padding: Optional[Sequence[int]] = None, - output_shape: Optional[Sequence[int]] = None, - dilation: Optional[Sequence[int]] = None, - dimension_numbers: ConvGeneralDilatedDimensionNumbers = None, - transpose_kernel: bool = True, - precision: PrecisionLike = None) -> Array: - """Convenience wrapper for calculating the N-d transposed convolution. - Args: - Much like *conv_transpose*, this function calculates transposed convolutions via fractionally strided convolution - rather than calculating the gradient (transpose) of a forward convolution. However, the latter is more common among - deep learning frameworks, such as TensorFlow, PyTorch, and Keras. This function provides the same set of APIs to help: - reproduce results in these frameworks. - lhs: a rank *n+2* dimensional input array. rhs: a rank *n+2* dimensional array of kernel weights. strides: sequence - of *n* integers, amounts to strides of the corresponding forward convolution. padding: *"SAME"*, *"VALID"*, or a - sequence of *n* integer 2-tuples that controls - the before-and-after padding for each *n* spatial dimension of the corresponding forward convolution. - output_padding: A sequence of integers specifying the amount of padding along - each spacial dimension of the output tensor, used to disambiguate the output shape of transposed convolutions - when the stride is larger than 1. (see a detailed description at - 1https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html) The amount of output padding along a - given dimension must be lower than the stride along that same dimension. If set to *None* (default), the output - shape is inferred. If both *output_padding* and *output_shape* are specified, they have to be mutually - compatible. - output_shape: Output shape of the spatial dimensions of a transpose - convolution. Can be *None* or an iterable of *n* integers. If a *None* value is given (default), the shape is - automatically calculated. Similar to *output_padding*, *output_shape* is also for disambiguating the output shape - when stride > 1 (see also https://www.tensorflow.org/api_docs/python/tf/nn/conv2d_transpose) If both - *output_padding* and *output_shape* are specified, they have to be mutually compatible. - dilation: *None*, or a sequence of *n* integers, giving the - dilation factor to apply in each spatial dimension of *rhs*. Dilated convolution is also known as atrous - convolution. - dimension_numbers: tuple of dimension descriptors as in - lax.conv_general_dilated. Defaults to tensorflow convention. - transpose_kernel: if *True* flips spatial axes and swaps the input/output - channel axes of the kernel. This makes the output of this function identical to the gradient-derived functions - like keras.layers.Conv2DTranspose and torch.nn.ConvTranspose2d applied to the same kernel. Although for typical - use in neural nets this is unnecessary and makes input/output channel specification confusing, you need to set - this to *True* in order to match the behavior in many deep learning frameworks, such as TensorFlow, Keras, and - PyTorch. - precision: Optional. Either `None`, which means the default precision for - the backend, a `lax.Precision` enum value (`Precision.DEFAULT`, `Precision.HIGH` or `Precision.HIGHEST`) or a - tuple of two `lax.Precision` enums indicating precision of ``lhs``` and `rhs`. - Returns: - Transposed N-d convolution. - """ - assert len(lhs.shape) == len(rhs.shape) and len(lhs.shape) >= 2 - ndims = len(lhs.shape) - one = (1,) * (ndims - 2) - # Set dimensional layout defaults if not specified. - if dimension_numbers is None: - if ndims == 2: - dimension_numbers = ('NC', 'IO', 'NC') - elif ndims == 3: - dimension_numbers = ('NHC', 'HIO', 'NHC') - elif ndims == 4: - dimension_numbers = ('NHWC', 'HWIO', 'NHWC') - elif ndims == 5: - dimension_numbers = ('NHWDC', 'HWDIO', 'NHWDC') - else: - raise ValueError('No 4+ dimensional dimension_number defaults.') - dn = conv_dimension_numbers_(lhs.shape, rhs.shape, dimension_numbers) - # dn = dimension_numbers - k_shape = np.take(rhs.shape, dn.rhs_spec) - k_sdims = k_shape[2:] # type: ignore[index] - i_shape = np.take(lhs.shape, dn.lhs_spec) - i_sdims = i_shape[2:] # type: ignore[index] - - # Calculate correct output shape given padding and strides. - if dilation is None: - dilation = (1,) * (rhs.ndim - 2) - - if output_padding is None: - output_padding = [None] * (rhs.ndim - 2) # type: ignore[list-item] - - if isinstance(padding, str): - if padding in {'SAME', 'VALID'}: - padding = [padding] * (rhs.ndim - 2) # type: ignore[list-item] + """Computes adjusted padding for desired ConvTranspose `output_size`. + Ported from DeepMind Haiku. + """ + kernel_size = (kernel_size - 1) * dilation + 1 + if padding == "VALID": + expected_input_size = (output_size - kernel_size + stride) // stride + if input_size != expected_input_size: + raise ValueError( + "The expected input size with the current set of input " + f"parameters is {expected_input_size} which doesn't " + f"match the actual input size {input_size}." + ) + padding_before = 0 + elif padding == "SAME": + expected_input_size = (output_size + stride - 1) // stride + if input_size != expected_input_size: + raise ValueError( + "The expected input size with the current set of input " + f"parameters is {expected_input_size} which doesn't " + f"match the actual input size {input_size}." + ) + padding_needed = max(0, (input_size - 1) * stride + kernel_size - output_size) + padding_before = padding_needed // 2 else: - raise ValueError(f"`padding` must be 'VALID' or 'SAME'. Passed: {padding}.") + padding_before = padding[0] # type: ignore[assignment] + + expanded_input_size = (input_size - 1) * stride + 1 + padded_out_size = output_size + kernel_size - 1 + pad_before = kernel_size - 1 - padding_before + pad_after = padded_out_size - expanded_input_size - pad_before + return (pad_before, pad_after) + + +def gradient_based_conv_transpose( + lhs: Array, + rhs: Array, + strides: Sequence[int], + padding: Union[str, Sequence[Tuple[int, int]]], + output_padding: Optional[Sequence[int]] = None, + output_shape: Optional[Sequence[int]] = None, + dilation: Optional[Sequence[int]] = None, + dimension_numbers: ConvGeneralDilatedDimensionNumbers = None, + transpose_kernel: bool = True, + precision: PrecisionLike = None, +) -> Array: + """Convenience wrapper for calculating the N-d transposed convolution. + Args: + Much like *conv_transpose*, this function calculates transposed convolutions via fractionally strided convolution + rather than calculating the gradient (transpose) of a forward convolution. However, the latter is more common among: + deep learning frameworks, such as TensorFlow, PyTorch, and Keras. This function provides the same set of APIs to help: + reproduce results in these frameworks. + lhs: a rank *n+2* dimensional input array. rhs: a rank *n+2* dimensional array of kernel weights. strides: + sequence of *n* integers, amounts to strides of the corresponding forward convolution. padding: *"SAME"*, + *"VALID"*, or a sequence of *n* integer 2-tuples that controls + the before-and-after padding for each *n* spatial dimension of the corresponding forward convolution. + output_padding: A sequence of integers specifying the amount of padding along + each spacial dimension of the output tensor, used to disambiguate the output shape of transposed convolutions + when the stride is larger than 1. (see a detailed description at + 1https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html) The amount of output padding along a + given dimension must be lower than the stride along that same dimension. If set to *None* (default), the output + shape is inferred. If both *output_padding* and *output_shape* are specified, they have to be mutually + compatible. + output_shape: Output shape of the spatial dimensions of a transpose + convolution. Can be *None* or an iterable of *n* integers. If a *None* value is given (default), the shape is + automatically calculated. Similar to *output_padding*, *output_shape* is also for disambiguating the output + shape when stride > 1 (see also https://www.tensorflow.org/api_docs/python/tf/nn/conv2d_transpose) If both + *output_padding* and *output_shape* are specified, they have to be mutually compatible. + dilation: *None*, or a sequence of *n* integers, giving the + dilation factor to apply in each spatial dimension of *rhs*. Dilated convolution is also known as atrous + convolution. + dimension_numbers: tuple of dimension descriptors as in + lax.conv_general_dilated. Defaults to tensorflow convention. + transpose_kernel: if *True* flips spatial axes and swaps the input/output + channel axes of the kernel. This makes the output of this function identical to the gradient-derived functions + like keras.layers.Conv2DTranspose and torch.nn.ConvTranspose2d applied to the same kernel. Although for typical + use in neural nets this is unnecessary and makes input/output channel specification confusing, you need to set + this to *True* in order to match the behavior in many deep learning frameworks, such as TensorFlow, Keras, and + PyTorch. + precision: Optional. Either `None`, which means the default precision for + the backend, a `lax.Precision` enum value (`Precision.DEFAULT`, `Precision.HIGH` or `Precision.HIGHEST`) or a + tuple of two `lax.Precision` enums indicating precision of ``lhs``` and `rhs`. + Returns: + Transposed N-d convolution. + """ + assert len(lhs.shape) == len(rhs.shape) and len(lhs.shape) >= 2 + ndims = len(lhs.shape) + one = (1,) * (ndims - 2) + # Set dimensional layout defaults if not specified. + if dimension_numbers is None: + if ndims == 2: + dimension_numbers = ("NC", "IO", "NC") + elif ndims == 3: + dimension_numbers = ("NHC", "HIO", "NHC") + elif ndims == 4: + dimension_numbers = ("NHWC", "HWIO", "NHWC") + elif ndims == 5: + dimension_numbers = ("NHWDC", "HWDIO", "NHWDC") + else: + raise ValueError("No 4+ dimensional dimension_number defaults.") + dn = conv_dimension_numbers_(lhs.shape, rhs.shape, dimension_numbers) + # dn = dimension_numbers + k_shape = np.take(rhs.shape, dn.rhs_spec) + k_sdims = k_shape[2:] # type: ignore[index] + i_shape = np.take(lhs.shape, dn.lhs_spec) + i_sdims = i_shape[2:] # type: ignore[index] - inferred_output_shape = tuple(map(_deconv_output_length, i_sdims, k_sdims, - padding, output_padding, strides, dilation)) - if output_shape is None: - output_shape = inferred_output_shape # type: ignore[assignment] - else: - if not output_shape == inferred_output_shape: - raise ValueError(f"`output_padding` and `output_shape` are not compatible." - f"Inferred output shape from `output_padding`: {inferred_output_shape}, " - f"but got `output_shape` {output_shape}") + # Calculate correct output shape given padding and strides. + if dilation is None: + dilation = (1,) * (rhs.ndim - 2) - pads = tuple(map(_compute_adjusted_padding, i_sdims, output_shape, - k_sdims, strides, padding, dilation)) + if output_padding is None: + output_padding = [None] * (rhs.ndim - 2) # type: ignore[list-item] - if transpose_kernel: - # flip spatial dims and swap input / output channel axes - rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:]) - rhs = np.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1]) - return conv_general_dilated(lhs, rhs, one, pads, strides, dilation, dimension_numbers, - precision=precision) + if isinstance(padding, str): + if padding in {"SAME", "VALID"}: + padding = [padding] * (rhs.ndim - 2) # type: ignore[list-item] + else: + raise ValueError(f"`padding` must be 'VALID' or 'SAME'. Passed: {padding}.") + + inferred_output_shape = tuple( + map(_deconv_output_length, i_sdims, k_sdims, padding, output_padding, strides, dilation) + ) + if output_shape is None: + output_shape = inferred_output_shape # type: ignore[assignment] + else: + if not output_shape == inferred_output_shape: + raise ValueError( + "`output_padding` and `output_shape` are not compatible." + f"Inferred output shape from `output_padding`: {inferred_output_shape}, " + f"but got `output_shape` {output_shape}" + ) + + pads = tuple(map(_compute_adjusted_padding, i_sdims, output_shape, k_sdims, strides, padding, dilation)) + + if transpose_kernel: + # flip spatial dims and swap input / output channel axes + rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:]) + rhs = np.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1]) + return conv_general_dilated(lhs, rhs, one, pads, strides, dilation, dimension_numbers, precision=precision) class ConvTransposeGradient(nn.Module): @@ -355,10 +359,11 @@ class ConvTransposeGradient(nn.Module): kernel_init: initializer for the convolutional kernel. bias_init: initializer for the bias. """ + features: int kernel_size: Union[int, Tuple[int, ...]] strides: Optional[Tuple[int, ...]] = None - padding: PaddingLike = 'SAME' + padding: PaddingLike = "SAME" kernel_dilation: Optional[Sequence[int]] = None use_bias: bool = True dtype: Dtype = jnp.float32 @@ -400,23 +405,18 @@ def __call__(self, inputs: Array) -> Array: in_features = inputs.shape[-1] kernel_shape = kernel_size + (in_features, self.features) - kernel = self.param('kernel', self.kernel_init, kernel_shape, - self.param_dtype) + kernel = self.param("kernel", self.kernel_init, kernel_shape, self.param_dtype) kernel = jnp.asarray(kernel, self.dtype) padding_lax = canonicalize_padding(self.padding, len(kernel_size)) - if padding_lax == 'CIRCULAR': - padding_lax = 'VALID' + if padding_lax == "CIRCULAR": + padding_lax = "VALID" y = gradient_based_conv_transpose( - inputs, - kernel, - strides, - padding_lax, - dilation=self.kernel_dilation, - precision=self.precision) - - if self.padding == 'CIRCULAR': + inputs, kernel, strides, padding_lax, dilation=self.kernel_dilation, precision=self.precision + ) + + if self.padding == "CIRCULAR": # For circular padding, we need to identify the size of the final output # ("period") along each spatial dimension, pad each dimension to an # integer number of periods, and wrap the array periodically around each @@ -426,35 +426,26 @@ def __call__(self, inputs: Array) -> Array: # Compute period along each spatial dimension - it's input size scaled # by the stride. - scaled_x_dims = [ - x_dim * stride for x_dim, stride in zip(inputs.shape[1:-1], strides) - ] + scaled_x_dims = [x_dim * stride for x_dim, stride in zip(inputs.shape[1:-1], strides)] # Compute difference between the current size of y and the final output # size, and complement this difference to 2 * period - that gives how # much we need to pad. - size_diffs = [ - -(y_dim - x_dim) % (2 * x_dim) - for y_dim, x_dim in zip(y.shape[1:-1], scaled_x_dims) - ] + size_diffs = [-(y_dim - x_dim) % (2 * x_dim) for y_dim, x_dim in zip(y.shape[1:-1], scaled_x_dims)] # Divide the padding equaly between left and right. The choice to put # "+1" on the left (and not on the right) represents a convention for # aligning even-sized kernels. - total_pad = [ - ((size_diff + 1) // 2, size_diff // 2) for size_diff in size_diffs - ] + total_pad = [((size_diff + 1) // 2, size_diff // 2) for size_diff in size_diffs] y = np.pad(y, [(0, 0)] + total_pad + [(0, 0)]) # Wrap the result periodically around each spatial dimension, # one by one. for i in range(1, y.ndim - 1): - y = y.reshape(y.shape[:i] + (-1, scaled_x_dims[i - 1]) + - y.shape[i + 1:]) + y = y.reshape(y.shape[:i] + (-1, scaled_x_dims[i - 1]) + y.shape[i + 1 :]) y = y.sum(axis=i) if is_single_input: y = jnp.squeeze(y, axis=0) if self.use_bias: - bias = self.param('bias', self.bias_init, (self.features,), - self.param_dtype) + bias = self.param("bias", self.bias_init, (self.features,), self.param_dtype) bias = jnp.asarray(bias, self.dtype) y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) - return y \ No newline at end of file + return y From 3d74f94c895c05ab2e5a840db95c3683b76bee16 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sat, 25 Jun 2022 22:45:22 +0200 Subject: [PATCH 10/22] add FlaxViTPatchEmbeddings for consistency --- src/transformers/models/dpt/modeling_flax_dpt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/dpt/modeling_flax_dpt.py b/src/transformers/models/dpt/modeling_flax_dpt.py index 78ac6aa2ab00..b1ff5c73ee7a 100644 --- a/src/transformers/models/dpt/modeling_flax_dpt.py +++ b/src/transformers/models/dpt/modeling_flax_dpt.py @@ -80,8 +80,8 @@ """ -# Copied from transformers.models.vit.modeling_flax_vit.FlaxPatchEmbeddings with ViT->DPT -class FlaxPatchEmbeddings(nn.Module): +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTPatchEmbeddings with ViT->DPT +class FlaxViTPatchEmbeddings(nn.Module): config: DPTConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation From ed7a4bc6f11307bd73f6d6dd139b364e48fe9e75 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sat, 25 Jun 2022 23:09:06 +0200 Subject: [PATCH 11/22] update tests --- tests/models/dpt/test_modeling_flax_dpt.py | 137 ++++++++++++++++++++- 1 file changed, 136 insertions(+), 1 deletion(-) diff --git a/tests/models/dpt/test_modeling_flax_dpt.py b/tests/models/dpt/test_modeling_flax_dpt.py index 1ace983dbe78..05cbe986e260 100644 --- a/tests/models/dpt/test_modeling_flax_dpt.py +++ b/tests/models/dpt/test_modeling_flax_dpt.py @@ -20,7 +20,7 @@ import numpy as np from transformers import DPTConfig, is_flax_available -from transformers.testing_utils import require_flax, slow +from transformers.testing_utils import require_flax, slow, is_pt_flax_cross_test from ...test_configuration_common import ConfigTester from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor @@ -185,6 +185,141 @@ def test_model_from_pretrained(self): outputs = model(np.ones((1, 3, 384, 384))) self.assertIsNotNone(outputs) + @is_pt_flax_cross_test + def test_equivalence_pt_to_flax(self): + import tempfile + + import torch + + import jax.numpy as jnp + import transformers + from transformers.modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax + from transformers.testing_utils import torch_device + + # It might be better to put this inside the for loop below (because we modify the config there). + # But logically, it is fine. + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + + # Output all for aggressive testing + config.output_hidden_states = True + config.output_attentions = self.has_attentions + + # prepare inputs + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + prepared_inputs_dict.pop("labels") + pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()} + + # load corresponding PyTorch class + pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning + pt_model_class = getattr(transformers, pt_model_class_name) + + pt_model = pt_model_class(config).eval() + # Flax models don't use the `use_cache` option and cache is not returned as a default. + # So we disable `use_cache` here for PyTorch model. + pt_model.config.use_cache = False + fx_model = model_class(config, dtype=jnp.float32) + + fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) + fx_model.params = fx_state + + # send pytorch model to the correct device + pt_model.to(torch_device) + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs) + fx_outputs = fx_model(**prepared_inputs_dict) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class) + + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True) + + fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict) + + fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class) + + + @is_pt_flax_cross_test + def test_equivalence_flax_to_pt(self): + import tempfile + + import torch + + import jax.numpy as jnp + import transformers + from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model + from transformers.testing_utils import torch_device + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + + # Output all for aggressive testing + config.output_hidden_states = True + config.output_attentions = self.has_attentions + + # prepare inputs + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + prepared_inputs_dict.pop("labels") + pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()} + + # load corresponding PyTorch class + pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning + pt_model_class = getattr(transformers, pt_model_class_name) + + pt_model = pt_model_class(config).eval() + # Flax models don't use the `use_cache` option and cache is not returned as a default. + # So we disable `use_cache` here for PyTorch model. + pt_model.config.use_cache = False + fx_model = model_class(config, dtype=jnp.float32) + + pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) + + # make sure weights are tied in PyTorch + pt_model.tie_weights() + + # send pytorch model to the correct device + pt_model.to(torch_device) + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs) + fx_outputs = fx_model(**prepared_inputs_dict) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class) + + with tempfile.TemporaryDirectory() as tmpdirname: + fx_model.save_pretrained(tmpdirname) + pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True) + + # send pytorch model to the correct device + pt_model_loaded.to(torch_device) + pt_model_loaded.eval() + + with torch.no_grad(): + pt_outputs_loaded = pt_model_loaded(**pt_inputs) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class) # TODO: add tests for segmentation and depth estimation (logits) # @require_vision From e2a61c96cc0ac4548b9d4d7ee520f749ec73c720 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sat, 25 Jun 2022 23:13:13 +0200 Subject: [PATCH 12/22] consistency --- src/transformers/models/dpt/modeling_flax_dpt.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/dpt/modeling_flax_dpt.py b/src/transformers/models/dpt/modeling_flax_dpt.py index b1ff5c73ee7a..ad577ebbbf5a 100644 --- a/src/transformers/models/dpt/modeling_flax_dpt.py +++ b/src/transformers/models/dpt/modeling_flax_dpt.py @@ -81,7 +81,7 @@ # Copied from transformers.models.vit.modeling_flax_vit.FlaxViTPatchEmbeddings with ViT->DPT -class FlaxViTPatchEmbeddings(nn.Module): +class FlaxDPTPatchEmbeddings(nn.Module): config: DPTConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation @@ -91,6 +91,7 @@ def setup(self): patch_size = self.config.patch_size num_patches = (image_size // patch_size) * (image_size // patch_size) self.num_patches = num_patches + self.num_channels = self.config.num_channels self.projection = nn.Conv( self.config.hidden_size, kernel_size=(patch_size, patch_size), @@ -101,9 +102,14 @@ def setup(self): ) def __call__(self, pixel_values): - x = self.projection(pixel_values) - batch_size, _, _, channels = x.shape - return jnp.reshape(x, (batch_size, -1, channels)) + num_channels = pixel_values.shape[-1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + embeddings = self.projection(pixel_values) + batch_size, _, _, channels = embeddings.shape + return jnp.reshape(embeddings, (batch_size, -1, channels)) # Copied from transformers.models.vit.modeling_flax_vit.FlaxViTEmbeddings with ViT->DPT @@ -115,7 +121,7 @@ class FlaxDPTEmbeddings(nn.Module): def setup(self): self.cls_token = self.param("cls_token", nn.initializers.zeros, (1, 1, self.config.hidden_size)) - self.patch_embeddings = FlaxPatchEmbeddings(self.config, dtype=self.dtype) + self.patch_embeddings = FlaxDPTPatchEmbeddings(self.config, dtype=self.dtype) num_patches = self.patch_embeddings.num_patches self.position_embeddings = self.param( "position_embeddings", nn.initializers.zeros, (1, num_patches + 1, self.config.hidden_size) From e11b0430b19d2aaed28d9afa6060f0ea2f25e21a Mon Sep 17 00:00:00 2001 From: Arthur Date: Sat, 25 Jun 2022 23:13:38 +0200 Subject: [PATCH 13/22] style --- tests/models/dpt/test_modeling_flax_dpt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/dpt/test_modeling_flax_dpt.py b/tests/models/dpt/test_modeling_flax_dpt.py index 05cbe986e260..b63a9d895220 100644 --- a/tests/models/dpt/test_modeling_flax_dpt.py +++ b/tests/models/dpt/test_modeling_flax_dpt.py @@ -20,7 +20,7 @@ import numpy as np from transformers import DPTConfig, is_flax_available -from transformers.testing_utils import require_flax, slow, is_pt_flax_cross_test +from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow from ...test_configuration_common import ConfigTester from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor @@ -250,7 +250,6 @@ def test_equivalence_pt_to_flax(self): self.assertEqual(fx_keys, pt_keys) self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class) - @is_pt_flax_cross_test def test_equivalence_flax_to_pt(self): import tempfile @@ -321,6 +320,7 @@ def test_equivalence_flax_to_pt(self): self.assertEqual(fx_keys, pt_keys) self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class) + # TODO: add tests for segmentation and depth estimation (logits) # @require_vision # @slow From 05dcc855da2d1dd45e6438a596f21ba47ec0c3cc Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Sun, 26 Jun 2022 00:23:29 +0200 Subject: [PATCH 14/22] all tests should pas - added new attribute to config without breaking backward compatibility - modified a bit the tests --- .../models/dpt/configuration_dpt.py | 2 + src/transformers/models/dpt/modeling_dpt.py | 8 +- tests/models/dpt/test_modeling_dpt.py | 169 +++++++++++++++++- tests/models/dpt/test_modeling_flax_dpt.py | 1 + 4 files changed, 174 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/dpt/configuration_dpt.py b/src/transformers/models/dpt/configuration_dpt.py index a255b0596b4d..63bd764c5b2f 100644 --- a/src/transformers/models/dpt/configuration_dpt.py +++ b/src/transformers/models/dpt/configuration_dpt.py @@ -137,6 +137,7 @@ def __init__( auxiliary_loss_weight=0.4, semantic_loss_ignore_index=255, semantic_classifier_dropout=0.1, + align_corners=True, **kwargs ): super().__init__(**kwargs) @@ -168,3 +169,4 @@ def __init__( self.auxiliary_loss_weight = auxiliary_loss_weight self.semantic_loss_ignore_index = semantic_loss_ignore_index self.semantic_classifier_dropout = semantic_classifier_dropout + self.align_corners = align_corners diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index 7dfa244805ff..927818e3a84a 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -582,10 +582,10 @@ class DPTFeatureFusionLayer(nn.Module): The align_corner setting for bilinear upsample. """ - def __init__(self, config, align_corners=True): + def __init__(self, config): super().__init__() - self.align_corners = align_corners + self.align_corners = config.align_corners self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True) @@ -832,7 +832,7 @@ def __init__(self, config): features = config.fusion_hidden_size self.head = nn.Sequential( nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), - nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=self.config.align_corners), nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), ACT2FN["relu"], nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), @@ -979,7 +979,7 @@ def __init__(self, config): ACT2FN["relu"], nn.Dropout(config.semantic_classifier_dropout), nn.Conv2d(features, config.num_labels, kernel_size=1), - nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=self.config.align_corners), ) def forward(self, hidden_states: List[torch.Tensor]) -> torch.Tensor: diff --git a/tests/models/dpt/test_modeling_dpt.py b/tests/models/dpt/test_modeling_dpt.py index 3266ea78a71a..5c53ee9e2a43 100644 --- a/tests/models/dpt/test_modeling_dpt.py +++ b/tests/models/dpt/test_modeling_dpt.py @@ -19,9 +19,9 @@ import unittest from transformers import DPTConfig -from transformers.file_utils import is_torch_available, is_vision_available +from transformers.file_utils import is_flax_available, is_torch_available, is_vision_available from transformers.models.auto import get_values -from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.testing_utils import is_pt_flax_cross_test, require_torch, require_vision, slow, torch_device from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor @@ -40,6 +40,18 @@ from transformers import DPTFeatureExtractor +if is_torch_available() and is_flax_available(): + import tempfile + + import numpy as np + + import jax.numpy as jnp + import transformers + from transformers.modeling_flax_pytorch_utils import ( + convert_pytorch_state_dict_to_flax, + load_flax_weights_in_pytorch_model, + ) + class DPTModelTester: def __init__( @@ -239,6 +251,159 @@ def test_training_gradient_checkpointing(self): loss = model(**inputs).loss loss.backward() + @is_pt_flax_cross_test + def test_equivalence_pt_to_flax(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.align_corners = False + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + fx_model_class_name = "Flax" + model_class.__name__ + + if not hasattr(transformers, fx_model_class_name): + # no flax model exists for this class + return + + # Output all for aggressive testing + config.output_hidden_states = True + config.output_attentions = self.has_attentions + + fx_model_class = getattr(transformers, fx_model_class_name) + + # load PyTorch class + pt_model = model_class(config).eval() + # Flax models don't use the `use_cache` option and cache is not returned as a default. + # So we disable `use_cache` here for PyTorch model. + pt_model.config.use_cache = False + + # load Flax class + fx_model = fx_model_class(config, dtype=jnp.float32) + + # make sure only flax inputs are forward that actually exist in function args + fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() + + # prepare inputs + pt_inputs = self._prepare_for_class(inputs_dict, model_class) + + # remove function args that don't exist in Flax + pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} + + # send pytorch inputs to the correct device + pt_inputs = { + k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items() + } + + # convert inputs to Flax + fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} + + fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) + fx_model.params = fx_state + + # send pytorch model to the correct device + pt_model.to(torch_device) + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs) + fx_outputs = fx_model(**fx_inputs) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class) + + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, from_pt=True) + + fx_outputs_loaded = fx_model_loaded(**fx_inputs) + + fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class) + + @is_pt_flax_cross_test + def test_equivalence_flax_to_pt(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.align_corners = False + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + fx_model_class_name = "Flax" + model_class.__name__ + + if not hasattr(transformers, fx_model_class_name): + # no flax model exists for this class + return + + # Output all for aggressive testing + config.output_hidden_states = True + config.output_attentions = self.has_attentions + + fx_model_class = getattr(transformers, fx_model_class_name) + + # load PyTorch class + pt_model = model_class(config).eval() + # Flax models don't use the `use_cache` option and cache is not returned as a default. + # So we disable `use_cache` here for PyTorch model. + pt_model.config.use_cache = False + + # load Flax class + fx_model = fx_model_class(config, dtype=jnp.float32) + + # make sure only flax inputs are forward that actually exist in function args + fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() + + # prepare inputs + pt_inputs = self._prepare_for_class(inputs_dict, model_class) + + # remove function args that don't exist in Flax + pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} + + # send pytorch inputs to the correct device + pt_inputs = { + k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items() + } + + # convert inputs to Flax + fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} + + pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) + + # make sure weights are tied in PyTorch + pt_model.tie_weights() + + # send pytorch model to the correct device + pt_model.to(torch_device) + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs) + fx_outputs = fx_model(**fx_inputs) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class) + + with tempfile.TemporaryDirectory() as tmpdirname: + fx_model.save_pretrained(tmpdirname) + pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True) + + # send pytorch model to the correct device + pt_model_loaded.to(torch_device) + pt_model_loaded.eval() + + with torch.no_grad(): + pt_outputs_loaded = pt_model_loaded(**pt_inputs) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class) + @slow def test_model_from_pretrained(self): for model_name in DPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/models/dpt/test_modeling_flax_dpt.py b/tests/models/dpt/test_modeling_flax_dpt.py index b63a9d895220..7ad3c453fc70 100644 --- a/tests/models/dpt/test_modeling_flax_dpt.py +++ b/tests/models/dpt/test_modeling_flax_dpt.py @@ -102,6 +102,7 @@ def prepare_config_and_inputs(self): attention_probs_dropout_prob=self.attention_probs_dropout_prob, is_decoder=False, initializer_range=self.initializer_range, + align_corners=False, ) return config, pixel_values, labels From 7c26f69f264305338e1fc285ecd8b4c8e3aeaf38 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 10 Aug 2022 15:26:27 +0200 Subject: [PATCH 15/22] fixing few comments --- src/transformers/models/dpt/gradient_convolution.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/dpt/gradient_convolution.py b/src/transformers/models/dpt/gradient_convolution.py index 9209a65b7de4..bd0f3b562421 100644 --- a/src/transformers/models/dpt/gradient_convolution.py +++ b/src/transformers/models/dpt/gradient_convolution.py @@ -14,10 +14,8 @@ PRNGKey = Any Shape = Tuple[int, ...] -Dtype = Any # this could be a real type? +Dtype = jnp.dtype # this could be a real type? Array = Any -PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]] - PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]] PaddingLike = Union[str, int, Sequence[Union[int, Tuple[int, int]]]] LaxPadding = Union[str, Sequence[Tuple[int, int]]] @@ -47,7 +45,7 @@ class ConvDimensionNumbers(NamedTuple): def _flip_axes(x, axes): """Flip ndarray 'x' along each axis specified in axes tuple.""" for axis in axes: - x = np.flip(x, axis) + x = jnp.flip(x, axis) return x @@ -183,6 +181,7 @@ def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding: ) +# Copied from a contributor's PR on jax: https://github.com/yang-song/jax/commit/883a7c9e812e0f7af8dffa0eb54d017fd2200f10 def _compute_adjusted_padding( input_size: int, output_size: int, From 350558278b8ef7dbd8ed2551dc46ee8496e8ab11 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 10 Aug 2022 15:28:48 +0200 Subject: [PATCH 16/22] Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> --- src/transformers/models/dpt/gradient_convolution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/dpt/gradient_convolution.py b/src/transformers/models/dpt/gradient_convolution.py index bd0f3b562421..7228bdb950b5 100644 --- a/src/transformers/models/dpt/gradient_convolution.py +++ b/src/transformers/models/dpt/gradient_convolution.py @@ -332,7 +332,7 @@ def gradient_based_conv_transpose( class ConvTransposeGradient(nn.Module): - """Convolution Module wrapping lax.conv_transpose. + """Convolution Module wrapping lax.conv_transpose. Calculates transposed convolutions via gradient (transpose) of a forward convolution. Attributes: features: number of convolution filters. @@ -375,7 +375,7 @@ class ConvTransposeGradient(nn.Module): def __call__(self, inputs: Array) -> Array: """Applies a transposed convolution to the inputs. - Behaviour mirrors of `jax.lax.conv_transpose`. + Behaviour mirrors of `jax.lax.conv_transpose`, computing transposed convolutions via the gradient (transpose) of a forward convolutions. Args: inputs: input data with dimensions (batch, spatial_dims..., features). From 425508e403579a281a909b3b923ff4fece3ba543 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 10 Aug 2022 15:32:47 +0200 Subject: [PATCH 17/22] Update src/transformers/models/dpt/gradient_convolution.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> --- src/transformers/models/dpt/gradient_convolution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/dpt/gradient_convolution.py b/src/transformers/models/dpt/gradient_convolution.py index 7228bdb950b5..b416a36b4d29 100644 --- a/src/transformers/models/dpt/gradient_convolution.py +++ b/src/transformers/models/dpt/gradient_convolution.py @@ -365,7 +365,7 @@ class ConvTransposeGradient(nn.Module): padding: PaddingLike = "SAME" kernel_dilation: Optional[Sequence[int]] = None use_bias: bool = True - dtype: Dtype = jnp.float32 + dtype: jnp.dtype = jnp.float32 param_dtype: Dtype = jnp.float32 precision: PrecisionLike = None kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init From 24aeb4d13a735ad11241736233b26a80ba203d35 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 10 Aug 2022 16:22:57 +0200 Subject: [PATCH 18/22] add few comments --- src/transformers/models/dpt/gradient_convolution.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/dpt/gradient_convolution.py b/src/transformers/models/dpt/gradient_convolution.py index b416a36b4d29..d912251d46f6 100644 --- a/src/transformers/models/dpt/gradient_convolution.py +++ b/src/transformers/models/dpt/gradient_convolution.py @@ -181,7 +181,7 @@ def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding: ) -# Copied from a contributor's PR on jax: https://github.com/yang-song/jax/commit/883a7c9e812e0f7af8dffa0eb54d017fd2200f10 +# Copied from a contributor's PR on jax: https://github.com/yang-song/jax/commit/883a7c9e812e0f7af8dffa0eb54d017fd2200f10 def _compute_adjusted_padding( input_size: int, output_size: int, @@ -332,7 +332,8 @@ def gradient_based_conv_transpose( class ConvTransposeGradient(nn.Module): - """Convolution Module wrapping lax.conv_transpose. Calculates transposed convolutions via gradient (transpose) of a forward convolution. + """Convolution Module wrapping lax.conv_transpose. Calculates transposed convolutions via gradient (transpose) of a +forward convolution. Attributes: features: number of convolution filters. @@ -374,8 +375,12 @@ class ConvTransposeGradient(nn.Module): @compact def __call__(self, inputs: Array) -> Array: """Applies a transposed convolution to the inputs. + Here we mimic the implementation of the ConvTranspose module: + https://flax.readthedocs.io/en/latest/_modules/flax/linen/linear.html#ConvTranspose and define the `__call__` + method with the @compact decorator - Behaviour mirrors of `jax.lax.conv_transpose`, computing transposed convolutions via the gradient (transpose) of a forward convolutions. + Behaviour mirrors of `jax.lax.conv_transpose`, computing transposed convolutions via the gradient (transpose) + of a forward convolutions. Args: inputs: input data with dimensions (batch, spatial_dims..., features). From 479d0e83142ec8260f33b92a561f0c45bf432d02 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 10 Aug 2022 16:26:27 +0200 Subject: [PATCH 19/22] Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> --- .../models/dpt/modeling_flax_dpt.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/dpt/modeling_flax_dpt.py b/src/transformers/models/dpt/modeling_flax_dpt.py index ad577ebbbf5a..fe62cde47ee7 100644 --- a/src/transformers/models/dpt/modeling_flax_dpt.py +++ b/src/transformers/models/dpt/modeling_flax_dpt.py @@ -205,7 +205,7 @@ def __call__(self, hidden_states, deterministic: bool = True, output_attentions: outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) return outputs - +`# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTOutput with ViT->DPT` class FlaxDPTViTOutput(nn.Module): config: DPTConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation @@ -506,7 +506,7 @@ class FlaxDPTFeatureFusionLayer(nn.Module): align_corners: bool = True def setup(self): - self.projection = nn.Conv(self.config.fusion_hidden_size, kernel_size=(1, 1)) # , bias=True) + self.projection = nn.Conv(self.config.fusion_hidden_size, kernel_size=(1, 1), use_bias=True) self.residual_layer1 = FlaxDPTPreActResidualLayer(self.config) self.residual_layer2 = FlaxDPTPreActResidualLayer(self.config) @@ -516,16 +516,16 @@ def __call__(self, hidden_state, residual=None): if residual is not None: if hidden_state.shape != residual.shape: size = hidden_state.shape - residual = self.upsample(residual, size) + residual = self.upsample(residual, align_corners=True) hidden_state = hidden_state + self.residual_layer1(residual) hidden_state = self.residual_layer2(hidden_state) - hidden_state = self.upsample(hidden_state) + hidden_state = self.upsample(hidden_state, align_corners=self.align_corners) hidden_state = self.projection(hidden_state) return hidden_state - +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTLayer with ViTConfig->DPTConfig, ViTAttention->DPTViTAttention, ViTIntermediate->DPTViTIntermediate, ViTOutput->DPTViTOutput class FlaxDPTViTLayer(nn.Module): config: DPTConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation @@ -561,7 +561,7 @@ def __call__(self, hidden_states, deterministic: bool = True, output_attentions: outputs += (attention_outputs[1],) return outputs - +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTLayerCollection with ViTConfig->DPTConfig, ViTLayer->DPTViTLayer class FlaxDPTViTLayerCollection(nn.Module): config: DPTConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation @@ -604,7 +604,7 @@ def __call__( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions ) - +# Copied from transformers.models.vit.modeling_vit.FlaxViTPooler with ViT->DPT class FlaxDPTViTPooler(nn.Module): config: DPTConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation @@ -621,7 +621,7 @@ def __call__(self, hidden_states): cls_hidden_state = self.dense(cls_hidden_state) return nn.tanh(cls_hidden_state) - +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTEncoder with ViTConfig -> DPTConfig, ViTLayer->DPTViTLayer class FlaxDPTViTEncoder(nn.Module): config: DPTConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation @@ -698,7 +698,7 @@ def __init__( if input_shape is None: input_shape = (1, config.image_size, config.image_size, 3) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - + # Copied from transformers.models.FlaxViTPreTrainedModel.init_weights def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors pixel_values = jnp.zeros(input_shape, dtype=self.dtype) @@ -719,6 +719,7 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz return random_params @add_start_docstrings_to_model_forward(DPT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + # Copied from transformers.models.FlaxViTPreTrainedModel.__call__ def __call__( self, pixel_values, @@ -863,8 +864,8 @@ class FlaxDPTUpsample(nn.Module): def setup(self): pass - def __call__(self, x, output_size=None): - if output_size is None: + def __call__(self, x, align_corners=True): + if not align_corners: output_size = x.shape output_size = (output_size[0], output_size[1] * self.scale, output_size[2] * self.scale, output_size[3]) return jax.image.resize(x, output_size, method=self.method) From 791ea245ed3470463d3a9e6fee2c49d384a4a71a Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 10 Aug 2022 16:59:39 +0200 Subject: [PATCH 20/22] refactor a bit: - removed `copied_from` on non module objects - check why `FlaxDPTViTLayerCollection` is not copied from `FlaxViTLayerCollection` --- .../models/dpt/gradient_convolution.py | 50 +++++++++---------- .../models/dpt/modeling_flax_dpt.py | 24 +++++---- 2 files changed, 38 insertions(+), 36 deletions(-) diff --git a/src/transformers/models/dpt/gradient_convolution.py b/src/transformers/models/dpt/gradient_convolution.py index d912251d46f6..1c949af46ed9 100644 --- a/src/transformers/models/dpt/gradient_convolution.py +++ b/src/transformers/models/dpt/gradient_convolution.py @@ -333,31 +333,31 @@ def gradient_based_conv_transpose( class ConvTransposeGradient(nn.Module): """Convolution Module wrapping lax.conv_transpose. Calculates transposed convolutions via gradient (transpose) of a -forward convolution. - - Attributes: - features: number of convolution filters. - kernel_size: shape of the convolutional kernel. For 1D convolution, - the kernel size can be passed as an integer. For all other cases, it must - be a sequence of integers. - strides: a sequence of `n` integers, representing the inter-window strides. - padding: either the string `'SAME'`, the string `'VALID'`, the string - `'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low, - high)` integer pairs that give the padding to apply before and after each - spatial dimension. A single int is interpeted as applying the same padding - in all dims and passign a single int in a sequence causes the same padding - to be used on both sides. - kernel_dilation: `None`, or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of the convolution - kernel. Convolution with kernel dilation is also known as 'atrous - convolution'. - use_bias: whether to add a bias to the output (default: True). - dtype: the dtype of the computation (default: float32). - param_dtype: the dtype passed to parameter initializers (default: float32). - precision: numerical precision of the computation see `jax.lax.Precision` - for details. - kernel_init: initializer for the convolutional kernel. - bias_init: initializer for the bias. + forward convolution. + + Attributes: + features: number of convolution filters. + kernel_size: shape of the convolutional kernel. For 1D convolution, + the kernel size can be passed as an integer. For all other cases, it must + be a sequence of integers. + strides: a sequence of `n` integers, representing the inter-window strides. + padding: either the string `'SAME'`, the string `'VALID'`, the string + `'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low, + high)` integer pairs that give the padding to apply before and after each + spatial dimension. A single int is interpeted as applying the same padding + in all dims and passign a single int in a sequence causes the same padding + to be used on both sides. + kernel_dilation: `None`, or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of the convolution + kernel. Convolution with kernel dilation is also known as 'atrous + convolution'. + use_bias: whether to add a bias to the output (default: True). + dtype: the dtype of the computation (default: float32). + param_dtype: the dtype passed to parameter initializers (default: float32). + precision: numerical precision of the computation see `jax.lax.Precision` + for details. + kernel_init: initializer for the convolutional kernel. + bias_init: initializer for the bias. """ features: int diff --git a/src/transformers/models/dpt/modeling_flax_dpt.py b/src/transformers/models/dpt/modeling_flax_dpt.py index fe62cde47ee7..0f07025d6142 100644 --- a/src/transformers/models/dpt/modeling_flax_dpt.py +++ b/src/transformers/models/dpt/modeling_flax_dpt.py @@ -205,7 +205,8 @@ def __call__(self, hidden_states, deterministic: bool = True, output_attentions: outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) return outputs -`# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTOutput with ViT->DPT` + +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTOutput with ViT->DPT class FlaxDPTViTOutput(nn.Module): config: DPTConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation @@ -503,7 +504,6 @@ def __call__(self, hidden_states, residual=None, i=0): class FlaxDPTFeatureFusionLayer(nn.Module): config: DPTConfig dtype: jnp.dtype = jnp.float32 - align_corners: bool = True def setup(self): self.projection = nn.Conv(self.config.fusion_hidden_size, kernel_size=(1, 1), use_bias=True) @@ -516,15 +516,16 @@ def __call__(self, hidden_state, residual=None): if residual is not None: if hidden_state.shape != residual.shape: size = hidden_state.shape - residual = self.upsample(residual, align_corners=True) + residual = self.upsample(residual, output_size=size) hidden_state = hidden_state + self.residual_layer1(residual) hidden_state = self.residual_layer2(hidden_state) - hidden_state = self.upsample(hidden_state, align_corners=self.align_corners) + hidden_state = self.upsample(hidden_state) hidden_state = self.projection(hidden_state) return hidden_state + # Copied from transformers.models.vit.modeling_flax_vit.FlaxViTLayer with ViTConfig->DPTConfig, ViTAttention->DPTViTAttention, ViTIntermediate->DPTViTIntermediate, ViTOutput->DPTViTOutput class FlaxDPTViTLayer(nn.Module): config: DPTConfig @@ -561,7 +562,7 @@ def __call__(self, hidden_states, deterministic: bool = True, output_attentions: outputs += (attention_outputs[1],) return outputs -# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTLayerCollection with ViTConfig->DPTConfig, ViTLayer->DPTViTLayer + class FlaxDPTViTLayerCollection(nn.Module): config: DPTConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation @@ -604,7 +605,8 @@ def __call__( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions ) -# Copied from transformers.models.vit.modeling_vit.FlaxViTPooler with ViT->DPT + +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTPooler with ViT->DPT class FlaxDPTViTPooler(nn.Module): config: DPTConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation @@ -621,7 +623,8 @@ def __call__(self, hidden_states): cls_hidden_state = self.dense(cls_hidden_state) return nn.tanh(cls_hidden_state) -# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTEncoder with ViTConfig -> DPTConfig, ViTLayer->DPTViTLayer + +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTEncoder with ViTConfig->DPTConfig, FlaxViTLayerCollection->FlaxDPTViTLayerCollection class FlaxDPTViTEncoder(nn.Module): config: DPTConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation @@ -698,7 +701,7 @@ def __init__( if input_shape is None: input_shape = (1, config.image_size, config.image_size, 3) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - # Copied from transformers.models.FlaxViTPreTrainedModel.init_weights + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors pixel_values = jnp.zeros(input_shape, dtype=self.dtype) @@ -719,7 +722,6 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz return random_params @add_start_docstrings_to_model_forward(DPT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - # Copied from transformers.models.FlaxViTPreTrainedModel.__call__ def __call__( self, pixel_values, @@ -864,8 +866,8 @@ class FlaxDPTUpsample(nn.Module): def setup(self): pass - def __call__(self, x, align_corners=True): - if not align_corners: + def __call__(self, x, output_size=None): + if output_size is None: output_size = x.shape output_size = (output_size[0], output_size[1] * self.scale, output_size[2] * self.scale, output_size[3]) return jax.image.resize(x, output_size, method=self.method) From a40f21dfe73079f26c5afccccad1e1cf14d15ac8 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 10 Aug 2022 18:49:52 +0200 Subject: [PATCH 21/22] add comments on key naming strategy --- src/transformers/models/dpt/modeling_flax_dpt.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/models/dpt/modeling_flax_dpt.py b/src/transformers/models/dpt/modeling_flax_dpt.py index 0f07025d6142..9eac7ce6e1bc 100644 --- a/src/transformers/models/dpt/modeling_flax_dpt.py +++ b/src/transformers/models/dpt/modeling_flax_dpt.py @@ -348,6 +348,9 @@ class FlaxDPTReadoutProjectSequentialCollectionLayer(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): + # Here we set the name of the layer to be explicitly `0` because in the + # Pytorch modeling script we use nn.Sequential layer that sets the + # key of this dense layer to 0 by default. self.dense = nn.Dense(self.config.hidden_size, name="0") self.act = ACT2FN[self.config.hidden_act] From f945eef63b733bd247b0d5d7c83a9fa28f06467c Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 10 Aug 2022 20:04:35 +0200 Subject: [PATCH 22/22] few modifications - added correct link for `CopiedFrom` - Added explicit argument for transposed conv on model def --- src/transformers/models/dpt/gradient_convolution.py | 7 +++++-- src/transformers/models/dpt/modeling_flax_dpt.py | 1 + 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/dpt/gradient_convolution.py b/src/transformers/models/dpt/gradient_convolution.py index 1c949af46ed9..39c535f0eb48 100644 --- a/src/transformers/models/dpt/gradient_convolution.py +++ b/src/transformers/models/dpt/gradient_convolution.py @@ -182,6 +182,7 @@ def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding: # Copied from a contributor's PR on jax: https://github.com/yang-song/jax/commit/883a7c9e812e0f7af8dffa0eb54d017fd2200f10 +# originally copied from: https://github.com/deepmind/dm-haiku/blob/c5dddba7c99850e4ce5a6d8bb17fc3f06c236359/haiku/_src/conv.py#L430 def _compute_adjusted_padding( input_size: int, output_size: int, @@ -205,7 +206,7 @@ def _compute_adjusted_padding( padding_before = 0 elif padding == "SAME": expected_input_size = (output_size + stride - 1) // stride - if input_size != expected_input_size: + if expected_input_size != input_size: raise ValueError( "The expected input size with the current set of input " f"parameters is {expected_input_size} which doesn't " @@ -291,8 +292,10 @@ def gradient_based_conv_transpose( raise ValueError("No 4+ dimensional dimension_number defaults.") dn = conv_dimension_numbers_(lhs.shape, rhs.shape, dimension_numbers) # dn = dimension_numbers + # k_shape = jnp.take(jnp.array(rhs.shape), jnp.array(dn.rhs_spec)) k_shape = np.take(rhs.shape, dn.rhs_spec) k_sdims = k_shape[2:] # type: ignore[index] + # i_shape = jnp.take(jnp.array(lhs.shape), jnp.array(dn.lhs_spec)) i_shape = np.take(lhs.shape, dn.lhs_spec) i_sdims = i_shape[2:] # type: ignore[index] @@ -363,7 +366,7 @@ class ConvTransposeGradient(nn.Module): features: int kernel_size: Union[int, Tuple[int, ...]] strides: Optional[Tuple[int, ...]] = None - padding: PaddingLike = "SAME" + padding: PaddingLike = (0, 0) kernel_dilation: Optional[Sequence[int]] = None use_bias: bool = True dtype: jnp.dtype = jnp.float32 diff --git a/src/transformers/models/dpt/modeling_flax_dpt.py b/src/transformers/models/dpt/modeling_flax_dpt.py index 9eac7ce6e1bc..ea7edad078d1 100644 --- a/src/transformers/models/dpt/modeling_flax_dpt.py +++ b/src/transformers/models/dpt/modeling_flax_dpt.py @@ -310,6 +310,7 @@ def setup(self): kernel_size=(self.factor, self.factor), strides=(self.factor, self.factor), use_bias=True, + padding="SAME", ) elif self.factor < 1: # so should downsample