diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 5c0d51d8b7af..7b3fef4618f5 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -225,7 +225,7 @@ Flax), PyTorch, and/or TensorFlow. | DETR | ❌ | ❌ | ✅ | ❌ | ❌ | | DistilBERT | ✅ | ✅ | ✅ | ✅ | ✅ | | DPR | ✅ | ✅ | ✅ | ✅ | ❌ | -| DPT | ❌ | ❌ | ✅ | ❌ | ❌ | +| DPT | ❌ | ❌ | ✅ | ❌ | ✅ | | ELECTRA | ✅ | ✅ | ✅ | ✅ | ✅ | | Encoder decoder | ❌ | ❌ | ✅ | ✅ | ✅ | | FairSeq Machine-Translation | ✅ | ❌ | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/dpt.mdx b/docs/source/en/model_doc/dpt.mdx index cdf009c6c8a0..1f33507eae71 100644 --- a/docs/source/en/model_doc/dpt.mdx +++ b/docs/source/en/model_doc/dpt.mdx @@ -54,4 +54,19 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The origi ## DPTForSemanticSegmentation [[autodoc]] DPTForSemanticSegmentation - - forward \ No newline at end of file + - forward + +## FlaxDPTForSemanticSegmentation + +[[autodoc]] FlaxDPTForSemanticSegmentation + - __call__ + +## FlaxDPTForDepthEstimation + +[[autodoc]] FlaxDPTForDepthEstimation + - __call__ + +## FlaxDPTModel + +[[autodoc]] FlaxDPTModel + - __call__ \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index be2be2727f01..992cb50c4ecf 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2745,6 +2745,14 @@ "FlaxDistilBertPreTrainedModel", ] ) + _import_structure["models.dpt"].extend( + [ + "FlaxDPTModel", + "FlaxDPTPreTrainedModel", + "FlaxDPTForSemanticSegmentation", + "FlaxDPTForDepthEstimation", + ] + ) _import_structure["models.electra"].extend( [ "FlaxElectraForCausalLM", @@ -5101,6 +5109,12 @@ FlaxDistilBertModel, FlaxDistilBertPreTrainedModel, ) + from .models.dpt import ( + FlaxDPTForDepthEstimation, + FlaxDPTForSemanticSegmentation, + FlaxDPTModel, + FlaxDPTPreTrainedModel, + ) from .models.electra import ( FlaxElectraForCausalLM, FlaxElectraForMaskedLM, diff --git a/src/transformers/modeling_flax_outputs.py b/src/transformers/modeling_flax_outputs.py index 4f6cc5a901f8..cb93bd319536 100644 --- a/src/transformers/modeling_flax_outputs.py +++ b/src/transformers/modeling_flax_outputs.py @@ -640,3 +640,72 @@ class FlaxSeq2SeqQuestionAnsweringModelOutput(ModelOutput): encoder_last_hidden_state: Optional[jnp.ndarray] = None encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None encoder_attentions: Optional[Tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxDepthEstimatorOutput(ModelOutput): + """ + Base class for outputs of depth estimation models. + + Args: + loss (`jnp.ndarray` of shape `(1,)`, *optional*, returned when `labels` is provided): + Depth Estimation loss. + predicted_depth (`jnp.ndarray` of shape `(batch_size, height, width)`): + Predicted depth for each pixel. + + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one + for the output of each layer) of shape `(batch_size, num_channels, height, width)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: jnp.ndarray = None + predicted_depth: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxSemanticSegmenterOutput(ModelOutput): + """ + Base class for outputs of depth estimation models. + + Args: + loss (`jnp.ndarray` of shape `(1,)`, *optional*, returned when `labels` is provided): + Semantic Segmentation loss (CrossEntropy Loss). + logits (`jnp.ndarray` of shape `(batch_size, config.num_labels, logits_height, logits_width)`): + Semantic Segmentation raw logits for each pixel. + + + + The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is + to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the + original image size as post-processing. You should always check your logits shape and resize as needed. + + + + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one + for the output of each layer) of shape `(batch_size, num_channels, height, width)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: jnp.ndarray = None + logits: jnp.ndarray = None + predicted_depth: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index 98c5d6fb5a10..6d3aa807d7ad 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -37,6 +37,7 @@ ("blenderbot-small", "FlaxBlenderbotSmallModel"), ("clip", "FlaxCLIPModel"), ("distilbert", "FlaxDistilBertModel"), + ("dpt", "FlaxDPTModel"), ("electra", "FlaxElectraModel"), ("gpt2", "FlaxGPT2Model"), ("gpt_neo", "FlaxGPTNeoModel"), @@ -211,6 +212,17 @@ ] ) +FLAX_MODEL_FOR_DEPTH_ESTIMATION_MAPPING = OrderedDict( + [ + ("dpt", "FlaxDPTForDepthEstimation"), + ] +) + +FLAX_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = OrderedDict( + [ + ("dpt", "FlaxDPTForSemanticSegmentation"), + ] +) FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES) FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES) @@ -241,6 +253,12 @@ FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES ) +FLAX_MODEL_FOR_DEPTH_ESTIMATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_DEPTH_ESTIMATION_MAPPING +) +FLAX_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING +) class FlaxAutoModel(_BaseAutoModelClass): @@ -344,3 +362,21 @@ class FlaxAutoModelForSpeechSeq2Seq(_BaseAutoModelClass): FlaxAutoModelForSpeechSeq2Seq = auto_class_update( FlaxAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling" ) + + +class FlaxAutoModelForDepthEstimation(_BaseAutoModelClass): + _model_mapping = FLAX_MODEL_FOR_DEPTH_ESTIMATION_MAPPING + + +FlaxAutoModelForDepthEstimation = auto_class_update( + FlaxAutoModelForDepthEstimation, head_doc="depth estimation modeling" +) + + +class FlaxAutoModelForSemanticSegmentation(_BaseAutoModelClass): + _model_mapping = FLAX_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING + + +FlaxAutoModelForSemanticSegmentation = auto_class_update( + FlaxAutoModelForSemanticSegmentation, head_doc="semantic segmentation modeling" +) diff --git a/src/transformers/models/dpt/__init__.py b/src/transformers/models/dpt/__init__.py index 1df82ab62824..2de43595c57d 100644 --- a/src/transformers/models/dpt/__init__.py +++ b/src/transformers/models/dpt/__init__.py @@ -17,7 +17,13 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...file_utils import _LazyModule, is_tokenizers_available, is_torch_available, is_vision_available +from ...file_utils import ( + _LazyModule, + is_flax_available, + is_tokenizers_available, + is_torch_available, + is_vision_available, +) from ...utils import OptionalDependencyNotAvailable @@ -45,6 +51,19 @@ "DPTPreTrainedModel", ] +try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_flax_dpt"] = [ + "FlaxDPTForSemanticSegmentation", + "FlaxDPTForDepthEstimation", + "FlaxDPTModel", + "FlaxDPTPreTrainedModel", + ] + if TYPE_CHECKING: from .configuration_dpt import DPT_PRETRAINED_CONFIG_ARCHIVE_MAP, DPTConfig @@ -71,6 +90,18 @@ DPTPreTrainedModel, ) + try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_flax_dpt import ( + FlaxDPTForDepthEstimation, + FlaxDPTForSemanticSegmentation, + FlaxDPTModel, + FlaxDPTPreTrainedModel, + ) else: import sys diff --git a/src/transformers/models/dpt/configuration_dpt.py b/src/transformers/models/dpt/configuration_dpt.py index a255b0596b4d..63bd764c5b2f 100644 --- a/src/transformers/models/dpt/configuration_dpt.py +++ b/src/transformers/models/dpt/configuration_dpt.py @@ -137,6 +137,7 @@ def __init__( auxiliary_loss_weight=0.4, semantic_loss_ignore_index=255, semantic_classifier_dropout=0.1, + align_corners=True, **kwargs ): super().__init__(**kwargs) @@ -168,3 +169,4 @@ def __init__( self.auxiliary_loss_weight = auxiliary_loss_weight self.semantic_loss_ignore_index = semantic_loss_ignore_index self.semantic_classifier_dropout = semantic_classifier_dropout + self.align_corners = align_corners diff --git a/src/transformers/models/dpt/gradient_convolution.py b/src/transformers/models/dpt/gradient_convolution.py new file mode 100644 index 000000000000..39c535f0eb48 --- /dev/null +++ b/src/transformers/models/dpt/gradient_convolution.py @@ -0,0 +1,458 @@ +from typing import Any, Callable, NamedTuple, Optional, Sequence, Tuple, Union + +import numpy as np + +import flax.linen as nn +import jax.numpy as jnp +from flax.linen.initializers import lecun_normal, zeros +from flax.linen.module import compact +from jax import lax +from jax.lax import conv_general_dilated + + +default_kernel_init = lecun_normal() + +PRNGKey = Any +Shape = Tuple[int, ...] +Dtype = jnp.dtype # this could be a real type? +Array = Any +PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]] +PaddingLike = Union[str, int, Sequence[Union[int, Tuple[int, int]]]] +LaxPadding = Union[str, Sequence[Tuple[int, int]]] + + +class ConvDimensionNumbers(NamedTuple): + """ + Describes batch, spatial, and feature dimensions of a convolution. + + Args: + lhs_spec: a tuple of nonnegative integer dimension numbers containing + `(batch dimension, feature dimension, spatial dimensions...)`. + rhs_spec: a tuple of nonnegative integer dimension numbers containing + `(out feature dimension, in feature dimension, spatial dimensions...)`. + out_spec: a tuple of nonnegative integer dimension numbers containing + `(batch dimension, feature dimension, spatial dimensions...)`. + """ + + lhs_spec: Sequence[int] + rhs_spec: Sequence[int] + out_spec: Sequence[int] + + +ConvGeneralDilatedDimensionNumbers = Union[None, ConvDimensionNumbers, Tuple[str, str, str]] + + +def _flip_axes(x, axes): + """Flip ndarray 'x' along each axis specified in axes tuple.""" + for axis in axes: + x = jnp.flip(x, axis) + return x + + +def conv_general_permutations(dimension_numbers): + """Utility for convolution dimension permutations relative to Conv HLO.""" + lhs_spec, rhs_spec, out_spec = dimension_numbers + lhs_char, rhs_char, out_char = charpairs = ("N", "C"), ("O", "I"), ("N", "C") + for i, (a, b) in enumerate(charpairs): + if not dimension_numbers[i].count(a) == dimension_numbers[i].count(b) == 1: + msg = "convolution dimension_numbers[{}] must contain the characters '{}' and '{}' exactly once, got {}." + raise TypeError(msg.format(i, a, b, dimension_numbers[i])) + if len(dimension_numbers[i]) != len(set(dimension_numbers[i])): + msg = "convolution dimension_numbers[{}] cannot have duplicate characters, got {}." + raise TypeError(msg.format(i, dimension_numbers[i])) + if not (set(lhs_spec) - set(lhs_char) == set(rhs_spec) - set(rhs_char) == set(out_spec) - set(out_char)): + msg = "convolution dimension_numbers elements must each have the same set of spatial characters, got {}." + raise TypeError(msg.format(dimension_numbers)) + + def getperm(spec, charpair): + spatial = (i for i, c in enumerate(spec) if c not in charpair) + if spec is not rhs_spec: + spatial = sorted(spatial, key=lambda i: rhs_spec.index(spec[i])) + return (spec.index(charpair[0]), spec.index(charpair[1])) + tuple(spatial) + + lhs_perm, rhs_perm, out_perm = map(getperm, dimension_numbers, charpairs) + return lhs_perm, rhs_perm, out_perm + + +def conv_dimension_numbers_(lhs_shape, rhs_shape, dimension_numbers) -> ConvDimensionNumbers: + """Converts convolution `dimension_numbers` to a `ConvDimensionNumbers`. + + Args: + lhs_shape: tuple of nonnegative integers, shape of the convolution input. + rhs_shape: tuple of nonnegative integers, shape of the convolution kernel. + dimension_numbers: None or a tuple/list of strings or a ConvDimensionNumbers + object following the convolution dimension number specification format in xla_client.py. + + Returns: + A `ConvDimensionNumbers` object that represents `dimension_numbers` in the canonical form used by lax functions. + """ + if isinstance(dimension_numbers, ConvDimensionNumbers): + return dimension_numbers + if len(lhs_shape) != len(rhs_shape): + msg = "convolution requires lhs and rhs ndim to be equal, got {} and {}." + raise TypeError(msg.format(len(lhs_shape), len(rhs_shape))) + + if dimension_numbers is None: + iota = tuple(range(len(lhs_shape))) + return ConvDimensionNumbers(iota, iota, iota) + elif isinstance(dimension_numbers, (list, tuple)): + if len(dimension_numbers) != 3: + msg = "convolution dimension_numbers list/tuple must be length 3, got {}." + raise TypeError(msg.format(len(dimension_numbers))) + if not all(isinstance(elt, str) for elt in dimension_numbers): + msg = "convolution dimension_numbers elements must be strings, got {}." + raise TypeError(msg.format(tuple(map(type, dimension_numbers)))) + msg = ( + "convolution dimension_numbers[{}] must have len equal to the ndim " + "of lhs and rhs, got {} for lhs and rhs shapes {} and {}." + ) + for i, elt in enumerate(dimension_numbers): + if len(elt) != len(lhs_shape): + raise TypeError(msg.format(i, len(elt), lhs_shape, rhs_shape)) + + lhs_spec, rhs_spec, out_spec = conv_general_permutations(dimension_numbers) + return ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec) + else: + msg = "convolution dimension_numbers must be tuple/list or None, got {}." + raise TypeError(msg.format(type(dimension_numbers))) + + +def _deconv_output_length(input_length, filter_size, padding, output_padding=None, stride=0, dilation=1): + """Determines the output length of a transposed convolution given the input length. + Arguments: + Function modified from Keras. + input_length: Integer. filter_size: Integer. padding: one of `"SAME"`, `"VALID"`, or a 2-integer tuple. + output_padding: Integer, amount of padding along the output dimension. Can + be set to `None` in which case the output length is inferred. + stride: Integer. dilation: Integer. + Returns: + The output length (integer). + """ + if input_length is None: + return None + + # Get the dilated kernel size + filter_size = filter_size + (filter_size - 1) * (dilation - 1) + + # Infer length if output padding is None, else compute the exact length + if output_padding is None: + if padding == "VALID": + length = input_length * stride + max(filter_size - stride, 0) + elif padding == "SAME": + length = input_length * stride + else: + length = (input_length - 1) * stride + filter_size - padding[0] - padding[1] + + else: + if padding == "SAME": + pad = filter_size // 2 + total_pad = pad * 2 + elif padding == "VALID": + total_pad = 0 + else: + total_pad = padding[0] + padding[1] + + length = (input_length - 1) * stride + filter_size - total_pad + output_padding + + return length + + +def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding: + if isinstance(padding, str): + return padding + if isinstance(padding, int): + return [(padding, padding)] * rank + + if isinstance(padding, Sequence) and len(padding) == rank: + new_pad = [] + for p in padding: + if isinstance(p, int): + new_pad.append((p, p)) + elif isinstance(p, tuple) and len(p) == 2: + new_pad.append(p) + else: + break + if len(new_pad) == rank: + return new_pad + raise ValueError( + f"Invalid padding format: {padding}, should be str, int," + f" or a sequence of len {rank} where each element is an" + " int or pair of ints." + ) + + +# Copied from a contributor's PR on jax: https://github.com/yang-song/jax/commit/883a7c9e812e0f7af8dffa0eb54d017fd2200f10 +# originally copied from: https://github.com/deepmind/dm-haiku/blob/c5dddba7c99850e4ce5a6d8bb17fc3f06c236359/haiku/_src/conv.py#L430 +def _compute_adjusted_padding( + input_size: int, + output_size: int, + kernel_size: int, + stride: int, + padding: Union[str, Tuple[int, int]], + dilation: int = 1, +) -> Tuple[int, int]: + """Computes adjusted padding for desired ConvTranspose `output_size`. + Ported from DeepMind Haiku. + """ + kernel_size = (kernel_size - 1) * dilation + 1 + if padding == "VALID": + expected_input_size = (output_size - kernel_size + stride) // stride + if input_size != expected_input_size: + raise ValueError( + "The expected input size with the current set of input " + f"parameters is {expected_input_size} which doesn't " + f"match the actual input size {input_size}." + ) + padding_before = 0 + elif padding == "SAME": + expected_input_size = (output_size + stride - 1) // stride + if expected_input_size != input_size: + raise ValueError( + "The expected input size with the current set of input " + f"parameters is {expected_input_size} which doesn't " + f"match the actual input size {input_size}." + ) + padding_needed = max(0, (input_size - 1) * stride + kernel_size - output_size) + padding_before = padding_needed // 2 + else: + padding_before = padding[0] # type: ignore[assignment] + + expanded_input_size = (input_size - 1) * stride + 1 + padded_out_size = output_size + kernel_size - 1 + pad_before = kernel_size - 1 - padding_before + pad_after = padded_out_size - expanded_input_size - pad_before + return (pad_before, pad_after) + + +def gradient_based_conv_transpose( + lhs: Array, + rhs: Array, + strides: Sequence[int], + padding: Union[str, Sequence[Tuple[int, int]]], + output_padding: Optional[Sequence[int]] = None, + output_shape: Optional[Sequence[int]] = None, + dilation: Optional[Sequence[int]] = None, + dimension_numbers: ConvGeneralDilatedDimensionNumbers = None, + transpose_kernel: bool = True, + precision: PrecisionLike = None, +) -> Array: + """Convenience wrapper for calculating the N-d transposed convolution. + Args: + Much like *conv_transpose*, this function calculates transposed convolutions via fractionally strided convolution + rather than calculating the gradient (transpose) of a forward convolution. However, the latter is more common among: + deep learning frameworks, such as TensorFlow, PyTorch, and Keras. This function provides the same set of APIs to help: + reproduce results in these frameworks. + lhs: a rank *n+2* dimensional input array. rhs: a rank *n+2* dimensional array of kernel weights. strides: + sequence of *n* integers, amounts to strides of the corresponding forward convolution. padding: *"SAME"*, + *"VALID"*, or a sequence of *n* integer 2-tuples that controls + the before-and-after padding for each *n* spatial dimension of the corresponding forward convolution. + output_padding: A sequence of integers specifying the amount of padding along + each spacial dimension of the output tensor, used to disambiguate the output shape of transposed convolutions + when the stride is larger than 1. (see a detailed description at + 1https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html) The amount of output padding along a + given dimension must be lower than the stride along that same dimension. If set to *None* (default), the output + shape is inferred. If both *output_padding* and *output_shape* are specified, they have to be mutually + compatible. + output_shape: Output shape of the spatial dimensions of a transpose + convolution. Can be *None* or an iterable of *n* integers. If a *None* value is given (default), the shape is + automatically calculated. Similar to *output_padding*, *output_shape* is also for disambiguating the output + shape when stride > 1 (see also https://www.tensorflow.org/api_docs/python/tf/nn/conv2d_transpose) If both + *output_padding* and *output_shape* are specified, they have to be mutually compatible. + dilation: *None*, or a sequence of *n* integers, giving the + dilation factor to apply in each spatial dimension of *rhs*. Dilated convolution is also known as atrous + convolution. + dimension_numbers: tuple of dimension descriptors as in + lax.conv_general_dilated. Defaults to tensorflow convention. + transpose_kernel: if *True* flips spatial axes and swaps the input/output + channel axes of the kernel. This makes the output of this function identical to the gradient-derived functions + like keras.layers.Conv2DTranspose and torch.nn.ConvTranspose2d applied to the same kernel. Although for typical + use in neural nets this is unnecessary and makes input/output channel specification confusing, you need to set + this to *True* in order to match the behavior in many deep learning frameworks, such as TensorFlow, Keras, and + PyTorch. + precision: Optional. Either `None`, which means the default precision for + the backend, a `lax.Precision` enum value (`Precision.DEFAULT`, `Precision.HIGH` or `Precision.HIGHEST`) or a + tuple of two `lax.Precision` enums indicating precision of ``lhs``` and `rhs`. + Returns: + Transposed N-d convolution. + """ + assert len(lhs.shape) == len(rhs.shape) and len(lhs.shape) >= 2 + ndims = len(lhs.shape) + one = (1,) * (ndims - 2) + # Set dimensional layout defaults if not specified. + if dimension_numbers is None: + if ndims == 2: + dimension_numbers = ("NC", "IO", "NC") + elif ndims == 3: + dimension_numbers = ("NHC", "HIO", "NHC") + elif ndims == 4: + dimension_numbers = ("NHWC", "HWIO", "NHWC") + elif ndims == 5: + dimension_numbers = ("NHWDC", "HWDIO", "NHWDC") + else: + raise ValueError("No 4+ dimensional dimension_number defaults.") + dn = conv_dimension_numbers_(lhs.shape, rhs.shape, dimension_numbers) + # dn = dimension_numbers + # k_shape = jnp.take(jnp.array(rhs.shape), jnp.array(dn.rhs_spec)) + k_shape = np.take(rhs.shape, dn.rhs_spec) + k_sdims = k_shape[2:] # type: ignore[index] + # i_shape = jnp.take(jnp.array(lhs.shape), jnp.array(dn.lhs_spec)) + i_shape = np.take(lhs.shape, dn.lhs_spec) + i_sdims = i_shape[2:] # type: ignore[index] + + # Calculate correct output shape given padding and strides. + if dilation is None: + dilation = (1,) * (rhs.ndim - 2) + + if output_padding is None: + output_padding = [None] * (rhs.ndim - 2) # type: ignore[list-item] + + if isinstance(padding, str): + if padding in {"SAME", "VALID"}: + padding = [padding] * (rhs.ndim - 2) # type: ignore[list-item] + else: + raise ValueError(f"`padding` must be 'VALID' or 'SAME'. Passed: {padding}.") + + inferred_output_shape = tuple( + map(_deconv_output_length, i_sdims, k_sdims, padding, output_padding, strides, dilation) + ) + if output_shape is None: + output_shape = inferred_output_shape # type: ignore[assignment] + else: + if not output_shape == inferred_output_shape: + raise ValueError( + "`output_padding` and `output_shape` are not compatible." + f"Inferred output shape from `output_padding`: {inferred_output_shape}, " + f"but got `output_shape` {output_shape}" + ) + + pads = tuple(map(_compute_adjusted_padding, i_sdims, output_shape, k_sdims, strides, padding, dilation)) + + if transpose_kernel: + # flip spatial dims and swap input / output channel axes + rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:]) + rhs = np.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1]) + return conv_general_dilated(lhs, rhs, one, pads, strides, dilation, dimension_numbers, precision=precision) + + +class ConvTransposeGradient(nn.Module): + """Convolution Module wrapping lax.conv_transpose. Calculates transposed convolutions via gradient (transpose) of a + forward convolution. + + Attributes: + features: number of convolution filters. + kernel_size: shape of the convolutional kernel. For 1D convolution, + the kernel size can be passed as an integer. For all other cases, it must + be a sequence of integers. + strides: a sequence of `n` integers, representing the inter-window strides. + padding: either the string `'SAME'`, the string `'VALID'`, the string + `'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low, + high)` integer pairs that give the padding to apply before and after each + spatial dimension. A single int is interpeted as applying the same padding + in all dims and passign a single int in a sequence causes the same padding + to be used on both sides. + kernel_dilation: `None`, or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of the convolution + kernel. Convolution with kernel dilation is also known as 'atrous + convolution'. + use_bias: whether to add a bias to the output (default: True). + dtype: the dtype of the computation (default: float32). + param_dtype: the dtype passed to parameter initializers (default: float32). + precision: numerical precision of the computation see `jax.lax.Precision` + for details. + kernel_init: initializer for the convolutional kernel. + bias_init: initializer for the bias. + """ + + features: int + kernel_size: Union[int, Tuple[int, ...]] + strides: Optional[Tuple[int, ...]] = None + padding: PaddingLike = (0, 0) + kernel_dilation: Optional[Sequence[int]] = None + use_bias: bool = True + dtype: jnp.dtype = jnp.float32 + param_dtype: Dtype = jnp.float32 + precision: PrecisionLike = None + kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init + bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros + + @compact + def __call__(self, inputs: Array) -> Array: + """Applies a transposed convolution to the inputs. + Here we mimic the implementation of the ConvTranspose module: + https://flax.readthedocs.io/en/latest/_modules/flax/linen/linear.html#ConvTranspose and define the `__call__` + method with the @compact decorator + + Behaviour mirrors of `jax.lax.conv_transpose`, computing transposed convolutions via the gradient (transpose) + of a forward convolutions. + + Args: + inputs: input data with dimensions (batch, spatial_dims..., features). + This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: + this is different from the input convention used by `lax.conv_general_dilated`, which puts the spatial + dimensions last. + + Returns: + The convolved data. + """ + inputs = jnp.asarray(inputs, self.dtype) + + kernel_size: Tuple[int, ...] + if isinstance(self.kernel_size, int): + kernel_size = (self.kernel_size,) + else: + kernel_size = self.kernel_size + + is_single_input = False + if inputs.ndim == len(kernel_size) + 1: + is_single_input = True + inputs = jnp.expand_dims(inputs, axis=0) + + strides: Tuple[int, ...] + strides = self.strides or (1,) * (inputs.ndim - 2) + + in_features = inputs.shape[-1] + kernel_shape = kernel_size + (in_features, self.features) + kernel = self.param("kernel", self.kernel_init, kernel_shape, self.param_dtype) + kernel = jnp.asarray(kernel, self.dtype) + + padding_lax = canonicalize_padding(self.padding, len(kernel_size)) + if padding_lax == "CIRCULAR": + padding_lax = "VALID" + + y = gradient_based_conv_transpose( + inputs, kernel, strides, padding_lax, dilation=self.kernel_dilation, precision=self.precision + ) + + if self.padding == "CIRCULAR": + # For circular padding, we need to identify the size of the final output + # ("period") along each spatial dimension, pad each dimension to an + # integer number of periods, and wrap the array periodically around each + # dimension. Padding should be done in such a way that the start of the + # original input data inside the padded array is located at integer + # number of periods - otherwise the result would be circularly shifted. + + # Compute period along each spatial dimension - it's input size scaled + # by the stride. + scaled_x_dims = [x_dim * stride for x_dim, stride in zip(inputs.shape[1:-1], strides)] + # Compute difference between the current size of y and the final output + # size, and complement this difference to 2 * period - that gives how + # much we need to pad. + size_diffs = [-(y_dim - x_dim) % (2 * x_dim) for y_dim, x_dim in zip(y.shape[1:-1], scaled_x_dims)] + # Divide the padding equaly between left and right. The choice to put + # "+1" on the left (and not on the right) represents a convention for + # aligning even-sized kernels. + total_pad = [((size_diff + 1) // 2, size_diff // 2) for size_diff in size_diffs] + y = np.pad(y, [(0, 0)] + total_pad + [(0, 0)]) + # Wrap the result periodically around each spatial dimension, + # one by one. + for i in range(1, y.ndim - 1): + y = y.reshape(y.shape[:i] + (-1, scaled_x_dims[i - 1]) + y.shape[i + 1 :]) + y = y.sum(axis=i) + + if is_single_input: + y = jnp.squeeze(y, axis=0) + if self.use_bias: + bias = self.param("bias", self.bias_init, (self.features,), self.param_dtype) + bias = jnp.asarray(bias, self.dtype) + y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) + return y diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index 7dfa244805ff..927818e3a84a 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -582,10 +582,10 @@ class DPTFeatureFusionLayer(nn.Module): The align_corner setting for bilinear upsample. """ - def __init__(self, config, align_corners=True): + def __init__(self, config): super().__init__() - self.align_corners = align_corners + self.align_corners = config.align_corners self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True) @@ -832,7 +832,7 @@ def __init__(self, config): features = config.fusion_hidden_size self.head = nn.Sequential( nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), - nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=self.config.align_corners), nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), ACT2FN["relu"], nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), @@ -979,7 +979,7 @@ def __init__(self, config): ACT2FN["relu"], nn.Dropout(config.semantic_classifier_dropout), nn.Conv2d(features, config.num_labels, kernel_size=1), - nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=self.config.align_corners), ) def forward(self, hidden_states: List[torch.Tensor]) -> torch.Tensor: diff --git a/src/transformers/models/dpt/modeling_flax_dpt.py b/src/transformers/models/dpt/modeling_flax_dpt.py new file mode 100644 index 000000000000..ea7edad078d1 --- /dev/null +++ b/src/transformers/models/dpt/modeling_flax_dpt.py @@ -0,0 +1,1243 @@ +# coding=utf-8 +# Copyright 2022 Intel Labs, OpenMMLab and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Flax DPT (Dense Prediction Transformers) model. + +TThis implementation is heavily inspired by OpenMMLab's implementation, found here: +https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/dpt_head.py. + +""" +import math +from typing import Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +import optax +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict + +from ...modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxBaseModelOutputWithPooling, + FlaxDepthEstimatorOutput, + FlaxSemanticSegmenterOutput, +) +from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward +from .configuration_dpt import DPTConfig +from .gradient_convolution import ConvTransposeGradient + + +DPT_START_DOCSTRING = r""" + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading, saving and converting weights from Flax models) This + model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + general usage and behavior. Finally, this model supports inherent JAX features such as: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + Parameters: + config ([`DPTConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). This can be used to enable mixed-precision training or half-precision + inference on GPUs or TPUs. If specified all the computation will be performed with the given `dtype`. + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] + and [`~FlaxPreTrainedModel.to_bf16`]. +""" + +DPT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`DPTFeatureExtractor`]. See + [`DPTFeatureExtractor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTPatchEmbeddings with ViT->DPT +class FlaxDPTPatchEmbeddings(nn.Module): + + config: DPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + image_size = self.config.image_size + patch_size = self.config.patch_size + num_patches = (image_size // patch_size) * (image_size // patch_size) + self.num_patches = num_patches + self.num_channels = self.config.num_channels + self.projection = nn.Conv( + self.config.hidden_size, + kernel_size=(patch_size, patch_size), + strides=(patch_size, patch_size), + padding="VALID", + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + + def __call__(self, pixel_values): + num_channels = pixel_values.shape[-1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + embeddings = self.projection(pixel_values) + batch_size, _, _, channels = embeddings.shape + return jnp.reshape(embeddings, (batch_size, -1, channels)) + + +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTEmbeddings with ViT->DPT +class FlaxDPTEmbeddings(nn.Module): + """Construct the CLS token, position and patch embeddings.""" + + config: DPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.cls_token = self.param("cls_token", nn.initializers.zeros, (1, 1, self.config.hidden_size)) + self.patch_embeddings = FlaxDPTPatchEmbeddings(self.config, dtype=self.dtype) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = self.param( + "position_embeddings", nn.initializers.zeros, (1, num_patches + 1, self.config.hidden_size) + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, pixel_values, deterministic=True): + batch_size = pixel_values.shape[0] + + embeddings = self.patch_embeddings(pixel_values) + + cls_tokens = jnp.broadcast_to(self.cls_token, (batch_size, 1, self.config.hidden_size)) + embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1) + embeddings = embeddings + self.position_embeddings + embeddings = self.dropout(embeddings, deterministic=deterministic) + return embeddings + + +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTSelfAttention with ViT->DPT +class FlaxDPTSelfAttention(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + if self.config.hidden_size % self.config.num_attention_heads != 0: + raise ValueError( + "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`:" + " {self.config.num_attention_heads}" + ) + + self.query = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + use_bias=self.config.qkv_bias, + ) + self.key = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + use_bias=self.config.qkv_bias, + ) + self.value = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + use_bias=self.config.qkv_bias, + ) + + def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False): + head_dim = self.config.hidden_size // self.config.num_attention_heads + + query_states = self.query(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + value_states = self.value(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + key_states = self.key(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + + dropout_rng = None + if not deterministic and self.config.attention_probs_dropout_prob > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_probs_dropout_prob, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) + + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTOutput with ViT->DPT +class FlaxDPTViTOutput(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, attention_output, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = hidden_states + attention_output + return hidden_states + + +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTSelfOutput with ViT->DPT +class FlaxDPTSelfOutput(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, input_tensor, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + return hidden_states + + +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTAttention with ViT->DPT +class FlaxDPTViTAttention(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.attention = FlaxDPTSelfAttention(self.config, dtype=self.dtype) + self.output = FlaxDPTSelfOutput(self.config, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic=True, output_attentions: bool = False): + attn_outputs = self.attention(hidden_states, deterministic=deterministic, output_attentions=output_attentions) + attn_output = attn_outputs[0] + hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_outputs[1],) + + return outputs + + +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTIntermediate with ViT->DPT +class FlaxDPTViTIntermediate(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.intermediate_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.activation = ACT2FN[self.config.hidden_act] + + def __call__(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +# DPT reassemble & Fusion +class FlaxDPTReassembleLayer(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + factor: int = 1 + channels: int = None + + def setup(self): + # projection + self.projection = nn.Conv( + self.channels, + kernel_size=(1, 1), + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + use_bias=True, + ) + + # up/down sampling depending on factor + if self.factor > 1: + self.resize = ConvTransposeGradient( + self.channels, + kernel_size=(self.factor, self.factor), + strides=(self.factor, self.factor), + use_bias=True, + padding="SAME", + ) + elif self.factor < 1: + # so should downsample + self.resize = nn.Conv( + self.channels, + kernel_size=(3, 3), + strides=(int(1 / self.factor), int(1 / self.factor)), + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + padding=(1, 1), + ) + + def __call__(self, hidden_state): + hidden_state = self.projection(hidden_state) + if self.factor != 1: + hidden_state = self.resize(hidden_state) + return hidden_state + + +class FlaxDPTReassembleLayerCollection(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.layers = [ + FlaxDPTReassembleLayer(self.config, factor=factor, channels=self.config.neck_hidden_sizes[i], name=str(i)) + for i, factor in zip(range(len(self.config.neck_hidden_sizes)), self.config.reassemble_factors) + ] + + def __call__(self, x, i): + return self.layers[i](x) + + +class FlaxDPTReadoutProjectSequentialCollectionLayer(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + # Here we set the name of the layer to be explicitly `0` because in the + # Pytorch modeling script we use nn.Sequential layer that sets the + # key of this dense layer to 0 by default. + self.dense = nn.Dense(self.config.hidden_size, name="0") + self.act = ACT2FN[self.config.hidden_act] + + def __call__(self, x): + x = self.dense(x) + return self.act(x) + + +class FlaxDPTReadoutProjectCollectionLayer(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.layers = [ + FlaxDPTReadoutProjectSequentialCollectionLayer(self.config, self.dtype, name=str(i)) + for i in range(len(self.config.neck_hidden_sizes)) + ] + + def __call__(self, hidden_states, i): + return self.layers[i](hidden_states) + + +class FlaxDPTReassembleStage(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + + self.layers = FlaxDPTReassembleLayerCollection(self.config, self.dtype) + + if self.config.readout_type == "project": + self.readout_projects = FlaxDPTReadoutProjectCollectionLayer(self.config, self.dtype) + + def __call__(self, hidden_states): + """ + Args: + hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length + 1, hidden_size)`): + List of hidden states from the backbone. + """ + out = [] + + for i, hidden_state in enumerate(hidden_states): + # reshape to (B, C, H, W) + hidden_state, cls_token = hidden_state[:, 1:], hidden_state[:, 0] + batch_size, sequence_length, num_channels = hidden_state.shape + size = int(math.sqrt(sequence_length)) + hidden_state = jnp.reshape(hidden_state, (batch_size, size, size, num_channels)) + + feature_shape = hidden_state.shape + if self.config.readout_type == "project": + # reshape to (B, H*W, C) + hidden_state = jnp.reshape(hidden_state, (batch_size, size * size, num_channels)) + readout = jnp.expand_dims(cls_token, axis=1) + readout = jnp.repeat(readout, size * size, axis=1) + # concatenate the readout token to the hidden states and project + hidden_state = self.readout_projects(jnp.concatenate((hidden_state, readout), axis=-1), i) + # reshape back to (B, C, H, W) + hidden_state = jnp.reshape(hidden_state, feature_shape) + elif self.config.readout_type == "add": + hidden_state = jnp.reshape(hidden_state, (batch_size, size * size, num_channels)) + jnp.expand_dims( + cls_token, axis=-1 + ) + hidden_state = jnp.reshape(hidden_state, feature_shape) + # hidden_state = self.layers[i](hidden_state) + hidden_state = self.layers(hidden_state, i) + out.append(hidden_state) + + return out + + +class FlaxDPTFeatureFusionStage(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + super().__init__() + self.layers = FlaxDPTFeatureFusionLayerCollection(self.config, self.dtype) + + def __call__(self, hidden_states): + # reversing the hidden_states, we start from the last + hidden_states = hidden_states[::-1] + + fused_hidden_states = [] + # first layer only uses the last hidden_state + fused_hidden_state = self.layers(hidden_states[0], residual=None, i=0) + fused_hidden_states.append(fused_hidden_state) + # looping from the last layer to the second + for i, hidden_state in enumerate(hidden_states[1:]): + fused_hidden_state = self.layers(fused_hidden_state, residual=hidden_state, i=i + 1) + fused_hidden_states.append(fused_hidden_state) + return fused_hidden_states + + +class FlaxDPTPreActResidualLayer(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + + self.use_batch_norm = self.config.use_batch_norm_in_fusion_residual + self.activation1 = ACT2FN["relu"] + self.convolution1 = nn.Conv( + self.config.fusion_hidden_size, + kernel_size=(3, 3), + strides=(1, 1), + padding="SAME", + use_bias=not self.use_batch_norm, + ) + + self.activation2 = ACT2FN["relu"] + self.convolution2 = nn.Conv( + self.config.fusion_hidden_size, + kernel_size=(3, 3), + strides=(1, 1), + padding="SAME", + use_bias=not self.use_batch_norm, + ) + + if self.use_batch_norm: + self.batch_norm1 = nn.BatchNorm(use_running_average=True) + self.batch_norm2 = nn.BatchNorm(use_running_average=True) + + def __call__(self, hidden_state): + residual = hidden_state + hidden_state = self.activation1(hidden_state) + + hidden_state = self.convolution1(hidden_state) + + if self.use_batch_norm: + hidden_state = self.batch_norm1(hidden_state) + + hidden_state = self.activation2(hidden_state) + hidden_state = self.convolution2(hidden_state) + + if self.use_batch_norm: + hidden_state = self.batch_norm2(hidden_state) + + return hidden_state + residual + + +class FlaxDPTFeatureFusionLayerCollection(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.layers = [ + FlaxDPTFeatureFusionLayer(self.config, name=str(i)) for i in range(len(self.config.neck_hidden_sizes)) + ] + + def __call__(self, hidden_states, residual=None, i=0): + return self.layers[i](hidden_states, residual) + + +class FlaxDPTFeatureFusionLayer(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.projection = nn.Conv(self.config.fusion_hidden_size, kernel_size=(1, 1), use_bias=True) + + self.residual_layer1 = FlaxDPTPreActResidualLayer(self.config) + self.residual_layer2 = FlaxDPTPreActResidualLayer(self.config) + self.upsample = FlaxDPTUpsample() + + def __call__(self, hidden_state, residual=None): + if residual is not None: + if hidden_state.shape != residual.shape: + size = hidden_state.shape + residual = self.upsample(residual, output_size=size) + hidden_state = hidden_state + self.residual_layer1(residual) + + hidden_state = self.residual_layer2(hidden_state) + hidden_state = self.upsample(hidden_state) + hidden_state = self.projection(hidden_state) + + return hidden_state + + +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTLayer with ViTConfig->DPTConfig, ViTAttention->DPTViTAttention, ViTIntermediate->DPTViTIntermediate, ViTOutput->DPTViTOutput +class FlaxDPTViTLayer(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.attention = FlaxDPTViTAttention(self.config, dtype=self.dtype) + self.intermediate = FlaxDPTViTIntermediate(self.config, dtype=self.dtype) + self.output = FlaxDPTViTOutput(self.config, dtype=self.dtype) + self.layernorm_before = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.layernorm_after = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False): + attention_outputs = self.attention( + self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention + deterministic=deterministic, + output_attentions=output_attentions, + ) + + attention_output = attention_outputs[0] + + # first residual connection + attention_output = attention_output + hidden_states + + # in ViT, layernorm is also applied after self-attention + layer_output = self.layernorm_after(attention_output) + + hidden_states = self.intermediate(layer_output) + hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attention_outputs[1],) + return outputs + + +class FlaxDPTViTLayerCollection(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxDPTViTLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer(hidden_states, deterministic=deterministic, output_attentions=output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTPooler with ViT->DPT +class FlaxDPTViTPooler(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + cls_hidden_state = hidden_states[:, 0] + cls_hidden_state = self.dense(cls_hidden_state) + return nn.tanh(cls_hidden_state) + + +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTEncoder with ViTConfig->DPTConfig, FlaxViTLayerCollection->FlaxDPTViTLayerCollection +class FlaxDPTViTEncoder(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layer = FlaxDPTViTLayerCollection(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_states, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + return self.layer( + hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +DPT_START_DOCSTRING = r""" + This model is a Flax [jax.nn.Module](https://jax.readthedocs.io/en/latest/jax.nn.html?highlight=nn.Module) + subclass. Use it as a regular Flax Module and refer to the Flax documentation for all matter related to general + usage and behavior. + + Parameters: + config ([`DPTConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DPT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`jax.numpy.array` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`ViTFeatureExtractor`]. See + [`ViTFeatureExtractor.__call__`] for details. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. +""" + + +class FlaxDPTPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DPTConfig + base_model_prefix = "dpt" + main_input_name = "pixel_values" + module_class: nn.Module = None + + def __init__( + self, + config: DPTConfig, + input_shape=None, + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + if input_shape is None: + input_shape = (1, config.image_size, config.image_size, 3) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + pixel_values = jnp.zeros(input_shape, dtype=self.dtype) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + @add_start_docstrings_to_model_forward(DPT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__( + self, + pixel_values, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[jnp.ndarray] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + return self.module.apply( + {"params": params or self.params}, + jnp.array(pixel_values, dtype=jnp.float32), + not train, + output_attentions, + output_hidden_states, + return_dict, + labels, + rngs=rngs, + mutable=["batch_stats"], + )[0] + + +class FlaxDPTModule(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + add_pooling_layer: bool = True + + def setup(self): + self.embeddings = FlaxDPTEmbeddings(self.config, dtype=self.dtype) + self.encoder = FlaxDPTViTEncoder(self.config, dtype=self.dtype) + self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.pooler = FlaxDPTViTPooler(self.config, dtype=self.dtype) if self.add_pooling_layer else None + + def __call__( + self, + pixel_values, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + labels: Optional[jnp.ndarray] = None, + ): + + hidden_states = self.embeddings(pixel_values, deterministic=deterministic) + + outputs = self.encoder( + hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + hidden_states = self.layernorm(hidden_states) + pooled = self.pooler(hidden_states) if self.add_pooling_layer else None + + if not return_dict: + # if pooled is None, don't return it + if pooled is None: + return (hidden_states,) + outputs[1:] + return (hidden_states, pooled) + outputs[1:] + + return FlaxBaseModelOutputWithPooling( + last_hidden_state=hidden_states, + pooler_output=pooled, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + DPT Model with a semantic segmentation head on top e.g. for ADE20k, CityScapes. + """, + DPT_START_DOCSTRING, +) +class FlaxDPTModel(FlaxDPTPreTrainedModel): + module_class = FlaxDPTModule + + +class FlaxDPTConvCollection(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.convs = [ + nn.Conv(self.config.fusion_hidden_size, kernel_size=(3, 3), padding=1, use_bias=False, name=str(i)) + for i in range(len(self.config.neck_hidden_sizes)) + ] + + def __call__(self, features): + return [self.convs[i](feature) for i, feature in enumerate(features)] + + +class FlaxDPTNeck(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + # postprocessing + self.reassemble_stage = FlaxDPTReassembleStage(self.config, self.dtype) + self.convs = FlaxDPTConvCollection(self.config, self.dtype) + # fusion + self.fusion_stage = FlaxDPTFeatureFusionStage(self.config) + + def __call__(self, hidden_states): + if not isinstance(hidden_states, list): + raise ValueError("hidden_states should be a list of tensors") + + if len(hidden_states) != len(self.config.neck_hidden_sizes): + raise ValueError("The number of hidden states should be equal to the number of neck hidden sizes.") + + # postprocess hidden states + features = self.reassemble_stage(hidden_states) + + features = self.convs(features) + + # fusion blocks + output = self.fusion_stage(features) + + return output + + +class FlaxDPTUpsample(nn.Module): + scale: int = 2 + method: str = "bilinear" + + def setup(self): + pass + + def __call__(self, x, output_size=None): + if output_size is None: + output_size = x.shape + output_size = (output_size[0], output_size[1] * self.scale, output_size[2] * self.scale, output_size[3]) + return jax.image.resize(x, output_size, method=self.method) + + +class FlaxDPTDepthEstimationHeadCollectionLayer(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv1 = nn.Conv( + self.config.fusion_hidden_size // 2, kernel_size=(3, 3), strides=(1, 1), padding="SAME", name="0" + ) + + self.upsample = FlaxDPTUpsample(scale=2, method="bilinear") + + self.conv2 = nn.Conv(32, kernel_size=(3, 3), strides=(1, 1), padding="SAME", name="2") + + self.act = ACT2FN["relu"] + + self.conv3 = nn.Conv(1, kernel_size=(1, 1), strides=(1, 1), padding="VALID", name="4") + + def __call__(self, hidden_state): + x = self.conv1(hidden_state) + x = self.upsample(x) + x = self.conv2(x) + x = self.act(x) + x = self.conv3(x) + return self.act(x) + + +class FlaxDPTDepthEstimationHead(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.head = FlaxDPTDepthEstimationHeadCollectionLayer(self.config, self.dtype) + + def __call__(self, hidden_states): + # use last features + hidden_states = hidden_states[self.config.head_in_index] + + predicted_depth = self.head(hidden_states) + + predicted_depth = jnp.squeeze(predicted_depth, -1) + + return predicted_depth + + +class FlaxDPTForDepthEstimationModule(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + + self.dpt = FlaxDPTModule(self.config, add_pooling_layer=False) + + # Neck + self.neck = FlaxDPTNeck(self.config) + + # Depth estimation head + self.head = FlaxDPTDepthEstimationHead(self.config) + + @add_start_docstrings_to_model_forward(DPT_INPUTS_DOCSTRING) + def __call__( + self, + pixel_values, + deterministic: bool = True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + labels=None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth depth estimation maps for computing the loss. + + Returns: + + Examples: + ```python + >>> from transformers import DPTFeatureExtractor, DPTForDepthEstimation + >>> import torch + >>> import numpy as np + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large") + >>> model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large") + + >>> # prepare image for the model + >>> inputs = feature_extractor(images=image, return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + ... predicted_depth = outputs.predicted_depth + + >>> # interpolate to original size + >>> prediction = torch.nn.functional.interpolate( + ... predicted_depth.unsqueeze(1), + ... size=image.size[::-1], + ... mode="bilinear", + ... align_corners=False, + ... ) + + >>> # visualize the prediction + >>> output = prediction.squeeze().cpu().numpy() + >>> formatted = (output * 255 / np.max(output)).astype("uint8") + >>> depth = Image.fromarray(formatted) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + outputs = self.dpt( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=True, # we need the intermediate hidden states + return_dict=return_dict, + ) + + hidden_states = outputs.hidden_states if return_dict else outputs + + # only keep certain features based on config.backbone_out_indices + # note that the hidden_states also include the initial embeddings + if return_dict: + hidden_states = [ + feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices + ] + else: + hidden_states = [ + feature for idx, feature in enumerate(hidden_states[1][1:]) if idx in self.config.backbone_out_indices + ] + + hidden_states = self.neck(hidden_states) + + predicted_depth = self.head(hidden_states) + + loss = None + if labels is not None: + raise NotImplementedError("Training is not implemented yet") + + if not return_dict: + if output_hidden_states: + output = (predicted_depth,) + outputs[1:] + else: + output = (predicted_depth,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return FlaxDepthEstimatorOutput( + loss=loss, + predicted_depth=predicted_depth, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + DPT Model with a depth estimation head on top (consisting of 3 convolutional layers) e.g. for KITTI, NYUv2. + """, + DPT_START_DOCSTRING, +) +class FlaxDPTForDepthEstimation(FlaxDPTPreTrainedModel): + module_class = FlaxDPTForDepthEstimationModule + + +class FlaxDPTSemanticSegmentationHeadCollectionLayer(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv1 = nn.Conv(self.config.fusion_hidden_size, kernel_size=(3, 3), padding=1, name="0", use_bias=False) + self.bn = nn.BatchNorm(use_running_average=True, name="1") + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.act = ACT2FN["relu"] + self.conv2 = nn.Conv(self.config.num_labels, kernel_size=(1, 1), name="4") + self.upsample = FlaxDPTUpsample(scale=2, method="bilinear") + + def __call__(self, hidden_states, deterministic=True): + x = self.conv1(hidden_states) + x = self.bn(x) + x = self.act(x) + x = self.dropout(x, deterministic=deterministic) + x = self.conv2(x) + x = self.upsample(x) + return x + + +class FlaxDPTSemanticSegmentationHead(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + # TODO: Change this to make sure BatchNorm + Dropout work / Put them outside a Sequential Module + def setup(self): + self.head = FlaxDPTSemanticSegmentationHeadCollectionLayer(self.config, self.dtype) + + # @nn.compact + def __call__(self, hidden_states, deterministic=True): + # use last features + hidden_states = hidden_states[self.config.head_in_index] + + logits = self.head(hidden_states, deterministic=deterministic) + return jnp.transpose(logits, (0, 3, 1, 2)) + + +class FlaxDPTAuxiliaryHeadCollectionLayer(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv1 = nn.Conv(self.config.fusion_hidden_size, kernel_size=(3, 3), padding=1, name="0", use_bias=False) + self.bn = nn.BatchNorm(use_running_average=True, name="1") + self.act = ACT2FN["relu"] + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.conv2 = nn.Conv(self.config.num_labels, kernel_size=(1, 1), name="4") + + def __call__(self, hidden_states, deterministic=True): + x = self.conv1(hidden_states) + x = self.bn(x) + x = self.act(x) + x = self.dropout(x, deterministic=deterministic) + x = self.conv2(x) + return x + + +class FlaxDPTAuxiliaryHead(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.head = FlaxDPTAuxiliaryHeadCollectionLayer(self.config, self.dtype) + + def __call__(self, hidden_states, deterministic=True): + logits = self.head(hidden_states, deterministic=deterministic) + + return jnp.transpose(logits, (0, 3, 1, 2)) + + +class FlaxDPTForSemanticSegmentationModule(nn.Module): + config: DPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + + self.dpt = FlaxDPTModule(self.config, add_pooling_layer=False) + + # Neck + self.neck = FlaxDPTNeck(self.config) + + # Segmentation head(s) + self.head = FlaxDPTSemanticSegmentationHead(self.config) + self.auxiliary_head = FlaxDPTAuxiliaryHead(self.config) if self.config.use_auxiliary_head else None + + self.upsample = FlaxDPTUpsample(scale=2, method="bilinear") + + @add_start_docstrings_to_model_forward(DPT_INPUTS_DOCSTRING) + def __call__( + self, + pixel_values=None, + deterministic: bool = True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + labels=None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + ```python + >>> from transformers import DPTFeatureExtractor, DPTForSemanticSegmentation + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large-ade") + >>> model = DPTForSemanticSegmentation.from_pretrained("Intel/dpt-large-ade") + + >>> inputs = feature_extractor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + outputs = self.dpt( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=True, # we need the intermediate hidden states + return_dict=return_dict, + ) + + hidden_states = outputs.hidden_states if return_dict else outputs + + # only keep certain features based on config.backbone_out_indices + # note that the hidden_states also include the initial embeddings + if return_dict: + hidden_states = [ + feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices + ] + else: + hidden_states = [ + feature for idx, feature in enumerate(hidden_states[1][1:]) if idx in self.config.backbone_out_indices + ] + hidden_states = self.neck(hidden_states) + + logits = self.head(hidden_states) + + auxiliary_logits = None + if self.auxiliary_head is not None: + auxiliary_logits = self.auxiliary_head(hidden_states[-1]) + + loss = None + if labels is not None: + if self.config.num_labels == 1: + raise ValueError("The number of labels should be greater than one") + else: + # upsample logits to the images' original size + output_shape = (logits.shape[0], labels.shape[1], labels.shape[2], logits.shape[3]) + upsampled_logits = self.upsample(logits, output_shape) + + if auxiliary_logits is not None: + upsampled_auxiliary_logits = self.upsample(auxiliary_logits, output_shape) + # compute weighted loss + # Copied from: https://flax.readthedocs.io/en/latest/notebooks/annotated_mnist.html + labels_onehot = jax.nn.one_hot(labels, num_classes=self.config.num_labels) + main_loss = optax.softmax_cross_entropy(logits=upsampled_logits, labels=labels_onehot).mean() + auxiliary_loss = optax.softmax_cross_entropy( + logits=upsampled_auxiliary_logits, labels=labels_onehot + ).mean() + loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss + + if not return_dict: + if output_hidden_states: + output = (logits,) + outputs[1:] + else: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return FlaxSemanticSegmenterOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + DPT Model with a semantic segmentation head on top e.g. for ADE20k, CityScapes. + """, + DPT_START_DOCSTRING, +) +class FlaxDPTForSemanticSegmentation(FlaxDPTPreTrainedModel): + module_class = FlaxDPTForSemanticSegmentationModule diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 953808dab8ad..a45c9731362a 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -592,6 +592,34 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) +class FlaxDPTForDepthEstimation(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxDPTForSemanticSegmentation(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxDPTModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxDPTPreTrainedModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxElectraForCausalLM(metaclass=DummyObject): _backends = ["flax"] diff --git a/tests/models/dpt/test_modeling_dpt.py b/tests/models/dpt/test_modeling_dpt.py index 3266ea78a71a..5c53ee9e2a43 100644 --- a/tests/models/dpt/test_modeling_dpt.py +++ b/tests/models/dpt/test_modeling_dpt.py @@ -19,9 +19,9 @@ import unittest from transformers import DPTConfig -from transformers.file_utils import is_torch_available, is_vision_available +from transformers.file_utils import is_flax_available, is_torch_available, is_vision_available from transformers.models.auto import get_values -from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.testing_utils import is_pt_flax_cross_test, require_torch, require_vision, slow, torch_device from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor @@ -40,6 +40,18 @@ from transformers import DPTFeatureExtractor +if is_torch_available() and is_flax_available(): + import tempfile + + import numpy as np + + import jax.numpy as jnp + import transformers + from transformers.modeling_flax_pytorch_utils import ( + convert_pytorch_state_dict_to_flax, + load_flax_weights_in_pytorch_model, + ) + class DPTModelTester: def __init__( @@ -239,6 +251,159 @@ def test_training_gradient_checkpointing(self): loss = model(**inputs).loss loss.backward() + @is_pt_flax_cross_test + def test_equivalence_pt_to_flax(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.align_corners = False + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + fx_model_class_name = "Flax" + model_class.__name__ + + if not hasattr(transformers, fx_model_class_name): + # no flax model exists for this class + return + + # Output all for aggressive testing + config.output_hidden_states = True + config.output_attentions = self.has_attentions + + fx_model_class = getattr(transformers, fx_model_class_name) + + # load PyTorch class + pt_model = model_class(config).eval() + # Flax models don't use the `use_cache` option and cache is not returned as a default. + # So we disable `use_cache` here for PyTorch model. + pt_model.config.use_cache = False + + # load Flax class + fx_model = fx_model_class(config, dtype=jnp.float32) + + # make sure only flax inputs are forward that actually exist in function args + fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() + + # prepare inputs + pt_inputs = self._prepare_for_class(inputs_dict, model_class) + + # remove function args that don't exist in Flax + pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} + + # send pytorch inputs to the correct device + pt_inputs = { + k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items() + } + + # convert inputs to Flax + fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} + + fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) + fx_model.params = fx_state + + # send pytorch model to the correct device + pt_model.to(torch_device) + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs) + fx_outputs = fx_model(**fx_inputs) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class) + + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, from_pt=True) + + fx_outputs_loaded = fx_model_loaded(**fx_inputs) + + fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class) + + @is_pt_flax_cross_test + def test_equivalence_flax_to_pt(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.align_corners = False + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + fx_model_class_name = "Flax" + model_class.__name__ + + if not hasattr(transformers, fx_model_class_name): + # no flax model exists for this class + return + + # Output all for aggressive testing + config.output_hidden_states = True + config.output_attentions = self.has_attentions + + fx_model_class = getattr(transformers, fx_model_class_name) + + # load PyTorch class + pt_model = model_class(config).eval() + # Flax models don't use the `use_cache` option and cache is not returned as a default. + # So we disable `use_cache` here for PyTorch model. + pt_model.config.use_cache = False + + # load Flax class + fx_model = fx_model_class(config, dtype=jnp.float32) + + # make sure only flax inputs are forward that actually exist in function args + fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() + + # prepare inputs + pt_inputs = self._prepare_for_class(inputs_dict, model_class) + + # remove function args that don't exist in Flax + pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} + + # send pytorch inputs to the correct device + pt_inputs = { + k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items() + } + + # convert inputs to Flax + fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} + + pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) + + # make sure weights are tied in PyTorch + pt_model.tie_weights() + + # send pytorch model to the correct device + pt_model.to(torch_device) + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs) + fx_outputs = fx_model(**fx_inputs) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class) + + with tempfile.TemporaryDirectory() as tmpdirname: + fx_model.save_pretrained(tmpdirname) + pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True) + + # send pytorch model to the correct device + pt_model_loaded.to(torch_device) + pt_model_loaded.eval() + + with torch.no_grad(): + pt_outputs_loaded = pt_model_loaded(**pt_inputs) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class) + @slow def test_model_from_pretrained(self): for model_name in DPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/models/dpt/test_modeling_flax_dpt.py b/tests/models/dpt/test_modeling_flax_dpt.py new file mode 100644 index 000000000000..7ad3c453fc70 --- /dev/null +++ b/tests/models/dpt/test_modeling_flax_dpt.py @@ -0,0 +1,330 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the Flax DPT model. """ + +import inspect +import unittest + +import numpy as np + +from transformers import DPTConfig, is_flax_available +from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow + +from ...test_configuration_common import ConfigTester +from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor + + +if is_flax_available(): + + import jax + from transformers.models.dpt.modeling_flax_dpt import ( + FlaxDPTForDepthEstimation, + FlaxDPTForSemanticSegmentation, + FlaxDPTModel, + ) + + +class FlaxDPTModelTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=2, + image_size=32, + patch_size=16, + num_channels=3, + is_training=True, + use_labels=False, + hidden_size=32, + num_hidden_layers=4, + backbone_out_indices=[0, 1, 2, 3], + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + initializer_range=0.02, + num_labels=3, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.is_training = is_training + self.use_labels = use_labels + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.backbone_out_indices = backbone_out_indices + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.num_labels = num_labels + self.scope = scope + # sequence length of DPT = num_patches + 1 (we add 1 for the [CLS] token) + num_patches = (image_size // patch_size) ** 2 + self.seq_length = num_patches + 1 + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + labels = None + + if self.use_labels: + labels = ids_tensor([self.batch_size, self.image_size, self.image_size], self.num_labels) + + config = DPTConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + backbone_out_indices=self.backbone_out_indices, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + is_decoder=False, + initializer_range=self.initializer_range, + align_corners=False, + ) + + return config, pixel_values, labels + + def create_and_check_model(self, config, pixel_values, labels): + + model = FlaxDPTModel(config=config) + result = model(pixel_values) + # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + image_size = (self.image_size, self.image_size) + patch_size = (self.patch_size, self.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + pixel_values, + labels, + ) = config_and_inputs + inputs_dict = {"pixel_values": pixel_values, "labels": labels} + return config, inputs_dict + + +@require_flax +class FlaxDPTModelTest(FlaxModelTesterMixin, unittest.TestCase): + + all_model_classes = ( + (FlaxDPTModel, FlaxDPTForSemanticSegmentation, FlaxDPTForDepthEstimation) if is_flax_available() else () + ) + + def setUp(self) -> None: + self.model_tester = FlaxDPTModelTester(self) + self.config_tester = ConfigTester(self, config_class=DPTConfig, has_text_modality=False, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + # We neeed to override this test because ViT's forward signature is different than text models. + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.__call__) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + # We need to override this test because ViT expects pixel_values instead of input_ids + def test_jit_compilation(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config) + + @jax.jit + def model_jitted(pixel_values, **kwargs): + return model(pixel_values=pixel_values, **kwargs) + + with self.subTest("JIT Enabled"): + jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple() + + with self.subTest("JIT Disabled"): + with jax.disable_jit(): + outputs = model_jitted(**prepared_inputs_dict).to_tuple() + + self.assertEqual(len(outputs), len(jitted_outputs)) + for jitted_output, output in zip(jitted_outputs, outputs): + self.assertEqual(jitted_output.shape, output.shape) + + @slow + def test_model_from_pretrained(self): + for model_class_name in self.all_model_classes: + model = model_class_name.from_pretrained("Intel/dpt-large", from_pt=True) + outputs = model(np.ones((1, 3, 384, 384))) + self.assertIsNotNone(outputs) + + @is_pt_flax_cross_test + def test_equivalence_pt_to_flax(self): + import tempfile + + import torch + + import jax.numpy as jnp + import transformers + from transformers.modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax + from transformers.testing_utils import torch_device + + # It might be better to put this inside the for loop below (because we modify the config there). + # But logically, it is fine. + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + + # Output all for aggressive testing + config.output_hidden_states = True + config.output_attentions = self.has_attentions + + # prepare inputs + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + prepared_inputs_dict.pop("labels") + pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()} + + # load corresponding PyTorch class + pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning + pt_model_class = getattr(transformers, pt_model_class_name) + + pt_model = pt_model_class(config).eval() + # Flax models don't use the `use_cache` option and cache is not returned as a default. + # So we disable `use_cache` here for PyTorch model. + pt_model.config.use_cache = False + fx_model = model_class(config, dtype=jnp.float32) + + fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) + fx_model.params = fx_state + + # send pytorch model to the correct device + pt_model.to(torch_device) + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs) + fx_outputs = fx_model(**prepared_inputs_dict) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class) + + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True) + + fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict) + + fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class) + + @is_pt_flax_cross_test + def test_equivalence_flax_to_pt(self): + import tempfile + + import torch + + import jax.numpy as jnp + import transformers + from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model + from transformers.testing_utils import torch_device + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + + # Output all for aggressive testing + config.output_hidden_states = True + config.output_attentions = self.has_attentions + + # prepare inputs + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + prepared_inputs_dict.pop("labels") + pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()} + + # load corresponding PyTorch class + pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning + pt_model_class = getattr(transformers, pt_model_class_name) + + pt_model = pt_model_class(config).eval() + # Flax models don't use the `use_cache` option and cache is not returned as a default. + # So we disable `use_cache` here for PyTorch model. + pt_model.config.use_cache = False + fx_model = model_class(config, dtype=jnp.float32) + + pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) + + # make sure weights are tied in PyTorch + pt_model.tie_weights() + + # send pytorch model to the correct device + pt_model.to(torch_device) + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs) + fx_outputs = fx_model(**prepared_inputs_dict) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class) + + with tempfile.TemporaryDirectory() as tmpdirname: + fx_model.save_pretrained(tmpdirname) + pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True) + + # send pytorch model to the correct device + pt_model_loaded.to(torch_device) + pt_model_loaded.eval() + + with torch.no_grad(): + pt_outputs_loaded = pt_model_loaded(**pt_inputs) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class) + + +# TODO: add tests for segmentation and depth estimation (logits) +# @require_vision +# @slow +# def test_model_from_pretrained_example(self): +# model = FlaxDPTForDepthEstimation.from_pretrained("Intel/dpt-large", from_pt=True) +# image = prepare_img() diff --git a/utils/check_repo.py b/utils/check_repo.py index d2271e87ebf1..993863dcabc1 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -154,6 +154,8 @@ "FlaxCLIPTextModel", "FlaxCLIPVisionModel", "FlaxWav2Vec2ForCTC", + "FlaxDPTForSemanticSegmentation", + "FlaxDPTForDepthEstimation", "DetrForSegmentation", "DPRReader", "FlaubertForQuestionAnswering",