Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ Flax), PyTorch, and/or TensorFlow.
| DETR | ❌ | ❌ | ✅ | ❌ | ❌ |
| DistilBERT | ✅ | ✅ | ✅ | ✅ | ✅ |
| DPR | ✅ | ✅ | ✅ | ✅ | ❌ |
| DPT | ❌ | ❌ | ✅ | ❌ | |
| DPT | ❌ | ❌ | ✅ | ❌ | |
| ELECTRA | ✅ | ✅ | ✅ | ✅ | ✅ |
| Encoder decoder | ❌ | ❌ | ✅ | ✅ | ✅ |
| FairSeq Machine-Translation | ✅ | ❌ | ✅ | ❌ | ❌ |
Expand Down
17 changes: 16 additions & 1 deletion docs/source/en/model_doc/dpt.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,19 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The origi
## DPTForSemanticSegmentation

[[autodoc]] DPTForSemanticSegmentation
- forward
- forward

## FlaxDPTForSemanticSegmentation

[[autodoc]] FlaxDPTForSemanticSegmentation
- __call__

## FlaxDPTForDepthEstimation

[[autodoc]] FlaxDPTForDepthEstimation
- __call__

## FlaxDPTModel

[[autodoc]] FlaxDPTModel
- __call__
14 changes: 14 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2745,6 +2745,14 @@
"FlaxDistilBertPreTrainedModel",
]
)
_import_structure["models.dpt"].extend(
[
"FlaxDPTModel",
"FlaxDPTPreTrainedModel",
"FlaxDPTForSemanticSegmentation",
"FlaxDPTForDepthEstimation",
]
)
_import_structure["models.electra"].extend(
[
"FlaxElectraForCausalLM",
Expand Down Expand Up @@ -5101,6 +5109,12 @@
FlaxDistilBertModel,
FlaxDistilBertPreTrainedModel,
)
from .models.dpt import (
FlaxDPTForDepthEstimation,
FlaxDPTForSemanticSegmentation,
FlaxDPTModel,
FlaxDPTPreTrainedModel,
)
from .models.electra import (
FlaxElectraForCausalLM,
FlaxElectraForMaskedLM,
Expand Down
69 changes: 69 additions & 0 deletions src/transformers/modeling_flax_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,3 +640,72 @@ class FlaxSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
encoder_last_hidden_state: Optional[jnp.ndarray] = None
encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
encoder_attentions: Optional[Tuple[jnp.ndarray]] = None


@flax.struct.dataclass
class FlaxDepthEstimatorOutput(ModelOutput):
"""
Base class for outputs of depth estimation models.

Args:
loss (`jnp.ndarray` of shape `(1,)`, *optional*, returned when `labels` is provided):
Depth Estimation loss.
predicted_depth (`jnp.ndarray` of shape `(batch_size, height, width)`):
Predicted depth for each pixel.

hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
for the output of each layer) of shape `(batch_size, num_channels, height, width)`.

Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, patch_size,
sequence_length)`.

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""

loss: jnp.ndarray = None
predicted_depth: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None


@flax.struct.dataclass
class FlaxSemanticSegmenterOutput(ModelOutput):
"""
Base class for outputs of depth estimation models.

Args:
loss (`jnp.ndarray` of shape `(1,)`, *optional*, returned when `labels` is provided):
Semantic Segmentation loss (CrossEntropy Loss).
logits (`jnp.ndarray` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
Semantic Segmentation raw logits for each pixel.

<Tip warning={true}>

The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
original image size as post-processing. You should always check your logits shape and resize as needed.

</Tip>

hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
for the output of each layer) of shape `(batch_size, num_channels, height, width)`.

Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, patch_size,
sequence_length)`.

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""

loss: jnp.ndarray = None
logits: jnp.ndarray = None
predicted_depth: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
36 changes: 36 additions & 0 deletions src/transformers/models/auto/modeling_flax_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
("blenderbot-small", "FlaxBlenderbotSmallModel"),
("clip", "FlaxCLIPModel"),
("distilbert", "FlaxDistilBertModel"),
("dpt", "FlaxDPTModel"),
("electra", "FlaxElectraModel"),
("gpt2", "FlaxGPT2Model"),
("gpt_neo", "FlaxGPTNeoModel"),
Expand Down Expand Up @@ -211,6 +212,17 @@
]
)

FLAX_MODEL_FOR_DEPTH_ESTIMATION_MAPPING = OrderedDict(
[
("dpt", "FlaxDPTForDepthEstimation"),
]
)

FLAX_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = OrderedDict(
[
("dpt", "FlaxDPTForSemanticSegmentation"),
]
)

FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES)
FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
Expand Down Expand Up @@ -241,6 +253,12 @@
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
)
FLAX_MODEL_FOR_DEPTH_ESTIMATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_DEPTH_ESTIMATION_MAPPING
)
FLAX_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
)


class FlaxAutoModel(_BaseAutoModelClass):
Expand Down Expand Up @@ -344,3 +362,21 @@ class FlaxAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
FlaxAutoModelForSpeechSeq2Seq = auto_class_update(
FlaxAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
)


class FlaxAutoModelForDepthEstimation(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_DEPTH_ESTIMATION_MAPPING


FlaxAutoModelForDepthEstimation = auto_class_update(
FlaxAutoModelForDepthEstimation, head_doc="depth estimation modeling"
)


class FlaxAutoModelForSemanticSegmentation(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING


FlaxAutoModelForSemanticSegmentation = auto_class_update(
FlaxAutoModelForSemanticSegmentation, head_doc="semantic segmentation modeling"
)
33 changes: 32 additions & 1 deletion src/transformers/models/dpt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...file_utils import _LazyModule, is_tokenizers_available, is_torch_available, is_vision_available
from ...file_utils import (
_LazyModule,
is_flax_available,
is_tokenizers_available,
is_torch_available,
is_vision_available,
)
from ...utils import OptionalDependencyNotAvailable


Expand Down Expand Up @@ -45,6 +51,19 @@
"DPTPreTrainedModel",
]

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_dpt"] = [
"FlaxDPTForSemanticSegmentation",
"FlaxDPTForDepthEstimation",
"FlaxDPTModel",
"FlaxDPTPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_dpt import DPT_PRETRAINED_CONFIG_ARCHIVE_MAP, DPTConfig
Expand All @@ -71,6 +90,18 @@
DPTPreTrainedModel,
)

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_dpt import (
FlaxDPTForDepthEstimation,
FlaxDPTForSemanticSegmentation,
FlaxDPTModel,
FlaxDPTPreTrainedModel,
)

else:
import sys
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/dpt/configuration_dpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def __init__(
auxiliary_loss_weight=0.4,
semantic_loss_ignore_index=255,
semantic_classifier_dropout=0.1,
align_corners=True,
**kwargs
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -168,3 +169,4 @@ def __init__(
self.auxiliary_loss_weight = auxiliary_loss_weight
self.semantic_loss_ignore_index = semantic_loss_ignore_index
self.semantic_classifier_dropout = semantic_classifier_dropout
self.align_corners = align_corners
Loading