# coding=utf-8
# Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformers DAC model."""

import math
from dataclasses import dataclass
from typing import Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from ...modeling_utils import PreTrainedModel
from ...utils import (
    ModelOutput,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    replace_return_docstrings,
)
from .configuration_dac import DacConfig


# General docstring
_CONFIG_FOR_DOC = "DacConfig"


@dataclass
class DacOutput(ModelOutput):
    """
    Args:
        loss (`torch.Tensor`):
            Loss from the encoder model, comprising the weighted combination of the commitment and codebook losses.
        audio_values (`torch.Tensor` of shape `(batch_size, input_length)`):
            Reconstructed audio data.
        quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
            Quantized continuous representation of input.
        audio_codes (`torch.LongTensor` of shape `(batch_size, num_codebooks, time_steps)`):
            Codebook indices for each codebook (quantized discrete representation of input).
        projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`):
            Projected latents (continuous representation of input before quantization).
    """

    loss: Optional[torch.FloatTensor] = None
    audio_values: Optional[torch.FloatTensor] = None
    quantized_representation: Optional[torch.FloatTensor] = None
    audio_codes: Optional[torch.LongTensor] = None
    projected_latents: Optional[torch.FloatTensor] = None


@dataclass
class DacEncoderOutput(ModelOutput):
    """
    Args:
        loss (`torch.Tensor`):
            Loss from the encoder model, comprising the weighted combination of the commitment and codebook losses.
        quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`, *optional*):
            Quantized continuous representation of input.
        audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`, *optional*):
            Codebook indices for each codebook (quantized discrete representation of input).
        projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`, *optional*):
            Projected latents (continuous representation of input before quantization).
    """

    loss: Optional[torch.FloatTensor] = None
    quantized_representation: Optional[torch.FloatTensor] = None
    audio_codes: Optional[torch.FloatTensor] = None
    projected_latents: Optional[torch.FloatTensor] = None


@dataclass
# Copied from transformers.models.encodec.modeling_encodec.EncodecDecoderOutput with Encodec->Dac, segment_length->input_length
class DacDecoderOutput(ModelOutput):
    """
    Args:
        audio_values (`torch.FloatTensor`  of shape `(batch_size, input_length)`, *optional*):
            Decoded audio values, obtained using the decoder part of Dac.
    """

    audio_values: Optional[torch.FloatTensor] = None


class Snake1d(nn.Module):
    """
    A 1-dimensional Snake activation function module.
    """

    def __init__(self, hidden_dim):
        super().__init__()
        self.alpha = nn.Parameter(torch.ones(1, hidden_dim, 1))

    def forward(self, hidden_states):
        shape = hidden_states.shape
        hidden_states = hidden_states.reshape(shape[0], shape[1], -1)
        hidden_states = hidden_states + (self.alpha + 1e-9).reciprocal() * torch.sin(self.alpha * hidden_states).pow(2)
        hidden_states = hidden_states.reshape(shape)
        return hidden_states


class DacVectorQuantize(nn.Module):
    """
    Implementation of VQ similar to Karpathy's repo (https://github.com/karpathy/deep-vector-quantization)

    Additionally uses following tricks from improved VQGAN
    (https://arxiv.org/pdf/2110.04627.pdf):
        1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
            for improved codebook usage
        2. l2-normalized codes: Converts euclidean distance to cosine similarity which
            improves training stability
    """

    def __init__(self, config: DacConfig):
        super().__init__()

        self.in_proj = nn.Conv1d(config.hidden_size, config.codebook_dim, kernel_size=1)
        self.out_proj = nn.Conv1d(config.codebook_dim, config.hidden_size, kernel_size=1)
        self.codebook = nn.Embedding(config.codebook_size, config.codebook_dim)

    def forward(self, hidden_state):
        """
        Quantizes the input tensor using a fixed codebook and returns the corresponding codebook vectors.

        Args:
            hidden_state (`torch.FloatTensor` of shape `(batch_size, dimension, time_steps)`):
                Input tensor.

        Returns:
            quantized_representation (`torch.Tensor`of shape `(batch_size, dimension, time_steps)`):
                Quantized continuous representation of input.
            commitment_loss (`torch.FloatTensor`of shape `(1)`):
                Commitment loss to train encoder to predict vectors closer to codebook entries.
            codebook_loss (`torch.FloatTensor`of shape `(1)`):
                Codebook loss to update the codebook.
            audio_codes (`torch.LongTensor` of shape `(batch_size, time_steps)`):
                Codebook indices for each codebook, quantized discrete representation of input.
            projected_latents (torch.FloatTensor of shape `(batch_size, num_codebooks * dimension, time_steps)`):
                Projected latents (continuous representation of input before quantization).
        """

        projected_latents = self.in_proj(hidden_state)
        quantized_representation, audio_codes = self.decode_latents(projected_latents)

        commitment_loss = F.mse_loss(projected_latents, quantized_representation.detach(), reduction="mean")
        codebook_loss = F.mse_loss(quantized_representation, projected_latents.detach(), reduction="mean")
        # noop in forward pass, straight-through gradient estimator in backward pass
        quantized_representation = projected_latents + (quantized_representation - projected_latents).detach()
        quantized_representation = self.out_proj(quantized_representation)

        return quantized_representation, commitment_loss, codebook_loss, audio_codes, projected_latents

    def decode_latents(self, hidden_states):
        batch_size, hidden_dim, sequence_length = hidden_states.shape
        encodings = hidden_states.permute(0, 2, 1).reshape(batch_size * sequence_length, hidden_dim)
        codebook = self.codebook.weight  # codebook: (N x D)

        # L2 normalize encodings and codebook (ViT-VQGAN)
        encodings = F.normalize(encodings)
        codebook = F.normalize(codebook)

        # Compute euclidean distance with codebook
        l2_norm = encodings.pow(2).sum(1, keepdim=True)
        dist = -(l2_norm - 2 * encodings @ codebook.t()) + codebook.pow(2).sum(1, keepdim=True).t()

        indices = dist.max(1)[1]
        indices = indices.reshape(hidden_states.size(0), -1)
        quantized_representation = self.codebook(indices).transpose(1, 2)
        return quantized_representation, indices


class DacResidualUnit(nn.Module):
    """
    A residual unit composed of Snake1d and weight-normalized Conv1d layers with dilations.
    """

    def __init__(self, dimension: int = 16, dilation: int = 1):
        super().__init__()
        pad = ((7 - 1) * dilation) // 2

        self.snake1 = Snake1d(dimension)
        self.conv1 = nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad)
        self.snake2 = Snake1d(dimension)
        self.conv2 = nn.Conv1d(dimension, dimension, kernel_size=1)

    def forward(self, hidden_state):
        """
        Forward pass through the residual unit.

        Args:
            hidden_state (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
                Input tensor .

        Returns:
            output_tensor (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
                Input tensor after passing through the residual unit.
        """
        output_tensor = hidden_state
        output_tensor = self.conv1(self.snake1(output_tensor))
        output_tensor = self.conv2(self.snake2(output_tensor))

        padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2
        if padding > 0:
            hidden_state = hidden_state[..., padding:-padding]
        output_tensor = hidden_state + output_tensor
        return output_tensor


class DacEncoderBlock(nn.Module):
    """Encoder block used in DAC encoder."""

    def __init__(self, config: DacConfig, stride: int = 1, stride_index: int = 1):
        super().__init__()

        dimension = config.encoder_hidden_size * 2**stride_index
        self.res_unit1 = DacResidualUnit(dimension // 2, dilation=1)
        self.res_unit2 = DacResidualUnit(dimension // 2, dilation=3)
        self.res_unit3 = DacResidualUnit(dimension // 2, dilation=9)
        self.snake1 = Snake1d(dimension // 2)
        self.conv1 = nn.Conv1d(
            dimension // 2, dimension, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)
        )

    def forward(self, hidden_state):
        hidden_state = self.res_unit1(hidden_state)
        hidden_state = self.res_unit2(hidden_state)
        hidden_state = self.snake1(self.res_unit3(hidden_state))
        hidden_state = self.conv1(hidden_state)

        return hidden_state


class DacDecoderBlock(nn.Module):
    """Decoder block used in DAC decoder."""

    def __init__(self, config: DacConfig, stride: int = 1, stride_index: int = 1):
        super().__init__()

        input_dim = config.decoder_hidden_size // 2**stride_index
        output_dim = config.decoder_hidden_size // 2 ** (stride_index + 1)
        self.snake1 = Snake1d(input_dim)
        self.conv_t1 = nn.ConvTranspose1d(
            input_dim,
            output_dim,
            kernel_size=2 * stride,
            stride=stride,
            padding=math.ceil(stride / 2),
        )

        self.res_unit1 = DacResidualUnit(output_dim, dilation=1)
        self.res_unit2 = DacResidualUnit(output_dim, dilation=3)
        self.res_unit3 = DacResidualUnit(output_dim, dilation=9)

    def forward(self, hidden_state):
        hidden_state = self.snake1(hidden_state)
        hidden_state = self.conv_t1(hidden_state)
        hidden_state = self.res_unit1(hidden_state)
        hidden_state = self.res_unit2(hidden_state)
        hidden_state = self.res_unit3(hidden_state)

        return hidden_state


class DacResidualVectorQuantize(nn.Module):
    """
    ResidualVectorQuantize block - Introduced in SoundStream: An end2end neural audio codec (https://arxiv.org/abs/2107.03312)
    """

    def __init__(self, config: DacConfig):
        super().__init__()

        n_codebooks = config.n_codebooks
        quantizer_dropout = config.quantizer_dropout

        self.n_codebooks = n_codebooks

        self.quantizers = nn.ModuleList([DacVectorQuantize(config) for i in range(config.n_codebooks)])
        self.quantizer_dropout = quantizer_dropout

    def forward(self, hidden_state, n_quantizers: Optional[int] = None):
        """
        Quantizes the input tensor using a fixed set of codebooks and returns corresponding codebook vectors.
        Args:
            hidden_state (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
                Input tensor to be quantized.
            n_quantizers (`int`, *optional*):
                Number of quantizers to use. If specified and `self.quantizer_dropout` is True,
                this argument is ignored during training, and a random number of quantizers is used.

        Returns:
            quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
                Quantized continuous representation of input.
            audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`):
                Codebook indices for each codebook (quantized discrete representation of input).
            projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`):
                Projected latents (continuous representation of input before quantization).
            commitment_loss (`torch.Tensor` of shape `(1)`):
                Commitment loss to train the encoder to predict vectors closer to codebook entries.
            codebook_loss (`torch.Tensor` of shape `(1)`):
                Codebook loss to update the codebook.
        """

        quantized_representation = 0
        residual = hidden_state
        commitment_loss = 0
        codebook_loss = 0

        audio_codes = []
        projected_latents = []

        n_quantizers = n_quantizers if n_quantizers is not None else self.n_codebooks
        if self.training:
            n_quantizers = torch.ones((hidden_state.shape[0],)) * self.n_codebooks + 1
            dropout = torch.randint(1, self.n_codebooks + 1, (hidden_state.shape[0],))
            n_dropout = int(hidden_state.shape[0] * self.quantizer_dropout)
            n_quantizers[:n_dropout] = dropout[:n_dropout]
            n_quantizers = n_quantizers.to(hidden_state.device)

        for i, quantizer in enumerate(self.quantizers):
            if self.training is False and i >= n_quantizers:
                break

            quantized_representation_i, commitment_loss_i, codebook_loss_i, indices_i, projected_latents_i = quantizer(
                residual
            )

            # Create mask to apply quantizer dropout
            mask = torch.full((hidden_state.shape[0],), fill_value=i, device=hidden_state.device) < n_quantizers
            quantized_representation = quantized_representation + quantized_representation_i * mask[:, None, None]
            residual = residual - quantized_representation_i

            # Sum losses
            commitment_loss += commitment_loss_i * mask
            codebook_loss += codebook_loss_i * mask

            audio_codes.append(indices_i)
            projected_latents.append(projected_latents_i)

        audio_codes = torch.stack(audio_codes, dim=1)
        projected_latents = torch.cat(projected_latents, dim=1)

        return quantized_representation, audio_codes, projected_latents, commitment_loss, codebook_loss

    def from_codes(self, audio_codes: torch.Tensor):
        """
        Reconstructs the continuous representation from quantized codes.

        Args:
            audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`):
                Quantized discrete representation of input.

        Returns:
            quantized_representation (`torch.Tensor`):
                Quantized continuous representation of input.
            projected_latents (`torch.Tensor`):
                List of projected latents (continuous representations of input before quantization)
                for each codebook.
            audio_codes (`torch.Tensor`):
                Codebook indices for each codebook.
        """
        quantized_representation = 0.0
        projected_latents = []
        n_codebooks = audio_codes.shape[1]
        for i in range(n_codebooks):
            projected_latents_i = self.quantizers[i].codebook(audio_codes[:, i, :]).transpose(1, 2)
            projected_latents.append(projected_latents_i)
            quantized_representation += self.quantizers[i].out_proj(projected_latents_i)
        return quantized_representation, torch.cat(projected_latents, dim=1), audio_codes

    def from_latents(self, latents: torch.Tensor):
        """Reconstructs the quantized representation from unquantized latents.

        Args:
            latents (`torch.Tensor` of shape `(batch_size, total_latent_dimension, time_steps)`):
                Continuous representation of input after projection.

        Returns:
            quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
                Quantized representation of the full-projected space.
            quantized_latents (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
                Quantized representation of the latent space (continuous representation before quantization).
        """
        quantized_representation = 0
        quantized_latents = []
        codes = []
        codebook_dims_tensor = torch.tensor([0] + [q.codebook_dim for q in self.quantizers])
        dims = torch.cumsum(codebook_dims_tensor, dim=0)

        n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0]
        for i in range(n_codebooks):
            hidden_dim_j, hidden_dim_k = dims[i], dims[i + 1]
            quantized_latents_i, codes_i = self.quantizers[i].decode_latents(latents[:, hidden_dim_j:hidden_dim_k, :])
            quantized_latents.append(quantized_latents_i)
            codes.append(codes_i)

            quantized_representation_i = self.quantizers[i].out_proj(quantized_latents_i)
            quantized_representation = quantized_representation + quantized_representation_i

        return quantized_representation, torch.cat(quantized_latents, dim=1)


class DacDecoder(nn.Module):
    """DAC Decoder"""

    def __init__(self, config: DacConfig):
        super().__init__()

        input_channel = config.hidden_size
        channels = config.decoder_hidden_size
        strides = config.upsampling_ratios

        # Add first conv layer
        self.conv1 = nn.Conv1d(input_channel, channels, kernel_size=7, padding=3)

        # Add upsampling + MRF blocks
        block = []
        for stride_index, stride in enumerate(strides):
            block += [DacDecoderBlock(config, stride, stride_index)]

        self.block = nn.ModuleList(block)
        output_dim = config.decoder_hidden_size // 2 ** (stride_index + 1)
        self.snake1 = Snake1d(output_dim)
        self.conv2 = nn.Conv1d(output_dim, 1, kernel_size=7, padding=3)
        self.tanh = nn.Tanh()

    def forward(self, hidden_state):
        hidden_state = self.conv1(hidden_state)

        for layer in self.block:
            hidden_state = layer(hidden_state)

        hidden_state = self.snake1(hidden_state)
        hidden_state = self.conv2(hidden_state)
        hidden_state = self.tanh(hidden_state)

        return hidden_state


class DacEncoder(nn.Module):
    """DAC Encoder"""

    def __init__(self, config: DacConfig):
        super().__init__()

        strides = config.downsampling_ratios
        # Create first convolution
        self.conv1 = nn.Conv1d(1, config.encoder_hidden_size, kernel_size=7, padding=3)

        self.block = []
        # Create EncoderBlocks that double channels as they downsample by `stride`
        for stride_index, stride in enumerate(strides):
            stride_index = stride_index + 1
            self.block += [DacEncoderBlock(config, stride=stride, stride_index=stride_index)]

        self.block = nn.ModuleList(self.block)
        d_model = config.encoder_hidden_size * 2**stride_index
        self.snake1 = Snake1d(d_model)
        self.conv2 = nn.Conv1d(d_model, config.hidden_size, kernel_size=3, padding=1)

    def forward(self, hidden_state):
        hidden_state = self.conv1(hidden_state)

        for module in self.block:
            hidden_state = module(hidden_state)

        hidden_state = self.snake1(hidden_state)
        hidden_state = self.conv2(hidden_state)

        return hidden_state


class DacPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
    """

    config_class = DacConfig
    base_model_prefix = "dac"
    main_input_name = "input_values"

    def _init_weights(self, module):
        if isinstance(module, nn.Conv1d):
            nn.init.trunc_normal_(module.weight, std=0.02)
            nn.init.constant_(module.bias, 0)

    def apply_weight_norm(self):
        weight_norm = nn.utils.weight_norm
        if hasattr(nn.utils.parametrizations, "weight_norm"):
            weight_norm = nn.utils.parametrizations.weight_norm

        for layer in self.quantizer.quantizers:
            weight_norm(layer.in_proj)
            weight_norm(layer.out_proj)

        weight_norm(self.encoder.conv1)
        weight_norm(self.encoder.conv2)

        for layer in self.encoder.block:
            weight_norm(layer.conv1)
            weight_norm(layer.res_unit1.conv1)
            weight_norm(layer.res_unit1.conv2)
            weight_norm(layer.res_unit2.conv1)
            weight_norm(layer.res_unit2.conv2)
            weight_norm(layer.res_unit3.conv1)
            weight_norm(layer.res_unit3.conv2)

        weight_norm(self.decoder.conv1)
        weight_norm(self.decoder.conv2)

        for layer in self.decoder.block:
            weight_norm(layer.conv_t1)
            weight_norm(layer.res_unit1.conv1)
            weight_norm(layer.res_unit1.conv2)
            weight_norm(layer.res_unit2.conv1)
            weight_norm(layer.res_unit2.conv2)
            weight_norm(layer.res_unit3.conv1)
            weight_norm(layer.res_unit3.conv2)

    def remove_weight_norm(self):
        for layer in self.quantizer.quantizers:
            nn.utils.remove_weight_norm(layer.in_proj)
            nn.utils.remove_weight_norm(layer.out_proj)

        nn.utils.remove_weight_norm(self.encoder.conv1)
        nn.utils.remove_weight_norm(self.encoder.conv2)

        for layer in self.encoder.block:
            nn.utils.remove_weight_norm(layer.conv1)
            nn.utils.remove_weight_norm(layer.res_unit1.conv1)
            nn.utils.remove_weight_norm(layer.res_unit1.conv2)
            nn.utils.remove_weight_norm(layer.res_unit2.conv1)
            nn.utils.remove_weight_norm(layer.res_unit2.conv2)
            nn.utils.remove_weight_norm(layer.res_unit3.conv1)
            nn.utils.remove_weight_norm(layer.res_unit3.conv2)

        nn.utils.remove_weight_norm(self.decoder.conv1)
        nn.utils.remove_weight_norm(self.decoder.conv2)

        for layer in self.decoder.block:
            nn.utils.remove_weight_norm(layer.conv_t1)
            nn.utils.remove_weight_norm(layer.res_unit1.conv1)
            nn.utils.remove_weight_norm(layer.res_unit1.conv2)
            nn.utils.remove_weight_norm(layer.res_unit2.conv1)
            nn.utils.remove_weight_norm(layer.res_unit2.conv2)
            nn.utils.remove_weight_norm(layer.res_unit3.conv1)
            nn.utils.remove_weight_norm(layer.res_unit3.conv2)


DAC_START_DOCSTRING = r"""
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`DacConfig`]):
            Model configuration class with all the parameters of the model. Initializing with a config file does not
            load the weights associated with the model, only the configuration. Check out the
            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

DAC_INPUTS_DOCSTRING = r"""
    Args:
        input_values (`torch.Tensor` of shape `(batch_size, 1, time_steps)`).
            Audio data to encode,
        n_quantizers (`int`, *optional*):
            Number of quantizers to use. If `None`, all quantizers are used. Default is `None`.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""


@add_start_docstrings(
    "The DAC (Descript Audio Codec) model.",
    DAC_START_DOCSTRING,
)
class DacModel(DacPreTrainedModel):
    def __init__(self, config: DacConfig):
        super().__init__(config)
        self.config = config

        self.encoder = DacEncoder(config)
        self.decoder = DacDecoder(config)

        self.quantizer = DacResidualVectorQuantize(config)

        self.bits_per_codebook = int(math.log2(self.config.codebook_size))
        if 2**self.bits_per_codebook != self.config.codebook_size:
            raise ValueError("The codebook_size must be a power of 2.")

        # Initialize weights and apply final processing
        self.post_init()

    @replace_return_docstrings(output_type=DacEncoderOutput, config_class=_CONFIG_FOR_DOC)
    def encode(
        self,
        input_values: torch.Tensor,
        n_quantizers: Optional[int] = None,
        return_dict: Optional[bool] = None,
    ):
        """
        Encode given audio data and return quantized latent codes

        Args:
            input_values (`torch.Tensor of shape `(batch_size, 1, time_steps)`):
                Input audio data to encode,
            n_quantizers (int, *optional*):
                Number of quantizers to use. If None, all quantizers are used. Default is None.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        Returns:

        """
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        quantized_representation = self.encoder(input_values)
        quantized_representation, audio_codes, projected_latents, commitment_loss, codebook_loss = self.quantizer(
            quantized_representation, n_quantizers
        )

        loss = self.config.commitment_loss_weight * commitment_loss + self.config.codebook_loss_weight * codebook_loss

        if not return_dict:
            return (loss, quantized_representation, audio_codes, projected_latents)

        return DacEncoderOutput(loss, quantized_representation, audio_codes, projected_latents)

    @replace_return_docstrings(output_type=DacDecoderOutput, config_class=_CONFIG_FOR_DOC)
    def decode(
        self,
        quantized_representation: Optional[torch.Tensor] = None,
        audio_codes: Optional[torch.Tensor] = None,
        return_dict: Optional[bool] = None,
    ):
        """Decode given latent codes and return audio data

        Args:
            quantized_representation (torch.Tensor of shape `(batch_size, dimension, time_steps)`, *optional*):
                Quantized continuous representation of input.
            audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`, *optional*):
                The codebook indices for each codebook, representing the quantized discrete
                representation of the input. This parameter should be provided if you want
                to decode directly from the audio codes (it will overwrite quantized_representation).
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.

        Returns:

        """

        if quantized_representation is None and audio_codes is None:
            raise ValueError("Either `quantized_representation` or `audio_codes` must be provided.")

        return_dict = return_dict if return_dict is not None else self.config.return_dict

        if audio_codes is not None:
            quantized_representation = self.quantizer.from_codes(audio_codes)[0]

        audio_values = self.decoder(quantized_representation).squeeze(1)

        if not return_dict:
            return (audio_values,)

        return DacDecoderOutput(audio_values)

    @add_start_docstrings_to_model_forward(DAC_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=DacOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_values: torch.Tensor,
        n_quantizers: Optional[int] = None,
        return_dict: Optional[bool] = None,
    ):
        """
        Returns:
        Examples:

        ```python
        >>> from datasets import load_dataset, Audio
        >>> from transformers import DacModel, AutoProcessor
        >>> librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

        >>> model = DacModel.from_pretrained("descript/dac_16khz")
        >>> processor = AutoProcessor.from_pretrained("descript/dac_16khz")
        >>> librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
        >>> audio_sample = librispeech_dummy[-1]["audio"]["array"]
        >>> inputs = processor(raw_audio=audio_sample, sampling_rate=processor.sampling_rate, return_tensors="pt")

        >>> encoder_outputs = model.encode(inputs["input_values"])
        >>> # Get the intermediate audio codes
        >>> audio_codes = encoder_outputs.audio_codes
        >>> # Reconstruct the audio from its quantized representation
        >>> audio_values = model.decode(encoder_outputs.quantized_representation)
        >>> # or the equivalent with a forward pass
        >>> audio_values = model(inputs["input_values"]).audio_values
        ```"""

        return_dict = return_dict if return_dict is not None else self.config.return_dict
        length = input_values.shape[-1]
        loss, quantized_representation, audio_codes, projected_latents = self.encode(
            input_values, n_quantizers, return_dict=False
        )
        audio_values = self.decode(quantized_representation, return_dict=False)[0][..., :length]

        if not return_dict:
            return (loss, audio_values, quantized_representation, audio_codes, projected_latents)

        return DacOutput(loss, audio_values, quantized_representation, audio_codes, projected_latents)


__all__ = ["DacModel", "DacPreTrainedModel"]
