# coding=utf-8
# Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
TensorFlow SAM model. This file was mostly generated by auto-translation from the PyTorch original. In the event of a
discrepancy, the original file should be regarded as the 'reference' version.
"""

from __future__ import annotations

import collections
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import numpy as np
import tensorflow as tf

from ...activations_tf import ACT2FN
from ...modeling_tf_outputs import TFBaseModelOutput
from ...modeling_tf_utils import TFModelInputType, TFPreTrainedModel, keras, shape_list, unpack_inputs
from ...tf_utils import flatten, functional_layernorm
from ...utils import (
    ModelOutput,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "SamConfig"
_CHECKPOINT_FOR_DOC = "facebook/sam-vit-huge"


@dataclass
class TFSamVisionEncoderOutput(ModelOutput):
    """
    Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection
    layer to the pooler_output.

    Args:
        image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
            The image embeddings obtained by applying the projection layer to the pooler_output.
        last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
            the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    image_embeds: tf.Tensor | None = None
    last_hidden_state: Optional[tf.Tensor] = None
    hidden_states: Tuple[tf.Tensor, ...] | None = None
    attentions: Tuple[tf.Tensor, ...] | None = None


@dataclass
class TFSamImageSegmentationOutput(ModelOutput):
    """
    Base class for Segment-Anything model's output

    Args:
        iou_scores (`tf.Tensor` of shape `(batch_size, num_masks)`):
            The iou scores of the predicted masks.
        pred_masks (`tf.Tensor` of shape `(batch_size, num_masks, height, width)`):
            The predicted low resolutions masks. Needs to be post-processed by the processor
        vision_hidden_states  (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
            the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs.
        vision_attentions  (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        mask_decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    iou_scores: Optional[tf.Tensor] = None
    pred_masks: Optional[tf.Tensor] = None
    vision_hidden_states: Tuple[tf.Tensor, ...] | None = None
    vision_attentions: Tuple[tf.Tensor, ...] | None = None
    mask_decoder_attentions: Tuple[tf.Tensor, ...] | None = None


class TFSamPatchEmbeddings(keras.layers.Layer):
    """
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    """

    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        image_size, patch_size = config.image_size, config.patch_size
        num_channels, hidden_size = config.num_channels, config.hidden_size
        image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.num_patches = num_patches

        self.projection = keras.layers.Conv2D(
            hidden_size, kernel_size=patch_size, strides=patch_size, name="projection"
        )

    def call(self, pixel_values):
        batch_size, num_channels, height, width = shape_list(pixel_values)
        if num_channels != self.num_channels:
            raise ValueError(
                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
            )
        if height != self.image_size[0] or width != self.image_size[1]:
            raise ValueError(
                f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
            )
        embeddings = self.projection(tf.transpose(pixel_values, perm=[0, 2, 3, 1]))
        return embeddings

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "projection", None) is not None:
            with tf.name_scope(self.projection.name):
                self.projection.build([None, None, None, self.num_channels])


class TFSamMLPBlock(keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.lin1 = keras.layers.Dense(config.mlp_dim, name="lin1")
        self.lin2 = keras.layers.Dense(config.hidden_size, name="lin2")
        self.act = ACT2FN[config.hidden_act]
        self.config = config

    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
        hidden_states = self.lin1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.lin2(hidden_states)
        return hidden_states

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "lin1", None) is not None:
            with tf.name_scope(self.lin1.name):
                self.lin1.build([None, None, self.config.hidden_size])
        if getattr(self, "lin2", None) is not None:
            with tf.name_scope(self.lin2.name):
                self.lin2.build([None, None, self.config.mlp_dim])


class TFSamLayerNorm(keras.layers.Layer):
    r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
    width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
    """

    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", **kwargs):
        super().__init__(**kwargs)
        self.eps = eps
        self.data_format = data_format
        self.normalized_shape = normalized_shape
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError(f"Unsupported data format: {self.data_format}")

    def build(self, input_shape):
        self.weight = self.add_weight(shape=self.normalized_shape, initializer="ones", name="weight")
        self.bias = self.add_weight(shape=self.normalized_shape, initializer="zeros", name="bias")
        super().build(input_shape)

    def call(self, x: tf.Tensor) -> tf.Tensor:
        if self.data_format == "channels_last":
            x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=-1)
        elif self.data_format == "channels_first":
            x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=1)
        return x


class TFSamAttention(keras.layers.Layer):
    """
    SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
    values.
    """

    def __init__(self, config, downsample_rate=None, **kwargs):
        super().__init__(**kwargs)
        self.hidden_size = config.hidden_size

        downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate

        self.internal_dim = config.hidden_size // downsample_rate
        self.num_attention_heads = config.num_attention_heads
        if self.internal_dim % config.num_attention_heads != 0:
            raise ValueError("num_attention_heads must divide hidden_size.")

        self.q_proj = keras.layers.Dense(self.internal_dim, name="q_proj")
        self.k_proj = keras.layers.Dense(self.internal_dim, name="k_proj")
        self.v_proj = keras.layers.Dense(self.internal_dim, name="v_proj")
        self.out_proj = keras.layers.Dense(self.hidden_size, name="out_proj")

    def _separate_heads(self, hidden_states: tf.Tensor, num_attention_heads: int) -> tf.Tensor:
        batch, point_batch_size, n_tokens, channel = shape_list(hidden_states)
        c_per_head = channel // num_attention_heads
        hidden_states = tf.reshape(
            hidden_states, (batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
        )
        return tf.transpose(hidden_states, perm=[0, 2, 1, 3])

    def _recombine_heads(self, hidden_states: tf.Tensor, point_batch_size: int) -> tf.Tensor:
        batch, n_heads, n_tokens, c_per_head = shape_list(hidden_states)
        hidden_states = tf.transpose(hidden_states, perm=[0, 2, 1, 3])
        return tf.reshape(
            hidden_states,
            (batch // tf.reduce_max([1, point_batch_size]), point_batch_size, n_tokens, n_heads * c_per_head),
        )

    def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor) -> tf.Tensor:
        # Input projections
        query = self.q_proj(query)
        key = self.k_proj(key)
        value = self.v_proj(value)

        point_batch_size = shape_list(query)[1]
        # Separate into heads
        query = self._separate_heads(query, self.num_attention_heads)
        key = self._separate_heads(key, self.num_attention_heads)
        value = self._separate_heads(value, self.num_attention_heads)

        # SamAttention
        _, _, _, c_per_head = shape_list(query)
        attn = tf.matmul(
            query, tf.transpose(key, perm=[0, 1, 3, 2])
        )  # batch_size * point_batch_size  x N_heads x N_tokens x N_tokens
        attn = attn / tf.math.sqrt(float(c_per_head))
        attn = tf.nn.softmax(attn, axis=-1)

        # Get output
        out = tf.matmul(attn, value)
        out = self._recombine_heads(out, point_batch_size)
        out = self.out_proj(out)

        return out

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "q_proj", None) is not None:
            with tf.name_scope(self.q_proj.name):
                self.q_proj.build([None, None, self.hidden_size])
        if getattr(self, "k_proj", None) is not None:
            with tf.name_scope(self.k_proj.name):
                self.k_proj.build([None, None, self.hidden_size])
        if getattr(self, "v_proj", None) is not None:
            with tf.name_scope(self.v_proj.name):
                self.v_proj.build([None, None, self.hidden_size])
        if getattr(self, "out_proj", None) is not None:
            with tf.name_scope(self.out_proj.name):
                self.out_proj.build([None, None, self.internal_dim])


class TFSamTwoWayAttentionBlock(keras.layers.Layer):
    def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, **kwargs):
        """
        A transformer block with four layers:
            (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
            sparse inputs (4) cross attention of dense inputs -> sparse inputs

        Arguments:
            config (`SamMaskDecoderConfig`):
                The configuration file used to instantiate the block
            attention_downsample_rate (*optionalk*, int, defaults to 2):
                The downsample ratio of the block used to reduce the inner dim of the attention.
            skip_first_layer_pe (*optional*, bool, defaults to `False`):
                Whether or not to skip the addition of the query_point_embedding on the first layer.
        """
        super().__init__(**kwargs)

        self.hidden_size = config.hidden_size
        self.layer_norm_eps = config.layer_norm_eps

        self.self_attn = TFSamAttention(config, downsample_rate=1, name="self_attn")
        self.layer_norm1 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm1")

        self.cross_attn_token_to_image = TFSamAttention(
            config, downsample_rate=attention_downsample_rate, name="cross_attn_token_to_image"
        )
        self.layer_norm2 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm2")

        self.mlp = TFSamMLPBlock(config, name="mlp")
        self.layer_norm3 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm3")

        self.layer_norm4 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm4")
        self.cross_attn_image_to_token = TFSamAttention(
            config, downsample_rate=attention_downsample_rate, name="cross_attn_image_to_token"
        )

        self.skip_first_layer_pe = skip_first_layer_pe

    def call(
        self,
        queries: tf.Tensor,
        keys: tf.Tensor,
        query_point_embedding: tf.Tensor,
        key_point_embedding: tf.Tensor,
        output_attentions: bool = False,
    ):
        # Self attention block
        if self.skip_first_layer_pe:
            queries = self.self_attn(query=queries, key=queries, value=queries)
        else:
            query = queries + query_point_embedding
            attn_out = self.self_attn(query=query, key=query, value=queries)
            queries = queries + attn_out
        queries = self.layer_norm1(queries)

        # Cross attention block, tokens attending to image embedding
        query = queries + query_point_embedding
        key = keys + key_point_embedding

        attn_out = self.cross_attn_token_to_image(query=query, key=key, value=keys)
        queries = queries + attn_out

        queries = self.layer_norm2(queries)

        # MLP block
        mlp_out = self.mlp(queries)
        queries = queries + mlp_out
        queries = self.layer_norm3(queries)

        # Cross attention block, image embedding attending to tokens
        query = queries + query_point_embedding
        key = keys + key_point_embedding

        attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries)
        keys = keys + attn_out

        keys = self.layer_norm4(keys)

        outputs = (queries, keys)

        if output_attentions:
            outputs = outputs + (attn_out,)
        else:
            outputs = outputs + (None,)

        return outputs

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "self_attn", None) is not None:
            with tf.name_scope(self.self_attn.name):
                self.self_attn.build(None)
        if getattr(self, "layer_norm1", None) is not None:
            with tf.name_scope(self.layer_norm1.name):
                self.layer_norm1.build([None, None, None, self.hidden_size])
        if getattr(self, "cross_attn_token_to_image", None) is not None:
            with tf.name_scope(self.cross_attn_token_to_image.name):
                self.cross_attn_token_to_image.build(None)
        if getattr(self, "layer_norm2", None) is not None:
            with tf.name_scope(self.layer_norm2.name):
                self.layer_norm2.build([None, None, None, self.hidden_size])
        if getattr(self, "mlp", None) is not None:
            with tf.name_scope(self.mlp.name):
                self.mlp.build(None)
        if getattr(self, "layer_norm3", None) is not None:
            with tf.name_scope(self.layer_norm3.name):
                self.layer_norm3.build([None, None, None, self.hidden_size])
        if getattr(self, "layer_norm4", None) is not None:
            with tf.name_scope(self.layer_norm4.name):
                self.layer_norm4.build([None, None, None, self.hidden_size])
        if getattr(self, "cross_attn_image_to_token", None) is not None:
            with tf.name_scope(self.cross_attn_image_to_token.name):
                self.cross_attn_image_to_token.build(None)


class TFSamTwoWayTransformer(keras.layers.Layer):
    def __init__(self, config: SamMaskDecoderConfig, **kwargs):
        super().__init__(**kwargs)
        self.config = config

        self.num_hidden_layers = config.num_hidden_layers
        self.layers = []

        for i in range(self.num_hidden_layers):
            self.layers.append(TFSamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0), name=f"layers_._{i}"))

        self.final_attn_token_to_image = TFSamAttention(config, name="final_attn_token_to_image")
        self.layer_norm_final_attn = keras.layers.LayerNormalization(
            epsilon=config.layer_norm_eps, name="layer_norm_final_attn"
        )

    def call(
        self,
        point_embeddings: tf.Tensor,
        image_embeddings: tf.Tensor,
        image_positional_embeddings: tf.Tensor,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, TFBaseModelOutput]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        all_attentions = ()

        if image_embeddings is None:
            raise ValueError("You have to specify an image_embedding")

        image_embeddings = tf.transpose(flatten(image_embeddings, 2), perm=(0, 2, 1))[:, None]
        image_positional_embeddings = tf.transpose(flatten(image_positional_embeddings, 2), (0, 2, 1))[:, None]

        # Prepare queries
        queries = point_embeddings
        keys = image_embeddings

        # Apply transformer blocks and final layernorm
        for layer in self.layers:
            queries, keys, attention_outputs = layer(
                queries=queries,
                keys=keys,
                query_point_embedding=point_embeddings,
                key_point_embedding=image_positional_embeddings,
                output_attentions=output_attentions,
            )

            if output_attentions:
                all_attentions = all_attentions + (attention_outputs,)

        # Apply the final attenion layer from the points to the image
        query = queries + point_embeddings
        key = keys + image_positional_embeddings

        attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys)

        queries = queries + attn_out
        queries = self.layer_norm_final_attn(queries)
        return queries, keys, all_attentions

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "final_attn_token_to_image", None) is not None:
            with tf.name_scope(self.final_attn_token_to_image.name):
                self.final_attn_token_to_image.build(None)
        if getattr(self, "layer_norm_final_attn", None) is not None:
            with tf.name_scope(self.layer_norm_final_attn.name):
                self.layer_norm_final_attn.build([None, None, None, self.config.hidden_size])
        for layer in self.layers:
            with tf.name_scope(layer.name):
                layer.build(None)


class TFSamFeedForward(keras.layers.Layer):
    def __init__(
        self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False, **kwargs
    ):
        super().__init__(**kwargs)
        self.num_layers = num_layers
        self.activation = keras.layers.ReLU()
        self.proj_in = keras.layers.Dense(hidden_dim, input_shape=(input_dim,), name="proj_in")
        self.proj_out = keras.layers.Dense(output_dim, input_shape=(hidden_dim,), name="proj_out")
        self.layers = [
            keras.layers.Dense(hidden_dim, input_shape=(hidden_dim,), name=f"layers_._{i}")
            for i in range(num_layers - 2)
        ]
        self.sigmoid_output = sigmoid_output
        self.hidden_dim = hidden_dim
        self.input_dim = input_dim

    def call(self, hidden_states):
        hidden_states = self.proj_in(hidden_states)
        hidden_states = self.activation(hidden_states)
        for layer in self.layers:
            hidden_states = self.activation(layer(hidden_states))

        hidden_states = self.proj_out(hidden_states)
        if self.sigmoid_output:
            hidden_states = tf.sigmoid(hidden_states)
        return hidden_states

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "proj_in", None) is not None:
            with tf.name_scope(self.proj_in.name):
                self.proj_in.build([None, None, self.input_dim])
        if getattr(self, "proj_out", None) is not None:
            with tf.name_scope(self.proj_out.name):
                self.proj_out.build([None, None, self.hidden_dim])
        if getattr(self, "layers", None) is not None:
            for layer in self.layers:
                with tf.name_scope(layer.name):
                    layer.build([None, None, self.hidden_dim])


class TFSamMaskDecoder(keras.layers.Layer):
    def __init__(self, config: SamMaskDecoderConfig, **kwargs):
        super().__init__(**kwargs)

        self.hidden_size = config.hidden_size

        self.num_multimask_outputs = config.num_multimask_outputs
        self.num_mask_tokens = config.num_multimask_outputs + 1

        self.transformer = TFSamTwoWayTransformer(config, name="transformer")

        self.upscale_conv1 = keras.layers.Conv2DTranspose(
            self.hidden_size // 4, kernel_size=2, strides=2, name="upscale_conv1", data_format="channels_first"
        )
        self.upscale_conv2 = keras.layers.Conv2DTranspose(
            self.hidden_size // 8, kernel_size=2, strides=2, name="upscale_conv2", data_format="channels_first"
        )
        self.upscale_layer_norm = TFSamLayerNorm(
            self.hidden_size // 4, data_format="channels_first", name="upscale_layer_norm"
        )
        self.activation = tf.nn.gelu

        mlps_list = []
        for i in range(self.num_mask_tokens):
            mlps_list += [
                TFSamFeedForward(
                    self.hidden_size,
                    self.hidden_size,
                    self.hidden_size // 8,
                    3,
                    name=f"output_hypernetworks_mlps_._{i}",
                )
            ]
        self.output_hypernetworks_mlps = mlps_list

        self.iou_prediction_head = TFSamFeedForward(
            self.hidden_size,
            config.iou_head_hidden_dim,
            self.num_mask_tokens,
            config.iou_head_depth,
            name="iou_prediction_head",
        )

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        self.iou_token = self.add_weight(shape=(1, self.hidden_size), name="iou_token.weight", trainable=True)
        self.mask_tokens = self.add_weight(
            shape=(self.num_mask_tokens, self.hidden_size), name="mask_tokens.weight", trainable=True
        )

        if getattr(self, "transformer", None) is not None:
            with tf.name_scope(self.transformer.name):
                self.transformer.build(None)
        if getattr(self, "upscale_conv1", None) is not None:
            with tf.name_scope(self.upscale_conv1.name):
                self.upscale_conv1.build([None, self.hidden_size, None, None])
        if getattr(self, "upscale_conv2", None) is not None:
            with tf.name_scope(self.upscale_conv2.name):
                self.upscale_conv2.build([None, self.hidden_size // 4, None, None])
        if getattr(self, "upscale_layer_norm", None) is not None:
            with tf.name_scope(self.upscale_layer_norm.name):
                self.upscale_layer_norm.build(None)
        if getattr(self, "iou_prediction_head", None) is not None:
            with tf.name_scope(self.iou_prediction_head.name):
                self.iou_prediction_head.build(None)
        for mlp in self.output_hypernetworks_mlps:
            with tf.name_scope(mlp.name):
                mlp.build(None)

    def call(
        self,
        image_embeddings: tf.Tensor,
        image_positional_embeddings: tf.Tensor,
        sparse_prompt_embeddings: tf.Tensor,
        dense_prompt_embeddings: tf.Tensor,
        multimask_output: bool,
        output_attentions: Optional[bool] = None,
    ) -> Tuple[tf.Tensor, tf.Tensor]:
        batch_size, num_channels, height, width = shape_list(image_embeddings)
        point_batch_size = tf.math.maximum(1, tf.shape(sparse_prompt_embeddings)[1])

        output_tokens = tf.concat([self.iou_token, self.mask_tokens], axis=0)  # Should be (1, 32) + (4, 32) = (5, 32)
        output_tokens = tf.tile(
            output_tokens[None, None, :], [batch_size, point_batch_size, 1, 1]
        )  # Should be (batch_size, point_size, 5, 32)

        # Matt: The original Torch code checked that the sum of sparse_prompt_embeddings equalled 0. However, this only
        #       happens when the sparse prompt embeddings are an empty tensor with shape[1] == 0. I replaced
        #       it with an explicit shape check to avoid data-dependent control flow which breaks XLA.
        if shape_list(sparse_prompt_embeddings)[1] != 0:
            tokens = tf.concat((output_tokens, sparse_prompt_embeddings), axis=2)
        else:
            tokens = output_tokens
        point_embeddings = tf.cast(tokens, self.iou_token.dtype)

        image_embeddings = image_embeddings + dense_prompt_embeddings
        image_embeddings = tf.repeat(image_embeddings, point_batch_size, axis=0)
        image_positional_embeddings = tf.repeat(image_positional_embeddings, point_batch_size, axis=0)

        point_embedding, image_embeddings, attentions = self.transformer(
            point_embeddings=point_embeddings,
            image_embeddings=image_embeddings,
            image_positional_embeddings=image_positional_embeddings,
            output_attentions=output_attentions,
        )
        iou_token_out = point_embedding[:, :, 0, :]
        mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]

        image_embeddings = tf.transpose(image_embeddings, perm=(0, 1, 3, 2))
        image_embeddings = tf.reshape(image_embeddings, [batch_size * point_batch_size, num_channels, height, width])

        upscaled_embedding = self.upscale_conv1(image_embeddings)
        upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
        upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))

        hyper_in_list = []
        for i in range(self.num_mask_tokens):
            current_mlp = self.output_hypernetworks_mlps[i]
            hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
        hyper_in = tf.stack(hyper_in_list, axis=2)

        _, num_channels, height, width = shape_list(upscaled_embedding)
        upscaled_embedding = tf.reshape(
            upscaled_embedding, [batch_size, point_batch_size, num_channels, height * width]
        )
        masks = tf.reshape(hyper_in @ upscaled_embedding, [batch_size, point_batch_size, -1, height, width])

        iou_pred = self.iou_prediction_head(iou_token_out)

        if multimask_output:
            mask_slice = slice(1, None)
        else:
            mask_slice = slice(0, 1)
        masks = masks[:, :, mask_slice, :, :]
        iou_pred = iou_pred[:, :, mask_slice]

        outputs = (masks, iou_pred)

        if output_attentions:
            outputs = outputs + (attentions,)
        else:
            outputs = outputs + (None,)

        return outputs


class TFSamPositionalEmbedding(keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.scale = config.hidden_size // 2
        self.config = config

    def build(self, input_shape):
        # TODO Matt: What is going on here? Why is a non-trainable weight randomly initialized?
        self.positional_embedding = self.add_weight(
            name="positional_embedding",
            shape=(2, self.config.num_pos_feats),
            initializer=keras.initializers.RandomNormal(mean=0.0, stddev=self.scale),
            trainable=False,
        )
        super().build(input_shape)

    def call(self, input_coords, input_shape=None):
        """Positionally encode points that are normalized to [0,1]."""
        coordinates = tf.identity(input_coords)

        if input_shape is not None:
            coordinates = tf.stack(
                [
                    tf.cast(coordinates[:, :, :, 0], tf.float32) / input_shape[1],
                    tf.cast(coordinates[:, :, :, 1], tf.float32) / input_shape[0],
                ],
                axis=-1,
            )

        # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
        coordinates = 2 * coordinates - 1
        coordinates = tf.cast(coordinates, self.positional_embedding.dtype)
        coordinates = tf.matmul(coordinates, self.positional_embedding)
        coordinates = 2 * np.pi * coordinates
        # outputs d_1 x ... x d_n x channel shape
        return tf.concat([tf.sin(coordinates), tf.cos(coordinates)], axis=-1)


class TFSamMaskEmbedding(keras.layers.Layer):
    def __init__(self, config: SamPromptEncoderConfig, **kwargs):
        super().__init__(**kwargs)
        self.mask_input_channels = config.mask_input_channels // 4
        self.activation = ACT2FN[config.hidden_act]
        self.conv1 = keras.layers.Conv2D(self.mask_input_channels, kernel_size=2, strides=2, name="conv1")
        self.conv2 = keras.layers.Conv2D(config.mask_input_channels, kernel_size=2, strides=2, name="conv2")
        self.conv3 = keras.layers.Conv2D(config.hidden_size, kernel_size=1, name="conv3")
        self.layer_norm1 = TFSamLayerNorm(self.mask_input_channels, config.layer_norm_eps, name="layer_norm1")
        self.layer_norm2 = TFSamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps, name="layer_norm2")
        self.config = config

    def call(self, masks):
        masks = tf.transpose(masks, perm=(0, 2, 3, 1))  # Convert to channels-last
        hidden_states = self.conv1(masks)
        hidden_states = self.layer_norm1(hidden_states)
        hidden_states = self.activation(hidden_states)

        hidden_states = self.conv2(hidden_states)
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.activation(hidden_states)
        dense_embeddings = self.conv3(hidden_states)
        dense_embeddings = tf.transpose(dense_embeddings, perm=(0, 3, 1, 2))  # Convert back to channels-first
        return dense_embeddings

    def build(self, input_shape=None):
        # This class needs an explicit build method because it isn't called with the standard dummy inputs
        if self.built:
            return
        self.built = True
        with tf.name_scope("conv1"):
            self.conv1.build([None, None, None, 1])
        with tf.name_scope("conv2"):
            self.conv2.build([None, None, None, self.mask_input_channels])
        with tf.name_scope("conv3"):
            self.conv3.build([None, None, None, self.mask_input_channels * 4])
        with tf.name_scope("layer_norm1"):
            self.layer_norm1.build([None, None, None, self.mask_input_channels])
        with tf.name_scope("layer_norm2"):
            self.layer_norm2.build([None, None, None, self.mask_input_channels * 4])


class TFSamPromptEncoder(keras.layers.Layer):
    def __init__(self, config: SamPromptEncoderConfig, shared_patch_embedding, **kwargs):
        super().__init__(**kwargs)
        self.shared_embedding = shared_patch_embedding
        self.mask_embed = TFSamMaskEmbedding(config, name="mask_embed")
        self.no_mask_embed = None

        self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size)
        self.input_image_size = config.image_size

        self.point_embed = []
        self.hidden_size = config.hidden_size
        self.not_a_point_embed = None
        self.config = config

    def build(self, input_shape=None):
        self.no_mask_embed = self.add_weight(
            name="no_mask_embed.weight",
            shape=(1, self.hidden_size),
            initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
            trainable=True,
        )
        self.point_embed = [
            self.add_weight(
                name=f"point_embed_._{i}.weight",
                shape=(1, self.hidden_size),
                initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
                trainable=True,
            )
            for i in range(self.config.num_point_embeddings)
        ]
        self.not_a_point_embed = self.add_weight(
            name="not_a_point_embed.weight",
            shape=(1, self.hidden_size),
            initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
            trainable=True,
        )
        with tf.name_scope("mask_embed"):
            # We must explicitly build the mask embed because it isn't touched by the standard dummy inputs
            self.mask_embed.build(
                (None, self.config.mask_input_channels, self.config.image_size, self.config.image_size)
            )

        if self.built:
            return
        self.built = True
        if getattr(self, "mask_embed", None) is not None:
            with tf.name_scope(self.mask_embed.name):
                self.mask_embed.build(None)

    def _embed_points(self, points: tf.Tensor, labels: tf.Tensor, pad: bool) -> tf.Tensor:
        """Embeds point prompts."""
        points = points + 0.5  # Shift to center of pixel
        if pad:
            target_point_shape = (shape_list(points)[0], shape_list(points)[1], 1, shape_list(points)[-1])
            target_labels_shape = (shape_list(points)[0], shape_list(points)[1], 1)
            padding_point = tf.zeros(target_point_shape, dtype=points.dtype)
            padding_label = -tf.ones(target_labels_shape, dtype=labels.dtype)
            points = tf.concat([points, padding_point], axis=2)
            labels = tf.concat([labels, padding_label], axis=2)
        input_shape = (self.input_image_size, self.input_image_size)
        point_embedding = self.shared_embedding(points, input_shape)

        point_embedding = tf.where(labels[..., None] == -1, self.not_a_point_embed[0], point_embedding)

        point_embedding = tf.where(
            labels[..., None] != -10,
            point_embedding,
            tf.zeros_like(point_embedding),
        )
        point_embedding = tf.where(
            (labels == 0)[:, :, :, None], point_embedding + self.point_embed[0], point_embedding
        )
        point_embedding = tf.where(
            (labels == 1)[:, :, :, None], point_embedding + self.point_embed[1], point_embedding
        )
        return point_embedding

    def _embed_boxes(self, boxes: tf.Tensor) -> tf.Tensor:
        """Embeds box prompts."""
        boxes = boxes + 0.5  # Shift to center of pixel
        batch_size, nb_boxes = shape_list(boxes)[:2]
        coords = tf.reshape(boxes, (batch_size, nb_boxes, 2, 2))
        input_shape = (self.input_image_size, self.input_image_size)
        corner_embedding = self.shared_embedding(coords, input_shape)
        corner_embedding += tf.where(
            tf.range(shape_list(corner_embedding)[2])[None, None, :, None] == 0,
            self.point_embed[2][0],
            self.point_embed[3][0],
        )
        return corner_embedding

    def call(
        self,
        batch_size: Optional[int],
        input_points: Optional[Tuple[tf.Tensor, tf.Tensor]],
        input_labels: tf.Tensor | None,
        input_boxes: tf.Tensor | None,
        input_masks: tf.Tensor | None,
    ) -> Tuple[tf.Tensor, tf.Tensor]:
        """
        Embeds different types of prompts, returning both sparse and dense embeddings.

        Args:
            points (`tf.Tensor`, *optional*):
                point coordinates and labels to embed.
            boxes (`tf.Tensor`, *optional*):
                boxes to embed
            masks (`tf.Tensor`, *optional*):
                masks to embed
        """
        sparse_embeddings = None
        if input_points is not None:
            batch_size, point_batch_size = shape_list(input_points)[:2]
            if input_labels is None:
                raise ValueError("If points are provided, labels must also be provided.")
            point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
            sparse_embeddings = tf.zeros(
                (batch_size, point_batch_size, 0, self.hidden_size), dtype=point_embeddings.dtype
            )
            sparse_embeddings = tf.concat([sparse_embeddings, point_embeddings], axis=2)
        if input_boxes is not None:
            batch_size = shape_list(input_boxes)[0]
            box_embeddings = self._embed_boxes(input_boxes)
            if sparse_embeddings is None:
                sparse_embeddings = box_embeddings
            else:
                sparse_embeddings = tf.concat([sparse_embeddings, box_embeddings], axis=2)
        if input_masks is not None:
            dense_embeddings = self.mask_embed(input_masks)
        else:
            dense_embeddings = self.no_mask_embed[0]
            dense_embeddings = tf.reshape(dense_embeddings, (1, -1, 1, 1))
            dense_embeddings = tf.tile(
                dense_embeddings, (batch_size, 1, self.image_embedding_size[0], self.image_embedding_size[1])
            )
        if sparse_embeddings is None:
            sparse_embeddings = tf.zeros((batch_size, 0, 1, self.hidden_size), dtype=dense_embeddings.dtype)

        return sparse_embeddings, dense_embeddings


class TFSamVisionAttention(keras.layers.Layer):
    """Multi-head Attention block with relative position embeddings."""

    def __init__(self, config, window_size, **kwargs):
        super().__init__(**kwargs)
        input_size = (
            (config.image_size // config.patch_size, config.image_size // config.patch_size)
            if window_size == 0
            else (window_size, window_size)
        )
        self.input_size = input_size

        self.num_attention_heads = config.num_attention_heads
        head_dim = config.hidden_size // config.num_attention_heads
        self.head_dim = head_dim
        self.scale = head_dim**-0.5
        self.dropout = config.attention_dropout

        self.qkv = keras.layers.Dense(config.hidden_size * 3, use_bias=config.qkv_bias, name="qkv")
        self.proj = keras.layers.Dense(config.hidden_size, name="proj")

        self.use_rel_pos = config.use_rel_pos
        if self.use_rel_pos:
            if input_size is None:
                raise ValueError("Input size must be provided if using relative positional encoding.")
        self.config = config

    def build(self, input_shape=None):
        if self.input_size is not None:
            # initialize relative positional embeddings
            self.rel_pos_h = self.add_weight(
                shape=(2 * self.input_size[0] - 1, self.head_dim), initializer="zeros", name="rel_pos_h"
            )
            self.rel_pos_w = self.add_weight(
                shape=(2 * self.input_size[1] - 1, self.head_dim), initializer="zeros", name="rel_pos_w"
            )

        if self.built:
            return
        self.built = True
        if getattr(self, "qkv", None) is not None:
            with tf.name_scope(self.qkv.name):
                self.qkv.build([None, None, self.config.hidden_size])
        if getattr(self, "proj", None) is not None:
            with tf.name_scope(self.proj.name):
                self.proj.build([None, None, self.config.hidden_size])

    def get_rel_pos(self, q_size: int, k_size: int, rel_pos: tf.Tensor) -> tf.Tensor:
        """
        Get relative positional embeddings according to the relative positions of
            query and key sizes.

        Args:
            q_size (int):
                size of the query.
            k_size (int):
                size of key k.
            rel_pos (`tf.Tensor`):
                relative position embeddings (L, channel).

        Returns:
            Extracted positional embeddings according to relative positions.
        """
        max_rel_dist = int(2 * max(q_size, k_size) - 1)
        # Interpolate rel pos if needed.
        if rel_pos.shape[0] != max_rel_dist:
            # Interpolate rel pos.
            rel_pos_resized = tf.image.resize(
                tf.reshape(rel_pos, (1, rel_pos.shape[0], -1)),
                size=(max_rel_dist, rel_pos.shape[1]),
                method="bilinear",
            )
            rel_pos_resized = tf.reshape(rel_pos_resized, (-1, max_rel_dist))
        else:
            rel_pos_resized = rel_pos

        # Scale the coords with short length if shapes for q and k are different.
        q_coords = tf.expand_dims(tf.range(q_size, dtype=tf.float32), 1) * max(k_size / q_size, 1.0)
        k_coords = tf.expand_dims(tf.range(k_size, dtype=tf.float32), 0) * max(q_size / k_size, 1.0)
        relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)

        return tf.gather(rel_pos_resized, tf.cast(relative_coords, tf.int32))

    def get_decomposed_rel_pos(
        self,
        query: tf.Tensor,
        rel_pos_h: tf.Tensor,
        rel_pos_w: tf.Tensor,
        q_size: Tuple[int, int],
        k_size: Tuple[int, int],
    ) -> tf.Tensor:
        """
        Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
        https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py

        Args:
            query (`tf.Tensor`):
                query q in the attention layer with shape (batch_size, query_height * query_width, channel).
            rel_pos_h (`tf.Tensor`):
                relative position embeddings (Lh, channel) for height axis.
            rel_pos_w (`tf.Tensor`):
                relative position embeddings (Lw, channel) for width axis.
            q_size (tuple):
                spatial sequence size of query q with (query_height, query_width).
            k_size (tuple):
                spatial sequence size of key k with (key_height, key_width).

        Returns:
            decomposed_rel_pos (`torch.Tensor`):
                decomposed relative position embeddings.
        """
        query_height, query_width = q_size
        key_height, key_width = k_size
        relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
        relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)

        batch_size, _, dim = shape_list(query)
        reshaped_query = tf.reshape(query, (batch_size, query_height, query_width, dim))
        rel_h = tf.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
        rel_w = tf.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)

        rel_h = tf.expand_dims(rel_h, axis=-1)
        rel_w = tf.expand_dims(rel_w, axis=-2)
        decomposed_rel_pos = rel_h + rel_w

        return decomposed_rel_pos

    def call(self, hidden_states: tf.Tensor, output_attentions=False, training=False) -> tf.Tensor:
        batch_size, height, width, _ = shape_list(hidden_states)
        # qkv with shape (3, batch_size, nHead, height * width, channel)
        qkv = tf.reshape(self.qkv(hidden_states), (batch_size, height * width, 3, self.num_attention_heads, -1))
        qkv = tf.transpose(qkv, perm=(2, 0, 3, 1, 4))
        # q, k, v with shape (batch_size * nHead, height * width, channel)
        query, key, value = tf.unstack(
            tf.reshape(qkv, (3, batch_size * self.num_attention_heads, height * width, -1)), axis=0
        )
        attn_weights = tf.matmul(query * self.scale, key, transpose_b=True)

        if self.use_rel_pos:
            decomposed_rel_pos = self.get_decomposed_rel_pos(
                query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
            )
            decomposed_rel_pos = tf.reshape(decomposed_rel_pos, shape_list(attn_weights))
            attn_weights = attn_weights + decomposed_rel_pos

        attn_weights = tf.nn.softmax(attn_weights, axis=-1)

        if training:
            attn_probs = tf.nn.dropout(attn_weights, rate=self.dropout)
        else:
            attn_probs = attn_weights

        attn_output = tf.reshape(attn_probs @ value, (batch_size, self.num_attention_heads, height, width, -1))
        attn_output = tf.transpose(attn_output, perm=(0, 2, 3, 1, 4))
        attn_output = tf.reshape(attn_output, (batch_size, height, width, self.config.hidden_size))

        attn_output = self.proj(attn_output)

        if output_attentions:
            outputs = (attn_output, attn_weights)
        else:
            outputs = (attn_output, None)

        return outputs


class TFSamVisionLayer(keras.layers.Layer):
    def __init__(self, config, window_size, **kwargs):
        super().__init__(**kwargs)
        self.layer_norm1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1")
        self.attn = TFSamVisionAttention(config, window_size, name="attn")
        self.layer_norm2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2")
        self.mlp = TFSamMLPBlock(config, name="mlp")
        self.window_size = window_size
        self.config = config

    def window_partition(self, hidden_states: tf.Tensor, window_size: int) -> Tuple[tf.Tensor, Tuple[int, int]]:
        batch_size, height, width, channel = shape_list(hidden_states)

        pad_h = (window_size - height % window_size) % window_size
        pad_w = (window_size - width % window_size) % window_size
        if pad_h > 0 or pad_w > 0:
            hidden_states = tf.pad(hidden_states, [[0, 0], [0, pad_h], [0, pad_w], [0, 0]])
        pad_height, pad_width = height + pad_h, width + pad_w

        hidden_states = tf.reshape(
            hidden_states,
            [batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel],
        )
        windows = tf.reshape(
            tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [-1, window_size, window_size, channel]
        )
        return windows, (pad_height, pad_width)

    def window_unpartition(
        self, windows: tf.Tensor, window_size: int, padding_shape: Tuple[int, int], original_shape: Tuple[int, int]
    ) -> tf.Tensor:
        pad_height, pad_width = padding_shape
        height, width = original_shape
        batch_size = shape_list(windows)[0] // (pad_height * pad_width // window_size // window_size)
        hidden_states = tf.reshape(
            windows, [batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1]
        )
        hidden_states = tf.reshape(
            tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [batch_size, pad_height, pad_width, -1]
        )

        if pad_height > height or pad_width > width:
            hidden_states = hidden_states[:, :height, :width, :]
        return hidden_states

    def call(
        self,
        hidden_states: tf.Tensor,
        output_attentions: Optional[bool] = False,
        training: Optional[bool] = False,
    ) -> Tuple[tf.Tensor]:
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
        if self.window_size > 0:
            height, width = hidden_states.shape[1], hidden_states.shape[2]
            hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size)

        hidden_states, attn_weights = self.attn(
            hidden_states=hidden_states,
            output_attentions=output_attentions,
            training=training,
        )
        if self.window_size > 0:
            hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width))

        hidden_states = residual + hidden_states
        layernorm_output = self.layer_norm2(hidden_states)
        hidden_states = hidden_states + self.mlp(layernorm_output)

        outputs = (hidden_states,)
        if output_attentions:
            outputs += (attn_weights,)

        return outputs

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "layer_norm1", None) is not None:
            with tf.name_scope(self.layer_norm1.name):
                self.layer_norm1.build([None, None, None, self.config.hidden_size])
        if getattr(self, "attn", None) is not None:
            with tf.name_scope(self.attn.name):
                self.attn.build(None)
        if getattr(self, "layer_norm2", None) is not None:
            with tf.name_scope(self.layer_norm2.name):
                self.layer_norm2.build([None, None, None, self.config.hidden_size])
        if getattr(self, "mlp", None) is not None:
            with tf.name_scope(self.mlp.name):
                self.mlp.build(None)


class TFSamVisionNeck(keras.layers.Layer):
    def __init__(self, config: SamVisionConfig, **kwargs):
        super().__init__(**kwargs)
        self.config = config

        self.conv1 = keras.layers.Conv2D(
            config.output_channels,
            kernel_size=1,
            use_bias=False,
            name="conv1",
        )
        self.layer_norm1 = TFSamLayerNorm(config.output_channels, name="layer_norm1")
        self.conv2 = keras.layers.Conv2D(
            config.output_channels,
            kernel_size=3,
            padding="same",
            use_bias=False,
            name="conv2",
        )
        self.layer_norm2 = TFSamLayerNorm(config.output_channels, name="layer_norm2")

    def call(self, hidden_states):
        hidden_states = self.conv1(hidden_states)
        hidden_states = self.layer_norm1(hidden_states)

        hidden_states = self.conv2(hidden_states)
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = tf.transpose(hidden_states, perm=[0, 3, 1, 2])
        return hidden_states

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "conv1", None) is not None:
            with tf.name_scope(self.conv1.name):
                self.conv1.build([None, None, None, self.config.hidden_size])
        if getattr(self, "layer_norm1", None) is not None:
            with tf.name_scope(self.layer_norm1.name):
                self.layer_norm1.build(None)
        if getattr(self, "conv2", None) is not None:
            with tf.name_scope(self.conv2.name):
                self.conv2.build([None, None, None, self.config.output_channels])
        if getattr(self, "layer_norm2", None) is not None:
            with tf.name_scope(self.layer_norm2.name):
                self.layer_norm2.build(None)


class TFSamVisionEncoder(keras.layers.Layer):
    def __init__(self, config: SamVisionConfig, **kwargs):
        super().__init__(**kwargs)
        self.config = config
        self.image_size = config.image_size

        self.patch_embed = TFSamPatchEmbeddings(config, name="patch_embed")

        self.pos_embed = None

        self.layers = []
        for i in range(config.num_hidden_layers):
            layer = TFSamVisionLayer(
                config,
                window_size=config.window_size if i not in config.global_attn_indexes else 0,
                name=f"layers_._{i}",
            )
            self.layers.append(layer)

        self.neck = TFSamVisionNeck(config, name="neck")

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if self.config.use_abs_pos:
            # Initialize absolute positional embedding with pretrain image size.
            self.pos_embed = self.add_weight(
                shape=[
                    1,
                    self.config.image_size // self.config.patch_size,
                    self.config.image_size // self.config.patch_size,
                    self.config.hidden_size,
                ],
                initializer="zeros",
                trainable=True,
                name="pos_embed",
            )

        if getattr(self, "patch_embed", None) is not None:
            with tf.name_scope(self.patch_embed.name):
                self.patch_embed.build(None)
        if getattr(self, "neck", None) is not None:
            with tf.name_scope(self.neck.name):
                self.neck.build(None)
        for layer in self.layers:
            with tf.name_scope(layer.name):
                layer.build(None)

    def get_input_embeddings(self):
        return self.patch_embed

    def call(
        self,
        pixel_values: tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: Optional[bool] = False,
    ) -> Union[Tuple, TFSamVisionEncoderOutput]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if pixel_values is None:
            raise ValueError("You have to specify pixel_values")

        hidden_states = self.patch_embed(pixel_values)
        if self.pos_embed is not None:
            hidden_states = hidden_states + self.pos_embed

        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None

        for i, layer_module in enumerate(self.layers):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_outputs = layer_module(hidden_states, output_attentions=output_attentions, training=training)

            hidden_states = layer_outputs[0]

            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        hidden_states = self.neck(hidden_states)

        if not return_dict:
            outputs = (hidden_states,)
            if output_hidden_states:
                outputs = outputs + (all_hidden_states,)
            if output_attentions:
                outputs = outputs + (all_self_attentions,)
            return outputs

        return TFSamVisionEncoderOutput(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )


class TFSamPreTrainedModel(TFPreTrainedModel):
    config_class = SamConfig
    base_model_prefix = "sam"
    main_input_name = "pixel_values"


SAM_START_DOCSTRING = r"""
    This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a TensorFlow [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model)
    subclass. Use it as a regular TensorFlow Model and refer to the TensorFlow documentation for all matter related to
    general usage and behavior.

    Parameters:
        config ([`SamConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
"""


SAM_INPUTS_DOCSTRING = r"""
    Args:
        pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for
            details.
        input_points (`tf.Tensor` of shape `(batch_size, num_points, 2)`):
            Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
            better results. The points can be obtained by passing a list of list of list to the processor that will
            create corresponding `tf` tensors of dimension 4. The first dimension is the image batch size, the second
            dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict per
            input point), the third dimension is the number of points per segmentation mask (it is possible to pass
            multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
            coordinates of the point. If a different number of points is passed either for each image, or for each
            mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
            computation of the embedding will be skipped for these points using the labels.
        input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points)`):
            Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
            official implementation, there are 3 types of labels

            - `1`: the point is a point that contains the object of interest
            - `0`: the point is a point that does not contain the object of interest
            - `-1`: the point corresponds to the background

            We added the label:

            - `-10`: the point is a padding point, thus should be ignored by the prompt encoder

            The padding labels should be automatically done by the processor.
        input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes, 4)`):
            Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
            much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
            that will generate a `tf` tensor, with each dimension corresponding respectively to the image batch size,
            the number of boxes per image and the coordinates of the top left and botton right point of the box. In the
            order (`x1`, `y1`, `x2`, `y2`):

            - `x1`: the x coordinate of the top left point of the input box
            - `y1`: the y coordinate of the top left point of the input box
            - `x2`: the x coordinate of the bottom right point of the input box
            - `y2`: the y coordinate of the bottom right point of the input box

        input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`):
            SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
            generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
            manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).

        image_embeddings (`tf.Tensor` of shape `(batch_size, output_channels, window_size, window_size)`):
            Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory
            efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
            method, and then feed them to the `call` method instead of feeding the `pixel_values`.
        multimask_output (`bool`, *optional*):
            In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
            bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
            "best" mask, by specifying `multimask_output=False`.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""


SAM_VISION_INPUTS_DOCSTRING = r"""
    Args:
        pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for
            details.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""


@add_start_docstrings(
    """The vision model from Sam without any head or projection on top.""",
    SAM_START_DOCSTRING,
)
class TFSamVisionModel(TFSamPreTrainedModel):
    config_class = SamVisionConfig
    main_input_name = "pixel_values"

    def __init__(self, config: SamVisionConfig, **kwargs):
        super().__init__(config, **kwargs)
        self.vision_encoder = TFSamVisionEncoder(config, name="vision_encoder")

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "vision_encoder", None) is not None:
            with tf.name_scope(self.vision_encoder.name):
                self.vision_encoder.build(None)

    def get_input_embeddings(self):
        return self.vision_encoder.patch_embed

    @unpack_inputs
    @add_start_docstrings_to_model_forward(SAM_VISION_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=TFSamVisionEncoderOutput, config_class=SamVisionConfig)
    def call(
        self,
        pixel_values: TFModelInputType | None = None,
        output_attentions: bool | None = None,
        output_hidden_states: bool | None = None,
        return_dict: bool | None = None,
        training: bool = False,
        **kwargs,
    ) -> TFSamVisionEncoderOutput | Tuple[tf.Tensor]:
        r"""
        Returns:

        """
        return self.vision_encoder(
            pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )


@add_start_docstrings(
    "Segment Anything Model (SAM) for generating segmentation masks, given an input image and ",
    " optional 2D location and bounding boxes.",
    SAM_START_DOCSTRING,
)
class TFSamModel(TFSamPreTrainedModel):
    _keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"]

    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)
        self.shared_image_embedding = TFSamPositionalEmbedding(config.vision_config, name="shared_image_embedding")

        self.vision_encoder = TFSamVisionEncoder(config.vision_config, name="vision_encoder")
        self.prompt_encoder = TFSamPromptEncoder(
            config.prompt_encoder_config, self.shared_image_embedding, name="prompt_encoder"
        )
        self.mask_decoder = TFSamMaskDecoder(config.mask_decoder_config, name="mask_decoder")
        self.config = config

    def get_input_embeddings(self):
        return self.vision_encoder.get_input_embeddings()

    def get_image_wide_positional_embeddings(self):
        size = self.config.prompt_encoder_config.image_embedding_size
        grid = tf.ones((size, size))
        y_embed = tf.math.cumsum(grid, axis=0) - 0.5
        x_embed = tf.math.cumsum(grid, axis=1) - 0.5
        y_embed = y_embed / size
        x_embed = x_embed / size

        positional_embedding = self.shared_image_embedding(tf.stack([x_embed, y_embed], axis=-1))
        return tf.expand_dims(tf.transpose(positional_embedding, perm=[2, 0, 1]), axis=0)  # channel x height x width

    def get_image_embeddings(
        self,
        pixel_values,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        r"""
        Returns the image embeddings by passing the pixel values through the vision encoder.

        Args:
            pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
                Input pixel values
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.TFModelOutput`] instead of a plain tuple.

        """
        vision_output = self.vision_encoder(
            pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        image_embeddings = vision_output[0]
        return image_embeddings

    def get_prompt_embeddings(
        self,
        input_points: tf.Tensor | None = None,
        input_labels: tf.Tensor | None = None,
        input_boxes: tf.Tensor | None = None,
        input_masks: tf.Tensor | None = None,
    ):
        r"""
        Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.

        Args:
            input_points (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
                Optional input points for the prompt encoder. The padding of the point is automatically done by the
                processor. `point_batch_size` refers to the number of masks that we want the model to predict per
                point. The model will output `point_batch_size` times 3 masks in total.
            input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
                Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
                processor, or can be fed by the user.
            input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes_per_image, 4)`):
                Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
                processor. users can also pass manually the input boxes.
            input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`):
                Optional input masks for the prompt encoder.
        """
        prompt_output = self.prompt_encoder(
            input_points=input_points,
            input_labels=input_labels,
            input_boxes=input_boxes,
            input_masks=input_masks,
        )
        return prompt_output

    @unpack_inputs
    @add_start_docstrings_to_model_forward(SAM_INPUTS_DOCSTRING)
    def call(
        self,
        pixel_values: TFModelInputType | None = None,
        input_points: tf.Tensor | None = None,
        input_labels: tf.Tensor | None = None,
        input_boxes: tf.Tensor | None = None,
        input_masks: tf.Tensor | None = None,
        image_embeddings: tf.Tensor | None = None,
        multimask_output: bool = True,
        output_attentions: bool | None = None,
        output_hidden_states: bool | None = None,
        return_dict: bool | None = None,
        training: bool = False,
        **kwargs,
    ) -> TFSamImageSegmentationOutput | Tuple[tf.Tensor]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if pixel_values is None and image_embeddings is None:
            raise ValueError("Either pixel_values or image_embeddings must be provided.")

        if pixel_values is not None and image_embeddings is not None:
            raise ValueError("Only one of pixel_values and image_embeddings can be provided.")

        if input_points is not None and len(input_points.shape) != 4:
            raise ValueError(
                "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.",
                " got {}.".format(input_points.shape),
            )
        if input_boxes is not None and len(input_boxes.shape) != 3:
            raise ValueError(
                "The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.",
                " got {}.".format(input_boxes.shape),
            )
        if input_points is not None and input_boxes is not None:
            point_batch_size = shape_list(input_points)[1]
            box_batch_size = shape_list(input_boxes)[1]
            if point_batch_size != box_batch_size:
                raise ValueError(
                    "You should provide as many bounding boxes as input points per box. Got {} and {}.".format(
                        point_batch_size, box_batch_size
                    )
                )
        if pixel_values is not None:
            # Ensures that later checks pass even with an all-None shape from the serving signature
            pixel_values = tf.ensure_shape(
                pixel_values,
                [
                    None,
                    self.config.vision_config.num_channels,
                    self.config.vision_config.image_size,
                    self.config.vision_config.image_size,
                ],
            )
        image_positional_embeddings = self.get_image_wide_positional_embeddings()
        # repeat with batch size
        batch_size = shape_list(pixel_values)[0] if pixel_values is not None else shape_list(image_embeddings)[0]
        image_positional_embeddings = tf.repeat(image_positional_embeddings, batch_size, axis=0)

        vision_attentions = None
        vision_hidden_states = None

        if pixel_values is not None:
            vision_outputs = self.vision_encoder(
                pixel_values,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=True,
                training=training,
            )
            image_embeddings = vision_outputs["last_hidden_state"]

            if output_hidden_states:
                vision_hidden_states = vision_outputs["hidden_states"]
            if output_attentions:
                vision_attentions = vision_outputs["attentions"]

        if input_points is not None and input_labels is None:
            input_labels = tf.ones_like(input_points[:, :, :, 0], dtype=tf.int32)

        if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]:
            raise ValueError(
                "The batch size of the image embeddings and the input points must be the same. ",
                "Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]),
                " if you want to pass multiple points for the same image, make sure that you passed ",
                " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
                " input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
            )

        sparse_embeddings, dense_embeddings = self.prompt_encoder(
            batch_size=shape_list(image_embeddings)[0],
            input_points=input_points,
            input_labels=input_labels,
            input_boxes=input_boxes,
            input_masks=input_masks,
        )

        low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder(
            image_embeddings=image_embeddings,
            image_positional_embeddings=image_positional_embeddings,
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=multimask_output,
            output_attentions=output_attentions,
        )

        if not return_dict:
            output = (iou_predictions, low_res_masks)
            if output_hidden_states:
                output = output + (vision_hidden_states,)

            if output_attentions:
                output = output + (vision_attentions, mask_decoder_attentions)
            return output

        return TFSamImageSegmentationOutput(
            iou_scores=iou_predictions,
            pred_masks=low_res_masks,
            vision_hidden_states=vision_hidden_states,
            vision_attentions=vision_attentions,
            mask_decoder_attentions=mask_decoder_attentions,
        )

    def serving_output(self, output: TFSamImageSegmentationOutput) -> TFSamImageSegmentationOutput:
        hs = tf.convert_to_tensor(output.vision_hidden_states) if self.config.output_hidden_states else None
        attns = tf.convert_to_tensor(output.vision_attentions) if self.config.output_attentions else None

        return TFSamImageSegmentationOutput(
            iou_scores=output.iou_scores,
            pred_masks=output.pred_masks,
            vision_hidden_states=hs if self.config.output_hidden_states else None,
            vision_attentions=attns if self.config.output_attentions else None,
            mask_decoder_attentions=output.mask_decoder_attentions if self.config.output_attentions else None,
        )

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "shared_image_embedding", None) is not None:
            with tf.name_scope(self.shared_image_embedding.name):
                self.shared_image_embedding.build(None)
        if getattr(self, "vision_encoder", None) is not None:
            with tf.name_scope(self.vision_encoder.name):
                self.vision_encoder.build(None)
        if getattr(self, "prompt_encoder", None) is not None:
            with tf.name_scope(self.prompt_encoder.name):
                self.prompt_encoder.build(None)
        if getattr(self, "mask_decoder", None) is not None:
            with tf.name_scope(self.mask_decoder.name):
                self.mask_decoder.build(None)


__all__ = ["TFSamVisionModel", "TFSamModel", "TFSamPreTrainedModel"]
