# mypy: allow-untyped-defs
from __future__ import annotations

import contextlib
from typing import Callable, TYPE_CHECKING

import torch
import torch._ops
from torch._dispatch import python as python_dispatch
from torch._subclasses import fake_tensor
from torch.fx.experimental import proxy_tensor
from torch.onnx._internal.fx import _pass, diagnostics
from torch.onnx._internal.fx.passes import _utils


if TYPE_CHECKING:
    from collections.abc import Mapping

    import torch.fx


class Decompose(_pass.Transform):
    def __init__(
        self,
        diagnostic_context: diagnostics.DiagnosticContext,
        module: torch.fx.GraphModule,
        decomposition_table: Mapping[torch._ops.OpOverload, Callable],
        enable_dynamic_axes: bool,
        allow_fake_constant: bool | None = False,
    ):
        super().__init__(diagnostic_context, module)
        self.decomposition_table = decomposition_table
        self.enable_dynamic_axes = enable_dynamic_axes
        self.allow_fake_constant = allow_fake_constant

    def _run(self, *args, **kwargs) -> torch.fx.GraphModule:
        assert not kwargs, "kwargs is not supported in Decompose."

        # To preserve stack trace info after `make_fx`.
        module = _utils.wrap_graph_module_for_node_meta_preservation(self.module)

        # fake mode use static size to trace the size of tensors. while symbolic
        # mode generates aten::sym_size to dynamically trace the size of tensors.

        # e.g. fake mode:
        #  view: f32[3, 5, 20] = torch.ops.aten.view.default(x, [3, 5, 20])

        # e.g. symbolic mode:
        #  sym_size = torch.ops.aten.sym_size(x, 0)
        #  sym_size_1 = torch.ops.aten.sym_size(x, 1)
        #  sym_size_2 = torch.ops.aten.sym_size(x, 2)
        #  sym_size_3 = torch.ops.aten.sym_size(x, 3)
        #  mul = sym_size_2 * sym_size_3;  sym_size_2 = sym_size_3 = None
        #  view: f32[3, 5, 20] = torch.ops.aten.view.default(x, [sym_size, sym_size_1, mul])

        # Mimic `torch._dynamo.export(aten_graph=True)` behavior in invoking `make_fx`.
        # TODO: May need revisit for user fake mode export + dynamic shape scenario.
        fake_mode: fake_tensor.FakeTensorMode | None = self.fake_mode
        maybe_fake_args = self._maybe_fakefy_args(fake_mode, *args)
        if fake_mode is not None:
            # Using existing fake mode as context, signal `make_fx` that it does not need
            # to create a new fake mode by passing tracing_mode as "real".
            tracing_mode = "real"
        else:
            # Existing fake mode not found, signal `make_fx` to create one.
            fake_mode = contextlib.nullcontext()  # type: ignore[assignment]
            tracing_mode = "symbolic" if self.enable_dynamic_axes else "fake"

        # Apply decomposition table to the input graph.
        assert fake_mode is not None  # for mypy
        with (
            fake_tensor.unset_fake_temporarily(),
            python_dispatch.enable_python_dispatcher(),
            fake_mode,
        ):
            decomposed_module = proxy_tensor.make_fx(
                module,
                decomposition_table=self.decomposition_table,
                tracing_mode=tracing_mode,
                _allow_non_fake_inputs=True,
                _allow_fake_constant=bool(self.allow_fake_constant),
            )(*maybe_fake_args)

        # Rename placeholder targets to match the original module's signature since
        # We don't want to map forward(x, y, z) to forward(arg0, arg1, arg2).
        _utils.replace_placeholder_name_and_target(decomposed_module, self.module)

        return decomposed_module
