# mypy: allow-untyped-defs
import io
from typing import Any, Callable, cast

import torch
import torch.distributed as dist
from torch._utils import _get_device_module
from torch.distributed._shard.metadata import ShardMetadata
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed.tensor import DTensor
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset

from .metadata import (
    BytesStorageMetadata,
    ChunkStorageMetadata,
    MetadataIndex,
    STATE_DICT_TYPE,
    STORAGE_TYPES,
    TensorProperties,
    TensorStorageMetadata,
)
from .planner import (
    LoadItemType,
    ReadItem,
    SavePlan,
    TensorWriteData,
    WriteItem,
    WriteItemType,
)
from .resharding import (
    _check_shard_metadata_pair_overlap,
    _shards_get_overlap_region_wrt_saved_tensor,
)


__all__: list[str] = ["create_read_items_for_chunk_list"]


def _compare_save_plans(plan: SavePlan, other_plan: SavePlan) -> bool:
    """
    Compare the two Save plans and return True if they are equal.

    Args:
        plan (SavePlan): First SavePlan to compare.
        other_plan (SavePlan): Second SavePlan to compare.

    Returns:
       True if the two plans are equal, False otherwise.
    """
    if plan.usable != other_plan.usable:
        return False

    # Both the plans should have the same number of items
    if len(plan.items) != len(other_plan.items):
        return False

    # Both the plans should have the same write items.
    for plan_item, other_plan_item in zip(plan.items, other_plan.items):
        # Write item type should be same
        if plan_item.type != other_plan_item.type:
            return False

        plan_metadata_index = plan_item.index
        other_plan_metadata_index = other_plan_item.index

        # Write item metadata_index should be same
        if (
            plan_metadata_index.fqn != other_plan_metadata_index.fqn
            or plan_metadata_index.offset != other_plan_metadata_index.offset
            or plan_metadata_index.index != other_plan_metadata_index.index
        ):
            return False

        # Write item tensor_data should be present in both the write items plans, if it exists in either of them.
        tensor_data = plan_item.tensor_data
        other_tensor_data = other_plan_item.tensor_data
        if (tensor_data and not other_tensor_data) or (
            not tensor_data and other_tensor_data
        ):
            return False

        if tensor_data and other_tensor_data:
            # Write item tensor_data size should be same
            if tensor_data.size != other_tensor_data.size:
                return False

            # Write item tensor_data chunk should be present in both the write items, if it exists in either of them.
            chunk = tensor_data.chunk
            other_chunk = other_tensor_data.chunk
            if (chunk and not other_chunk) or (not chunk and other_chunk):
                return False

            # Write item tensor_data chunk offsets and sizes should be same
            if chunk and other_chunk:
                if (
                    chunk.offsets != other_chunk.offsets
                    or chunk.sizes != other_chunk.sizes
                ):
                    return False

    return True


def _merge_delta_local_plans(
    cached_plans: list[SavePlan],
    delta_plans: list[SavePlan],
) -> list[SavePlan]:
    """
    Merge a list of delta plans into a single plan.

    Args:
        cached_plans (List[SavePlan]): A list of cached plans.
        delta_plans (List[SavePlan]): A list of delta plans to merge. It can contain empty plans

    Returns:
        A single merged plan. If a delta plan is not usable, use the cached plan. Otherwise, use the delta plan.
    """
    merged_plans = []

    for cached_plan, delta_plan in zip(cached_plans, delta_plans):
        if delta_plan and not delta_plan.usable:
            merged_plans.append(cached_plan)
        else:
            merged_plans.append(delta_plan)

    return merged_plans


def _create_chunk_from_tensor(tensor: torch.Tensor) -> ChunkStorageMetadata:
    return ChunkStorageMetadata(
        offsets=torch.Size([0] * len(tensor.size())), sizes=tensor.size()
    )


def _chunk_for_shard(shard_md: ShardMetadata) -> ChunkStorageMetadata:
    return ChunkStorageMetadata(
        offsets=torch.Size(shard_md.shard_offsets),
        sizes=torch.Size(shard_md.shard_sizes),
    )


def _sharded_tensor_metadata(
    sharded_tensor: ShardedTensor, shard_md: ShardMetadata
) -> TensorWriteData:
    shard_properties = sharded_tensor.metadata().tensor_properties

    properties = TensorProperties(
        dtype=shard_properties.dtype,
        layout=shard_properties.layout,
        requires_grad=shard_properties.requires_grad,
        memory_format=shard_properties.memory_format,
        pin_memory=shard_properties.pin_memory,
    )

    return TensorWriteData(
        chunk=_chunk_for_shard(shard_md),
        properties=properties,
        size=sharded_tensor.metadata().size,
    )


def _create_write_items_for_dtensor(fqn: str, tensor: DTensor) -> WriteItem:
    sizes, offsets = compute_local_shape_and_global_offset(
        tensor.shape, tensor.device_mesh, tensor.placements
    )
    sizes, offsets = torch.Size(sizes), torch.Size(offsets)

    return WriteItem(
        index=MetadataIndex(fqn, offsets),
        type=WriteItemType.SHARD,
        tensor_data=TensorWriteData(
            chunk=ChunkStorageMetadata(
                offsets=offsets,
                sizes=sizes,
            ),
            properties=TensorProperties.create_from_tensor(tensor.to_local()),
            size=tensor.size(),
        ),
    )


def _create_write_item_for_shard(
    fqn: str, sharded_tensor: ShardedTensor, shard_md: ShardMetadata
) -> WriteItem:
    offsets = torch.Size(shard_md.shard_offsets)
    return WriteItem(
        index=MetadataIndex(fqn, offsets),
        type=WriteItemType.SHARD,
        tensor_data=_sharded_tensor_metadata(sharded_tensor, shard_md),
    )


def _create_write_item_for_tensor(fqn: str, tensor: torch.Tensor) -> WriteItem:
    offsets = torch.Size([0] * len(tensor.size()))
    return WriteItem(
        index=MetadataIndex(fqn, offsets),
        type=WriteItemType.TENSOR,
        tensor_data=TensorWriteData(
            chunk=ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()),
            properties=TensorProperties.create_from_tensor(tensor),
            size=tensor.size(),
        ),
    )


def _create_write_item_for_bytesio(fqn: str, bytes: Any):
    return WriteItem(
        index=MetadataIndex(fqn),
        type=WriteItemType.BYTE_IO,
    )


def _create_read_item_for_byteio(
    dest_index, dest_offset, storage_index, storage_offset, length
):
    return ReadItem(
        type=LoadItemType.BYTE_IO,
        dest_index=dest_index,
        dest_offsets=torch.Size((dest_offset,)),
        storage_index=storage_index,
        storage_offsets=torch.Size((storage_offset,)),
        lengths=torch.Size((length,)),
    )


def _create_read_item_for_tensor(
    dest_index, dest_offsets, storage_index, storage_offsets, lengths
):
    return ReadItem(
        type=LoadItemType.TENSOR,
        dest_index=dest_index,
        dest_offsets=torch.Size(dest_offsets),
        storage_index=storage_index,
        storage_offsets=torch.Size(storage_offsets),
        lengths=torch.Size(lengths),
    )


def create_read_items_for_chunk_list(
    fqn: str,
    checkpoint_md: TensorStorageMetadata,
    local_chunks: list[ChunkStorageMetadata],
) -> list[ReadItem]:
    """
    Create a list of ``ReadItem`` based on the checkpoint and local chunks.

    This applies the resharding algorithm and computes the reads needed
    to satisfy ``local_chunks`` with a checkpoint described by ``checkpoint_md``.

    Args:
        fqn (str) : The state_dict FQN to pass to ``ReadItem``.
        checkpoint_md (TensorStorageMetadata): metadata for a given tensor
            from a checkpoint.
        local_chunks (List[ChunkStorageMetadata]): Local chunks that needs to be
            loaded.

    Returns:
        A list of ``ReadItem`` that will satisfy all input chunks.
    """
    read_items = []
    # this is a naive quadratic algo that can be optimized later
    for idx, shard in enumerate(local_chunks):
        for storage_idx, storage_md in enumerate(checkpoint_md.chunks):
            if not _check_shard_metadata_pair_overlap(shard, storage_md):
                continue

            storage_offsets = []
            dest_offsets = []
            lengths = []
            for (
                _dim,
                offset_for_saved_tensor,
                offset_for_current_tensor,
                length,
            ) in _shards_get_overlap_region_wrt_saved_tensor(
                saved_shard=storage_md, current_shard=shard
            ):
                storage_offsets.append(offset_for_saved_tensor)
                dest_offsets.append(offset_for_current_tensor)
                lengths.append(length)

            read_items.append(
                _create_read_item_for_tensor(
                    dest_index=MetadataIndex(fqn, shard.offsets, idx),
                    dest_offsets=dest_offsets,
                    storage_index=MetadataIndex(fqn, storage_md.offsets, storage_idx),
                    storage_offsets=storage_offsets,
                    lengths=lengths,
                )
            )
    return read_items


def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan:
    requests = []
    for fqn, obj in state_dict.items():
        if isinstance(obj, DTensor):
            requests.append(_create_write_items_for_dtensor(fqn, obj))
        elif isinstance(obj, ShardedTensor):
            requests.extend(
                _create_write_item_for_shard(fqn, obj, shard_md)
                for shard_md in obj.metadata().shards_metadata
            )
        elif isinstance(obj, torch.Tensor):
            requests.append(_create_write_item_for_tensor(fqn, obj))
        else:
            requests.append(_create_write_item_for_bytesio(fqn, obj))
    return SavePlan(requests)


def _create_write_items(fqn: str, object: Any) -> list[WriteItem]:
    if hasattr(object, "__create_write_items__"):
        # DTensor implements _Checkpointable
        return object.__create_write_items__(fqn, object)
    elif isinstance(object, ShardedTensor):
        return [
            _create_write_item_for_shard(fqn, object, shard.metadata)
            for shard in object.local_shards()
        ]
    elif isinstance(object, torch.Tensor):
        return [_create_write_item_for_tensor(fqn, object)]
    else:
        return [_create_write_item_for_bytesio(fqn, object)]


def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata:
    sizes, offsets = compute_local_shape_and_global_offset(
        tensor.shape, tensor.device_mesh, tensor.placements
    )
    sizes, offsets = torch.Size(sizes), torch.Size(offsets)
    return ChunkStorageMetadata(
        offsets=offsets,
        sizes=sizes,
    )


def _create_chunk_list(tensor: torch.Tensor) -> list[ChunkStorageMetadata]:
    if hasattr(tensor, "__create_chunk_list__"):
        # DTensor implements _Checkpointable
        local_chunks = tensor.__create_chunk_list__()  # type: ignore[attr-defined]
    elif isinstance(tensor, ShardedTensor):
        local_chunks = [
            _chunk_for_shard(shard.metadata) for shard in tensor.local_shards()
        ]
    elif isinstance(tensor, torch.Tensor):
        local_chunks = [_create_chunk_from_tensor(tensor)]
    else:
        raise ValueError(
            "Unsupported Type, expecting one of [Tensor, DTensor, ShardedTensor] "
            f",but got {type(tensor)}"
        )

    return local_chunks


def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> list[ReadItem]:
    if not isinstance(md, BytesStorageMetadata):
        try:
            local_chunks = _create_chunk_list(obj)
        except ValueError as ex:
            raise ValueError(
                f"Invalid checkpoint metadata for {fqn}, "
                + f"expected BytesStorageMetadata but found {type(md)}",
            ) from ex

        return create_read_items_for_chunk_list(fqn, md, local_chunks)
    else:
        return [
            _create_read_item_for_byteio(
                dest_index=MetadataIndex(fqn),
                dest_offset=0,
                storage_index=MetadataIndex(fqn),
                storage_offset=0,
                length=0,
            )
        ]


def _init_state_dict(state_dict: dict[str, Any]) -> Any:
    """
    Initializes meta tensor if the meta tensor is DTensor or torch.Tensor.
    """

    def dtensor_func(value: DTensor):
        device = getattr(value, "device", None)
        if device == torch.device("meta"):
            device_type = dist.distributed_c10d._get_pg_default_device().type
            device = cast(
                torch.device, _get_device_module(device_type).current_device()
            )
            new_local_tensor = torch.empty_like(value.to_local(), device=device)
            # We need to pass shape and stride explicitly, since DTensor might be
            # sharded unevenly.
            dtensor = DTensor.from_local(
                new_local_tensor,
                device_mesh=value.device_mesh,
                placements=value.placements,
                shape=value.size(),
                stride=value.stride(),
            )
            return dtensor
        else:
            return value

    def sharded_tensor_func(value: Any):
        device = getattr(value, "device", None)
        if device == torch.device("meta"):
            raise RuntimeError(
                f"Found unsupported type {type(value)} for meta device loading."
            )
        else:
            return value

    def tensor_func(value: torch.Tensor):
        device = getattr(value, "device", None)
        if device == torch.device("meta"):
            device_type = dist.distributed_c10d._get_pg_default_device().type
            device = cast(
                torch.device, _get_device_module(device_type).current_device()
            )
            tensor = torch.empty_like(value, device=device)
            return tensor
        else:
            return value

    _iterate_state_dict(
        state_dict,
        dtensor_func,
        sharded_tensor_func,
        tensor_func,
    )


def _iterate_state_dict(
    iter_object: Any,
    dtensor_func: Callable,
    sharded_tensor_func: Callable,
    tensor_func: Callable,
):
    """
    Iterate through the state dict, applying the given functions to each tensor type
    and update the state dict in place.

    Args:
        iter_object (Any): the target state_dict.
        sharded_tensor_func (Callable): the function to apply to ShardedTensor
        dtensor_func (Callable): the function to apply to DTensor
        tensor_func (Callable): the function to apply to Tensor

    # TODO: let state_dict_util._iterate_state_dict() to support in place option
    so we don't need to have two versions of _iterate_state_dict.
    """

    if isinstance(iter_object, DTensor):
        return dtensor_func(iter_object)
    elif isinstance(iter_object, ShardedTensor):
        return sharded_tensor_func(iter_object)
    elif isinstance(iter_object, torch.Tensor):
        return tensor_func(iter_object)
    elif (
        isinstance(iter_object, (int, float, str, bytes, io.BytesIO))
        or iter_object is None
    ):
        return iter_object
    elif isinstance(iter_object, dict):
        for key, value in iter_object.items():
            iter_object[key] = _iterate_state_dict(
                value, dtensor_func, sharded_tensor_func, tensor_func
            )
        return iter_object
    elif isinstance(iter_object, (list, tuple)):
        ret = [
            _iterate_state_dict(v, dtensor_func, sharded_tensor_func, tensor_func)
            for v in iter_object
        ]
        if isinstance(iter_object, tuple):
            ret = tuple(ret)  # type: ignore[assignment]
        return ret
