# mypy: allow-untyped-defs
from __future__ import annotations

import io
import logging
import os
from typing import IO, TYPE_CHECKING

import torch
from torch.onnx import _type_utils as jit_type_utils


if TYPE_CHECKING:
    import onnx

    from torch.types import FileLike

log = logging.getLogger(__name__)


def _create_tensor_proto_with_external_data(
    tensor: torch.Tensor,
    name: str,
    location: str,
    basepath: str,
    dtype_override: onnx.TypeProto | None = None,  # type: ignore[name-defined]
) -> onnx.TensorProto:  # type: ignore[name-defined]
    """Create a TensorProto with external data from a PyTorch tensor.
    The external data is saved to os.path.join(basepath, location).

    Args:
        tensor: Tensor to be saved.
        name: Name of the tensor (i.e., initializer name in ONNX graph).
        location: Relative location of the external data file
            (e.g., "/tmp/initializers/weight_0" when model is "/tmp/model_name.onnx").
        basepath: Base path of the external data file (e.g., "/tmp/external_data" while model must be in "/tmp").


    Reference for ONNX's external data format:
        How to load?
        https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L187
        How to save?
        https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L43
        How to set ONNX fields?
        https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L88
    """
    # FIXME: Avoid importing onnx into torch.onnx.
    import onnx

    scalar_type = (
        jit_type_utils.JitScalarType.from_onnx_type(
            dtype_override.tensor_type.elem_type
        )
        if dtype_override is not None
        else jit_type_utils.JitScalarType.from_dtype(tensor.dtype)
    )

    # Checkpoints can be stored with a different dtype as the model expects because
    # the user script can explicitly cast the original type to something or maybe
    # PyTorch's type promotion might do it
    if dtype_override is not None and scalar_type.dtype() != tensor.dtype:
        tensor = tensor.to(scalar_type.dtype())

    tensor_proto = onnx.TensorProto()  # type: ignore[attr-defined]
    tensor_proto.name = name
    tensor_proto.data_type = scalar_type.onnx_type()  # type: ignore[assignment]

    tensor_proto.dims.extend(tensor.shape)
    tensor_proto.data_location = onnx.TensorProto.EXTERNAL  # type: ignore[attr-defined]

    # Settings for saving one tensor per file.
    # Offset is zero because there is no other tensor in the same file.
    key_value_pairs = {
        "location": location,
        "offset": 0,
        "length": tensor.untyped_storage().nbytes(),
    }
    for k, v in key_value_pairs.items():
        entry = tensor_proto.external_data.add()
        entry.key = k
        entry.value = str(v)

    # Actual path to write content of tensor.
    external_data_file_path = os.path.join(basepath, location)
    if os.path.exists(external_data_file_path):
        os.remove(external_data_file_path)

    # Create external data's folder if not exists.
    external_data_dir_path = os.path.dirname(external_data_file_path)
    if not os.path.exists(external_data_dir_path):
        # if the demo_folder directory is not present
        # then create it.
        os.makedirs(external_data_dir_path)

    # Create a fresh file.
    with open(external_data_file_path, "xb") as data_file:
        # No need to call "seek" because offset is 0.
        # data_file.seek(0)
        # Write tensor content to the file.
        data_file.write(tensor.numpy(force=True).tobytes())

    return tensor_proto


def _convert_safetensors_to_torch_format(safetensors_file):
    # It this function is called, safetensors is guaranteed to exist
    # because the HF model with safetensors was already loaded and exported to ONNX
    from safetensors import safe_open  # type: ignore[import-not-found, import-untyped]

    tensors = {}
    with safe_open(safetensors_file, framework="pt", device="cpu") as f:  # type: ignore[attr-defined]
        for k in f.keys():
            tensors[k] = f.get_tensor(k).cpu()
    return tensors


# TODO: generalize to allow more checkpoints formats (torch or gguf)
def save_model_with_external_data(
    basepath: str,
    model_location: str,
    initializer_location: str,
    torch_state_dicts: tuple[dict | FileLike, ...],
    onnx_model: onnx.ModelProto,  # type: ignore[name-defined]
    rename_initializer: bool = False,
) -> None:
    """Load PyTorch tensors from files and add to "onnx_model" as external initializers.

    Output files:
        ONNX model file path:
        ONNX initializer folder: os.path.join(basepath, initializer_location)

    After running this function, you can do
        ort_sess = onnxruntime.InferenceSession(os.path.join(basepath, model_location))
    to execute the model.

    Arguments:
        basepath: Base path of the ONNX external data file (e.g., "/path/to/large_model/").
        model_location: Relative location of the ONNX model file.
            E.g., "model.onnx" so that the model file is saved to
            "<basepath>/model.onnx".
        initializer_location: Relative location of the ONNX initializer folder.
            E.g., "initializers" so that the initializers are saved to
            "<basepath>/initializers/".
            Note: When initializers are >2GB, must be the same as `model_location`.
        torch_state_dicts: Dictionaries or files which contain PyTorch tensors to be saved
            as ONNX initializers. For non-dict arguments, `torch.load` will be used to load them from file-like objects.
        onnx_model: ONNX model to be saved with external initializers.
            If an input name matches a tensor loaded from "torch_state_dicts",
            the tensor will be saved as that input's external initializer.
        rename_initializer: Replaces "." by "_" for all ONNX initializer names.
            Not needed by the official torch.onnx.dynamo_export. This is a hack
            for supporting `FXSymbolicTracer` tracer with fake tensor mode.
            In short, `FXSymbolicTracer` lifts FX parameters (self.linear_weight)
            as inputs (`def forward(self, linear_weight)`) and therefore, `.` cannot be used.
    """
    # FIXME: Avoid importing onnx into torch.onnx.
    import onnx

    initializers_to_be_deleted = {}  # Using dict because it is **ordered**
    existing_initializers = {
        k.name: idx for idx, k in enumerate(onnx_model.graph.initializer)
    }
    onnx_input_names = {input.name for input in onnx_model.graph.input}
    for el in torch_state_dicts:
        if isinstance(el, dict):
            # Useful for when state_dict is loaded with torch.load(..., mmap=True, map_location="cpu") by the user
            # Using torch.save wouldn't leverage mmap, leading to higher memory usage
            state_dict = el
        else:
            if isinstance(el, (str, os.PathLike)) and os.fspath(el).endswith(
                ".safetensors"
            ):
                state_dict = _convert_safetensors_to_torch_format(el)
            else:
                try:
                    # Loads checkpoint using memory-map on CPU to support really large models
                    # The underlying torch.UntypedStorage is memory mapped, so state_dict is lazy loaded
                    state_dict = torch.load(el, map_location="cpu", mmap=True)
                except (RuntimeError, ValueError) as e:
                    if "mmap can only be used with files saved with" in str(e) or (
                        isinstance(el, (io.IOBase, IO))
                        and el.readable()
                        and el.seekable()
                    ):
                        log.warning(
                            "Failed to load the checkpoint with memory-map enabled, retrying without memory-map."
                            "Consider updating the checkpoint with mmap by using torch.save() on PyTorch version >= 1.6."
                        )
                        if isinstance(el, (io.IOBase, IO)):
                            el.seek(0)  # torch.load from `try:` has read the file.
                        state_dict = torch.load(el, map_location="cpu")
                    else:
                        raise e

        for name, tensor in state_dict.items():
            if rename_initializer:
                # Basically, "transformer.attention.self.query.weight" is mapped
                # to "transformer_attention_self_query_weight" for mimicking the
                # name-modifying code in FX-to-ONNX exporter.
                # See function _replace_get_attr_with_placeholder for details.
                name = name.replace(".", "_")

            # This block tries to match the onnx initializer name with torch parameter/buffer
            #  e.g. A pytorch buffer 'transformer.h.0.attn.bias' can be named 'h.0.attn.bias' in a ONNX initializer
            # For each PyTorch tensor name loaded by torch.load,
            #  1.  Search its best match in ONNX model. E.g., the match of
            #       "transformer_attention_weight" could be "attention_weight".
            #  2.  Set "tensor" as the initializer of the matched ONNX input.
            #      E.g., "tensor" is stored as the initializer of "attention_weight".
            # Step 1 is required because sometimes, tensor names are stored with prefix the dictionary
            # loaded by torch.load.
            if name in onnx_input_names:
                # Same input name shouldn't be matched again
                onnx_input_names.remove(name)
            else:
                for onnx_input_name in onnx_input_names:
                    if onnx_input_name.endswith(name) or name.endswith(onnx_input_name):
                        # Find a match. Change name to the matched ONNX input name, so that we
                        # create initializer with the right ONNX name.
                        name = onnx_input_name
                        onnx_input_names.remove(onnx_input_name)
                        break

            relative_tensor_file_path = os.path.join(initializer_location, name)
            # Create one file per tensor.
            # tensor_proto.raw_data is stored to external file at
            # os.path.join(basepath, relative_tensor_file_path).
            model_input_types = {k.name: k.type for k in onnx_model.graph.input}

            # Mark for deletion - a replacement will be appended next
            if name in existing_initializers:
                initializers_to_be_deleted[existing_initializers[name]] = name
            tensor_proto = _create_tensor_proto_with_external_data(
                tensor,
                name,
                relative_tensor_file_path,
                basepath,
                model_input_types.pop(name, None),
            )
            # Add the tensor_proto to the ONNX model as an initializer with external data.
            onnx_model.graph.initializer.append(tensor_proto)
    # Remove old duplicated initializers, if any. delete in desc order to not invalidate deletion indices
    initializers_to_be_deleted = dict(
        sorted(initializers_to_be_deleted.items(), reverse=True)
    )
    for idx in initializers_to_be_deleted.keys():
        del onnx_model.graph.initializer[idx]

    # model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx".
    onnx.save(onnx_model, os.path.join(basepath, model_location))  # type: ignore[attr-defined]
