# mypy: ignore-errors

"""
This file contains a collection of context manager classes used by Dynamo for tracking
and managing various PyTorch runtime states during graph compilation. These context
managers handle different aspects of PyTorch's execution environment, including:

- Autograd states (grad mode, inference mode)
- CUDA streams and events
- Profiling contexts
- Deterministic algorithms
- Forward/backward AD modes
- SDPA (Scaled Dot Product Attention) kernels
- FSDP (Fully Sharded Data Parallel) states
- AMP (Automatic Mixed Precision) autocast states

The context managers ensure proper state transitions during graph compilation by
tracking enter/exit points and managing cleanup operations. They help maintain
consistency between eager execution and compiled graph behavior by capturing and
restoring state changes.
"""

import dataclasses
import inspect
import sys
import warnings
from typing import Callable, Optional, TYPE_CHECKING, Union

import torch._C
from torch._guards import Guard

from .. import graph_break_hints, variables
from ..bytecode_transformation import (
    create_call_function,
    create_instruction,
    create_setup_with,
)
from ..device_interface import get_interface_for_device
from ..exc import unimplemented_v2
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, GlobalStateSource
from .base import VariableTracker
from .functions import (
    NestedUserFunctionVariable,
    UserFunctionVariable,
    UserMethodVariable,
    WrappedUserFunctionVariable,
    WrappedUserMethodVariable,
)
from .user_defined import UserDefinedObjectVariable


if TYPE_CHECKING:
    from torch._dynamo.symbolic_convert import InstructionTranslator


@dataclasses.dataclass
class ContextManagerState:
    """
    Mutating `self` in VariableTracker is not allowed because we copy
    them.  This is a mutable container pointed to by context managers
    that won't get copied, so it is safe to mutate.
    """

    cleanup_fn: Optional[Callable] = None
    proxy: Optional[torch.fx.Proxy] = None

    def cleanup(self):
        if self.cleanup_fn is not None:
            self.cleanup_fn()
            self.cleanup_fn = None

    def cleanup_assert(self):
        assert self.cleanup_fn, "multiple exits?"
        self.cleanup()


class ContextWrappingVariable(VariableTracker):
    _nonvar_fields = {
        "cm_obj",
        "target_values",
        "initial_values",
        "state",
        *VariableTracker._nonvar_fields,
    }

    def __init__(
        self, target_values, initial_values=None, *, state=None, **kwargs
    ) -> None:
        super().__init__(**kwargs)
        self.target_values = target_values
        self.initial_values = initial_values
        self.state = ContextManagerState() if state is None else state

    def enter(self, tx):
        self._call_func(tx, self.target_values)
        self.set_cleanup_hook(tx)
        return variables.ConstantVariable.create(None)

    def set_cleanup_hook(self, tx: "InstructionTranslator", fn=None):
        if fn is None:

            def fn():
                self._call_func(tx, self.initial_values)

        self.state.cleanup_fn = fn
        tx.output.add_cleanup_hook(self.state.cleanup)

    def exit(self, tx: "InstructionTranslator", *args):
        self.state.cleanup_assert()
        return variables.ConstantVariable.create(None)

    def reconstruct_type(self, codegen):
        codegen(
            AttrSource(codegen.tx.import_source(self.module_name()), self.fn_name())
        )

    def reconstruct(self, codegen):
        codegen.add_push_null(lambda: self.reconstruct_type(codegen))
        target_values = self.target_values
        if not target_values:
            target_values = ()
        codegen.extend_output([codegen.create_load_const(val) for val in target_values])
        codegen.extend_output(create_call_function(len(target_values), False))

    def module_name(self):
        raise NotImplementedError("module_name called on base")

    def fn_name(self):
        raise NotImplementedError("fn_name called on base")

    def call_function(
        self,
        tx: "InstructionTranslator",
        args: "list[VariableTracker]",
        kwargs: "dict[str, VariableTracker]",
    ) -> "VariableTracker":
        assert len(args) == 1
        if isinstance(args[0], NestedUserFunctionVariable):
            args[0] = UserFunctionVariable(args[0].get_function())
        assert isinstance(args[0], (UserMethodVariable, UserFunctionVariable))

        if isinstance(args[0], UserMethodVariable):
            return WrappedUserMethodVariable(args[0], self)

        if isinstance(args[0], UserFunctionVariable):
            return WrappedUserFunctionVariable(args[0], self)

    def supports_graph_breaks(self):
        return True

    def exit_on_graph_break(self):
        return True


class GenericContextWrappingVariable(UserDefinedObjectVariable):
    # Some methods in ContextWrappingVariable assumes the arguments are
    # python contants. Which might not always be the case here.
    def __init__(self, cm_obj, **kwargs) -> None:
        assert cm_obj is not None
        super().__init__(
            value=cm_obj,
            value_type=cm_obj.__class__,
            **kwargs,
        )
        self.cm_obj = cm_obj

    def module_name(self):
        return self.cm_obj.__module__

    def fn_name(self):
        return type(self.cm_obj).__name__

    def enter(self, tx):
        source = None if self.source is None else AttrSource(self.source, "__enter__")
        return variables.UserMethodVariable(
            self.cm_obj.__enter__.__func__,
            self,
            source=source,
        ).call_function(tx, [], {})

    def exit(self, tx: "InstructionTranslator", *args):
        source = None if self.source is None else AttrSource(self.source, "__exit__")
        x = variables.UserMethodVariable(
            self.cm_obj.__exit__.__func__,
            self,
            source=source,
        ).call_function(tx, args, {})
        tx.active_generic_context_managers.pop()
        return x

    def supports_graph_breaks(self):
        return False

    def exit_on_graph_break(self):
        return True


class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable):
    """represents torch grad requries grad"""

    @staticmethod
    def create(tx: "InstructionTranslator", target_values, **kwargs):
        return GradInplaceRequiresGradCtxManagerVariable(
            target_values=target_values,
            initial_values=None,
            **kwargs,
        )

    def enter(self, tx):
        [enabled] = self.target_values
        self.prev_state = torch._C._functorch.get_inplace_requires_grad_allowed()
        torch._C._functorch.set_inplace_requires_grad_allowed(enabled)
        self.set_cleanup_hook(
            tx,
            lambda: torch._C._functorch.set_inplace_requires_grad_allowed(
                self.prev_state
            ),
        )
        self.state.proxy = tx.output.create_node(
            "call_function",
            torch._C._functorch.set_inplace_requires_grad_allowed,
            (enabled,),
            {},
        )
        return variables.ConstantVariable.create(None)

    def exit(self, tx: "InstructionTranslator", *args):
        self.state.cleanup()
        tx.output.create_node(
            "call_function",
            torch._C._functorch.set_inplace_requires_grad_allowed,
            (self.prev_state,),
            {},
        )
        return variables.ConstantVariable.create(None)


class TemporarilyPopInterpreterStackCtxManagerVariable(ContextWrappingVariable):
    """represents torch._functorch.pyfunction.temporarily_pop_interpreter_stack()"""

    @staticmethod
    def create(tx: "InstructionTranslator", target_values, **kwargs):
        return TemporarilyPopInterpreterStackCtxManagerVariable(
            target_values=target_values,
            initial_values=None,
            **kwargs,
        )

    def enter(self, tx):
        self.saved = torch._C._functorch.pop_dynamic_layer_stack()
        self.set_cleanup_hook(
            tx,
            lambda: torch._C._functorch.push_dynamic_layer_stack(self.saved),
        )
        self.state.proxy = tx.output.create_node(
            "call_function",
            torch._C._functorch.pop_dynamic_layer_stack,
            (),
            {},
        )
        return variables.ConstantVariable.create(None)

    def exit(self, tx: "InstructionTranslator", *args):
        self.state.cleanup()
        tx.output.create_node(
            "call_function",
            torch._C._functorch.push_dynamic_layer_stack,
            (self.state.proxy,),
            {},
        )
        return variables.ConstantVariable.create(None)


class JvpIncrementNestingCtxManagerVariable(ContextWrappingVariable):
    """represents torch.func.jvp increment/decrement nesting"""

    # A guard is needed as the grad level is baked into the torch FX graph
    # This is fine if jvp is only called from within the function
    # being compiled. But the FX graph may be invalid in the case of a jvp
    # call from eager that calls the compiled function, as the jvp levels
    # may be different.
    _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH)

    @staticmethod
    def create(tx: "InstructionTranslator", **kwargs):
        var = JvpIncrementNestingCtxManagerVariable(
            target_values=None,
            initial_values=None,
            **kwargs,
        )
        return var

    def enter(self, tx):
        install_guard(self._guards_singleton)
        jvp_level = torch._functorch.eager_transforms.enter_jvp_nesting()
        self.set_cleanup_hook(
            tx, lambda: torch._functorch.eager_transforms.exit_jvp_nesting()
        )
        self.state.proxy = tx.output.create_node(
            "call_function",
            torch._C._functorch._jvp_increment_nesting,
            (),
            {},
        )
        return variables.ConstantVariable.create(jvp_level)

    def exit(self, tx: "InstructionTranslator", *args):
        self.state.cleanup()
        tx.output.create_node(
            "call_function", torch._C._functorch._jvp_decrement_nesting, (), {}
        )
        return variables.ConstantVariable.create(None)


class SetFwdGradEnabledContextManager(ContextWrappingVariable):
    """represents torch.autograd.forward_ad._set_fwd_grad_enabled() to enable/disable fwd grad"""

    @staticmethod
    def create(tx: "InstructionTranslator", target_values, **kwargs):
        return SetFwdGradEnabledContextManager(
            target_values=target_values,
            initial_values=None,
            **kwargs,
        )

    def enter(self, tx):
        [mode] = self.target_values
        self.prev_state = torch._C._is_fwd_grad_enabled()
        torch._C._set_fwd_grad_enabled(mode)
        self.set_cleanup_hook(
            tx,
            lambda: torch._C._set_fwd_grad_enabled(self.prev_state),
        )
        self.state.proxy = tx.output.create_node(
            "call_function",
            torch._C._set_fwd_grad_enabled,
            (mode,),
            {},
        )
        return variables.ConstantVariable.create(None)

    def exit(self, tx: "InstructionTranslator", *args):
        self.state.cleanup()
        tx.output.create_node(
            "call_function",
            torch._C._set_fwd_grad_enabled,
            (self.prev_state,),
            {},
        )
        return variables.ConstantVariable.create(None)


class DualLevelContextManager(ContextWrappingVariable):
    """Represents torch.autograd.forward_ad.dual_level ctx manager"""

    _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.DUAL_LEVEL)

    @staticmethod
    def create(tx: "InstructionTranslator", **kwargs):
        return DualLevelContextManager(
            target_values=None,
            initial_values=None,
            **kwargs,
        )

    def enter(self, tx):
        install_guard(self._guards_singleton)
        self.new_level = torch.autograd.forward_ad.enter_dual_level()
        self.set_cleanup_hook(
            tx, lambda: torch.autograd.forward_ad.exit_dual_level(level=self.new_level)
        )
        self.state.proxy = tx.output.create_node(
            "call_function",
            torch._C._enter_dual_level,
            (),
            {},
        )
        return variables.ConstantVariable.create(self.new_level)

    def exit(self, tx: "InstructionTranslator", *args):
        self.state.cleanup()
        tx.output.create_node(
            "call_function",
            torch._C._exit_dual_level,
            (self.new_level,),
            {},
        )
        return variables.ConstantVariable.create(None)


class GradIncrementNestingCtxManagerVariable(ContextWrappingVariable):
    """represents torch.func.grad increment/decrement nesting"""

    # A guard is needed as the grad level is baked into the torch FX graph
    # This is fine if grad is only called from within the function
    # being compiled. But the FX graph may be invalid in the case of a grad
    # call from eager that calls the compiled function, as the grad levels
    # may be different.
    _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH)

    @staticmethod
    def create(tx: "InstructionTranslator", **kwargs):
        var = GradIncrementNestingCtxManagerVariable(
            target_values=None,
            initial_values=None,
            **kwargs,
        )
        return var

    def enter(self, tx):
        install_guard(self._guards_singleton)
        grad_level = torch._C._functorch._grad_increment_nesting()
        self.set_cleanup_hook(tx, lambda: torch._C._functorch._grad_decrement_nesting())
        self.state.proxy = tx.output.create_node(
            "call_function",
            torch._C._functorch._grad_increment_nesting,
            (),
            {},
        )
        return variables.ConstantVariable.create(grad_level)

    def exit(self, tx: "InstructionTranslator", *args):
        self.state.cleanup()
        tx.output.create_node(
            "call_function", torch._C._functorch._grad_decrement_nesting, (), {}
        )
        return variables.ConstantVariable.create(None)


class CatchWarningsCtxManagerVariable(ContextWrappingVariable):
    """Delay a call to warnings.catch_warnings"""

    @staticmethod
    def create(tx: "InstructionTranslator", catch_warnings_args):
        return CatchWarningsCtxManagerVariable(
            catch_warnings_args=catch_warnings_args,
            target_values=None,
            initial_values=None,
        )

    def __init__(self, catch_warnings_args, **kwargs) -> None:
        assert isinstance(catch_warnings_args, dict), catch_warnings_args
        super().__init__(**kwargs)
        self.catch_warnings_args = catch_warnings_args

    def enter(self, tx):
        kwargs = {
            k: v.as_python_constant() for k, v in self.catch_warnings_args.items()
        }
        ctx_val = warnings.catch_warnings(**kwargs)
        self.set_cleanup_hook(tx, lambda: ctx_val.__exit__(None, None, None))
        return variables.ConstantVariable.create(ctx_val.__enter__())

    def reconstruct(self, cg):
        cg.add_push_null(lambda: cg.load_import_from("warnings", "catch_warnings"))
        cg.foreach(self.catch_warnings_args.values())
        keys = tuple(self.catch_warnings_args.keys())
        cg.extend_output(cg.create_call_function_kw(len(keys), keys, False))


class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable):
    """represents torch VMap increment/decrement nesting"""

    # A guard is needed as the vmap level is baked into the torch FX graph
    # generated. This is fine if vmap is only called from within the function
    # being compiled. But the FX graph may be invalid in the case of a vmap
    # call from eager that calls the compiled function, as the vmap levels
    # may be different.
    _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH)

    @staticmethod
    def create(tx: "InstructionTranslator", target_values, **kwargs):
        var = VmapIncrementNestingCtxManagerVariable(
            target_values=target_values,
            initial_values=None,
            **kwargs,
        )
        return var

    def enter(self, tx):
        install_guard(self._guards_singleton)
        batch_size, randomness = self.target_values
        if isinstance(batch_size, variables.SymNodeVariable):
            batch_size_value = batch_size.sym_num
            batch_size_node = batch_size.as_proxy().node
        else:
            batch_size_value = batch_size.as_python_constant()
            batch_size_node = batch_size.as_python_constant()
        randomness = randomness.as_python_constant()
        vmap_level = torch._C._functorch._vmap_increment_nesting(
            batch_size_value, randomness
        )
        self.set_cleanup_hook(tx, lambda: torch._C._functorch._vmap_decrement_nesting())
        self.state.proxy = tx.output.create_node(
            "call_function",
            torch._C._functorch._vmap_increment_nesting,
            (batch_size_node, randomness),
            {},
        )
        return variables.ConstantVariable.create(vmap_level)

    def exit(self, tx: "InstructionTranslator", *args):
        self.state.cleanup()
        tx.output.create_node(
            "call_function", torch._C._functorch._vmap_decrement_nesting, (), {}
        )
        return variables.ConstantVariable.create(None)


class GradModeVariable(ContextWrappingVariable):
    """represents torch.{no_grad,enable_grad,set_grad_mode}()"""

    _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.GRAD_MODE)

    @staticmethod
    def create(tx: "InstructionTranslator", target_value, initialized=False, **kwargs):
        var = GradModeVariable(
            target_values=[target_value],
            initial_values=[torch.is_grad_enabled()],
            **kwargs,
        )
        if initialized:
            var._call_func(tx, var.target_values)
        return var

    def __init__(
        self, target_values, initial_values=None, initialized=True, **kwargs
    ) -> None:
        super().__init__(
            target_values=target_values, initial_values=initial_values, **kwargs
        )
        install_guard(self._guards_singleton)

    def enter(self, tx):
        self._call_func(tx, self.target_values)
        return variables.ConstantVariable.create(None)

    def exit(self, tx: "InstructionTranslator", *args):
        self._call_func(tx, self.initial_values)
        return variables.ConstantVariable.create(None)

    def call_function(
        self,
        tx: "InstructionTranslator",
        args: "list[VariableTracker]",
        kwargs: "dict[str, VariableTracker]",
    ):
        self._call_func(tx, self.initial_values)  # undo eager initialization
        return super().call_function(tx, args, kwargs)

    def _call_func(self, tx: "InstructionTranslator", values):
        assert len(values) == 1
        value = values[0]
        # Coalesce grad mode mutations
        if torch.is_grad_enabled() != value:
            tx.output.create_node(
                "call_function", torch._C._set_grad_enabled, (value,), {}
            )
            torch._C._set_grad_enabled(value)

    def module_name(self):
        return "torch"

    def fn_name(self):
        return "set_grad_enabled"


class InferenceModeVariable(ContextWrappingVariable):
    @staticmethod
    def create(tx: "InstructionTranslator", target_value, **kwargs):
        var = InferenceModeVariable(
            [target_value], initial_values=torch.is_inference_mode_enabled(), **kwargs
        )
        return var

    def __init__(
        self,
        target_values,
        initial_values=None,
        **kwargs,
    ) -> None:
        if initial_values is None:
            # This must be called here since function defaults are evaluated at import time
            initial_values = torch.is_inference_mode_enabled()
        super().__init__(
            target_values=target_values, initial_values=initial_values, **kwargs
        )
        self.target_values = target_values

    def exit(self, tx: "InstructionTranslator", *args):
        self.state.cleanup_assert()
        tx.output.create_node(
            "call_function",
            torch.autograd.grad_mode._exit_inference_mode,
            (self.state.proxy,),
            {},
        )

    def enter(self, tx):
        ctx = torch.autograd.grad_mode._enter_inference_mode(*self.target_values)
        self.set_cleanup_hook(
            tx, lambda: torch.autograd.grad_mode._exit_inference_mode(ctx)
        )
        self.state.proxy = tx.output.create_node(
            "call_function",
            torch.autograd.grad_mode._enter_inference_mode,
            (*self.target_values,),
            {},
        )

    def module_name(self):
        return "torch"

    def fn_name(self):
        return "inference_mode"


class CUDADeviceVariable(ContextWrappingVariable):
    """represents torch.cuda.device"""

    @staticmethod
    def create(tx: "InstructionTranslator", device, **kwargs):
        var = CUDADeviceVariable(
            target_values=[torch.cuda._get_device_index(device, optional=True)],
            initial_values=None,
            **kwargs,
        )
        return var

    def __init__(
        self,
        target_values,
        initial_values=None,
        **kwargs,
    ) -> None:
        super().__init__(
            target_values=target_values, initial_values=initial_values, **kwargs
        )
        self.target_values = target_values

    def exit(self, tx: "InstructionTranslator", *args):
        self.state.cleanup_assert()
        tx.output.create_node(
            "call_function",
            torch.cuda._maybe_exchange_device,
            (self.state.proxy,),
            {},
        )
        return variables.ConstantVariable.create(False)

    def enter(self, tx):
        prev_idx = torch.cuda._exchange_device(*self.target_values)
        self.set_cleanup_hook(tx, lambda: torch.cuda._maybe_exchange_device(prev_idx))
        self.state.proxy = tx.output.create_node(
            "call_function",
            torch.cuda._exchange_device,
            (*self.target_values,),
            {},
        )

    def module_name(self):
        return "torch.cuda"

    def fn_name(self):
        return "device"


class TorchFunctionDisableVariable(ContextWrappingVariable):
    """represents whether torch function overrides are enabled or not"""

    _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.TORCH_FUNCTION_STATE)

    @staticmethod
    def create(tx: "InstructionTranslator", **kwargs):
        var = TorchFunctionDisableVariable(
            target_values=[False],
            initial_values=[tx.output.torch_function_enabled],
            **kwargs,
        )
        # mlazos: I think this is here to make sure we don't reinvoke on clone()
        var._call_func(tx, [False])
        var.set_cleanup_hook(tx)
        return var

    def __init__(self, target_values, initial_values=None, **kwargs) -> None:
        super().__init__(
            target_values=target_values, initial_values=initial_values, **kwargs
        )
        install_guard(self._guards_singleton)

    def enter(self, tx):
        return variables.ConstantVariable.create(None)

    def _call_func(self, tx: "InstructionTranslator", values):
        assert len(values) == 1
        tx.symbolic_torch_function_state.torch_function_subclass_enabled = values[0]
        tx.symbolic_torch_function_state.torch_function_mode_enabled = values[0]
        tx.output.set_torch_function_state(values[0])


class DeterministicAlgorithmsVariable(ContextWrappingVariable):
    """represents torch.{are_deterministic_algorithms_enabled,use_deterministic_algorithms}()"""

    _guards_singleton = Guard(
        GlobalStateSource(), GuardBuilder.DETERMINISTIC_ALGORITHMS
    )

    @staticmethod
    def create(tx: "InstructionTranslator", target_value, **kwargs):
        var = DeterministicAlgorithmsVariable(
            target_values=[target_value],
            initial_values=[torch.are_deterministic_algorithms_enabled()],
            **kwargs,
        )
        var._call_func(tx, [target_value])
        var.set_cleanup_hook(tx)
        return var

    def __init__(self, target_values, initial_values=None, **kwargs) -> None:
        super().__init__(
            target_values=target_values, initial_values=initial_values, **kwargs
        )
        install_guard(self._guards_singleton)

    def enter(self, tx):
        return variables.ConstantVariable.create(None)

    def _call_func(self, tx: "InstructionTranslator", values):
        assert len(values) == 1
        value = values[0]
        (
            tx.output.create_node(
                "call_function", torch._C._set_deterministic_algorithms, (value,), {}
            ),
        )
        torch._C._set_deterministic_algorithms(value)

    def module_name(self):
        return "torch"

    def fn_name(self):
        return "use_deterministic_algorithms"


class DisabledSavedTensorsHooksVariable(ContextWrappingVariable):
    """represents torch.autograd.graph.disable_saved_tensors_hook."""

    @staticmethod
    def create(tx: "InstructionTranslator", target_value, **kwargs):
        var = DisabledSavedTensorsHooksVariable(
            target_values=[target_value],
            initial_values=[
                torch._C._autograd._saved_tensors_hooks_get_disabled_error_message()
            ],
            **kwargs,
        )
        var._call_func(tx, [target_value])
        var.set_cleanup_hook(tx)
        return var

    def __init__(self, target_values, initial_values=None, **kwargs) -> None:
        super().__init__(
            target_values=target_values, initial_values=initial_values, **kwargs
        )

    def enter(self, tx):
        return variables.ConstantVariable.create(None)

    def _call_func(self, tx: "InstructionTranslator", values):
        assert len(values) == 1
        value = values[0]
        if value is not None:
            # Disable `saved_tensors_hooks` with message (`value`)
            # OR
            # we are exiting this context and restoring the previous message.
            tx.output.create_node(
                "call_function",
                torch._C._autograd._saved_tensors_hooks_disable,
                (value,),
                {},
            )
            torch._C._autograd._saved_tensors_hooks_disable(value)
        else:
            # We are exiting this context and if prev_message was None, we re-enable `saved_tensors_hooks`.
            tx.output.create_node(
                "call_function", torch._C._autograd._saved_tensors_hooks_enable, (), {}
            )
            torch._C._autograd._saved_tensors_hooks_enable()

    def module_name(self):
        return "torch.autograd.graph"

    def fn_name(self):
        return "disable_saved_tensors_hooks"


class AutocastModeVariable(ContextWrappingVariable):
    @staticmethod
    def create(func, args, kwargs):
        assert func in [
            torch.amp.autocast_mode.autocast,
            torch.cuda.amp.autocast,
            torch.cpu.amp.autocast,
        ]
        # device_type : str,
        # dtype : Optional[_dtype] = None,
        # enabled : bool = True,
        # cache_enabled : Optional[bool] = None):cache_enabled
        bound_args = inspect.signature(func).bind(*args, **kwargs)
        bound_args.apply_defaults()
        target_values = []
        kwargs.clear()

        for key in ["device_type", "dtype", "enabled", "cache_enabled"]:
            if key == "device_type" and func in [
                torch.cuda.amp.autocast,
                torch.cpu.amp.autocast,
            ]:
                arg = "cuda" if func is torch.cuda.amp.autocast else "cpu"
            else:
                arg = bound_args.arguments[key]
            if isinstance(arg, VariableTracker):
                target_values.append(arg.as_python_constant())
            else:
                target_values.append(arg)

        var = AutocastModeVariable(target_values, initial_values=None, **kwargs)
        return var

    def __init__(self, target_values, initial_values=None, **kwargs) -> None:
        super().__init__(
            target_values=target_values, initial_values=initial_values, **kwargs
        )
        self.target_values = target_values

    def exit(self, tx: "InstructionTranslator", *args):
        self.state.cleanup_assert()
        tx.output.create_node(
            "call_function", torch.amp._exit_autocast, (self.state.proxy,), {}
        )

    def enter(self, tx):
        ctx = torch.amp._enter_autocast(*self.target_values)
        self.set_cleanup_hook(tx, lambda: torch.amp._exit_autocast(ctx))
        self.state.proxy = tx.output.create_node(
            "call_function", torch.amp._enter_autocast, (*self.target_values,), {}
        )

    def module_name(self):
        return "torch.amp.autocast_mode"

    def fn_name(self):
        return "autocast"


class NullContextVariable(ContextWrappingVariable):
    """
    This class represents Python contextlib.nullcontext.
    """

    def __init__(self, target_values=None, **kwargs) -> None:
        super().__init__(target_values=target_values, **kwargs)

    def enter(self, tx):
        return variables.ConstantVariable.create(None)

    def exit(self, tx: "InstructionTranslator", *args):
        return variables.ConstantVariable.create(None)

    def module_name(self):
        return "contextlib"

    def fn_name(self):
        return "nullcontext"


class ProfilerContextVariable(ContextWrappingVariable):
    """
    This class represents a set of torch profiler context objects, where Dynamo
    ignores all the side-effects in the __init__, __enter__ and __exit__ methods
    by treating the object mostly as a `contextlib.nullcontext`, except for edge
    cases like the `__enter__` method which returns the object itself rather
    than `None`, per implementation of the torch objects.
    """

    def __init__(self, **kwargs) -> None:
        super().__init__(target_values=None, **kwargs)

    def enter(self, tx):
        return self

    def exit(self, tx: "InstructionTranslator", *args):
        return variables.ConstantVariable.create(None)

    def module_name(self):
        return "contextlib"

    def fn_name(self):
        return "nullcontext"

    def reconstruct(self, cg):
        unimplemented_v2(
            gb_type="torch.profiler object escaped from compiled region",
            context=str(self),
            explanation="Dynamo doesn't support compiling a region that returns a torch.profiler context manager.",
            hints=[
                *graph_break_hints.SUPPORTABLE,
            ],
        )


class StreamContextVariable(ContextWrappingVariable):
    @staticmethod
    def create(tx: "InstructionTranslator", target_value, **kwargs):
        from .builder import wrap_fx_proxy_cls

        current_stream_method = get_interface_for_device(
            target_value.device
        ).current_stream
        current_stream = wrap_fx_proxy_cls(
            StreamVariable,
            tx,
            tx.output.create_proxy(
                "call_function",
                current_stream_method,
                (None,),
                {},
            ),
        )
        return StreamContextVariable(
            target_values=[target_value],
            initial_values=[current_stream],
            device=target_value.device,
            **kwargs,
        )

    def __init__(self, target_values, device, initial_values=None, **kwargs) -> None:
        super().__init__(
            target_values=target_values, initial_values=initial_values, **kwargs
        )
        self.device = device
        self.set_stream = get_interface_for_device(self.device).set_stream
        self.set_stream_id = get_interface_for_device(self.device)._set_stream_by_id

    def enter(self, tx):
        # stream generated inside the traced function
        if self.target_values[0].as_proxy() is not None:
            tx.output.create_proxy(
                "call_function",
                self.set_stream,
                (self.target_values[0].as_proxy(),),
                {},
            )
        # stream passed from outside the traced function
        else:
            stream = self.target_values[0].value
            tx.output.create_proxy(
                "call_function",
                self.set_stream_id,
                (stream.stream_id, stream.device_index, stream.device_type),
                {},
            )
        self.set_stream(self.target_values[0].value)
        self.set_cleanup_hook(tx, lambda: self.set_stream(self.initial_values[0].value))

    def exit(self, tx: "InstructionTranslator", *args):
        tx.output.create_proxy(
            "call_function",
            self.set_stream,
            (self.initial_values[0].as_proxy(),),
            {},
        )
        self.state.cleanup_assert()


class PreserveVersionContextVariable(ContextWrappingVariable):
    """
    Wraps torch.autograd._unsafe_preserve_version_counter
    """

    @staticmethod
    def _create_lambda_from_tensors(tx, tensors):
        if isinstance(tensors, variables.TensorVariable):
            versions = variables.TupleVariable(
                [x.var_getattr(tx, "_version") for x in [tensors]]
            )
            tensors = variables.TupleVariable([tensors])
        else:
            versions = variables.TupleVariable(
                [x.var_getattr(tx, "_version") for x in tensors.items]
            )
        return PreserveVersionContextVariable(tensors, versions)

    @staticmethod
    def constructor(tx):
        return variables.LambdaVariable(
            lambda tensors: PreserveVersionContextVariable._create_lambda_from_tensors(
                tx, tensors
            )
        )

    def __init__(self, tensors, prev_versions, **kwargs) -> None:
        kwargs.setdefault("target_values", None)
        super().__init__(**kwargs)
        self.tensors = tensors
        self.prev_versions = prev_versions
        # The context manager accepts Union[Tensor, Tuple[Tensor]]
        if isinstance(self.tensors, variables.TensorVariable):
            self.tensors = variables.TupleVariable([self.tensors])
        if isinstance(
            self.prev_versions, (variables.ConstantVariable, variables.SymNodeVariable)
        ):
            self.prev_versions = variables.TupleVariable([self.prev_versions])

    def enter(self, tx):
        pass

    def exit(self, tx: "InstructionTranslator", *args):
        from ..tensor_version_op import _unsafe_set_version_counter

        return variables.TorchInGraphFunctionVariable(
            _unsafe_set_version_counter
        ).call_function(tx, [self.tensors, self.prev_versions], {})

    def reconstruct(self, codegen):
        unimplemented_v2(
            gb_type="torch.autograd._unsafe_preserve_version_counter escaped from compiled region",
            context=str(self),
            explanation=(
                "Dynamo doesn't support compiling a region that returns "
                "a torch.autograd._unsafe_preserve_version_counter context manager."
            ),
            hints=[
                *graph_break_hints.SUPPORTABLE,
            ],
        )


class FSDPParamGroupUseTrainingStateVariable(ContextWrappingVariable):
    _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FSDP_TRAINING_STATE)

    @staticmethod
    def create(tx: "InstructionTranslator", param_group_var, target_value, **kwargs):
        var = FSDPParamGroupUseTrainingStateVariable(
            param_group_var=param_group_var,
            target_values=[target_value],
            initial_values=[param_group_var.value._training_state],
            **kwargs,
        )
        return var

    def __init__(
        self, param_group_var, target_values, initial_values=None, **kwargs
    ) -> None:
        super().__init__(
            target_values=target_values, initial_values=initial_values, **kwargs
        )
        self.param_group_var = param_group_var
        install_guard(self._guards_singleton)

    def enter(self, tx):
        self._call_func(tx, self.target_values)
        return variables.ConstantVariable.create(None)

    def exit(self, tx: "InstructionTranslator", *args):
        self._call_func(tx, self.initial_values)
        return variables.ConstantVariable.create(None)

    def call_function(
        self,
        tx: "InstructionTranslator",
        args: "list[VariableTracker]",
        kwargs: "dict[str, VariableTracker]",
    ):
        self._call_func(tx, self.initial_values)  # undo eager initialization
        return super().call_function(tx, args, kwargs)

    def _call_func(self, tx: "InstructionTranslator", values):
        assert len(values) == 1
        value = values[0]
        if self.param_group_var.value._training_state != value:
            self.param_group_var.call_method(
                tx,
                "__setattr__",
                (
                    variables.ConstantVariable.create("_training_state"),
                    variables.EnumVariable(value),
                ),
                {},
            )
            self.param_group_var.value._training_state = value

    def module_name(self):
        return "torch.distributed.fsdp._fully_shard._fsdp_param_group.FSDPParamGroup"

    def fn_name(self):
        return "use_training_state"


class SDPAKernelVariable(ContextWrappingVariable):
    """represents torch.nn.attention.sdpa_kernel"""

    @staticmethod
    def create(tx: "InstructionTranslator", backends, **kwargs):
        if isinstance(backends, torch.nn.attention.SDPBackend):
            backends = [backends]
        var = SDPAKernelVariable(
            target_values=backends,
            initial_values=None,
            **kwargs,
        )
        return var

    def __init__(
        self,
        target_values: list[torch.nn.attention.SDPBackend],
        initial_values=None,
        **kwargs,
    ) -> None:
        super().__init__(
            target_values=target_values, initial_values=initial_values, **kwargs
        )

    @staticmethod
    def _backends_to_nodes(tx, backends):
        # convert to/from string in order to bake the backend into FX graph
        nodes = [
            tx.output.create_node(
                "call_function",
                torch.nn.attention._backend_from_string,
                (backend.name,),
                {},
            )
            for backend in backends
        ]
        return nodes

    def enter(self, tx):
        self.prev_backends = torch.nn.attention._cur_sdpa_kernel_backends()
        self.set_cleanup_hook(
            tx, lambda: torch.nn.attention._sdpa_kernel(self.prev_backends)
        )
        torch.nn.attention._sdpa_kernel(self.target_values)
        arg = self._backends_to_nodes(tx, self.target_values)
        tx.output.create_node(
            "call_function",
            torch.nn.attention._sdpa_kernel,
            (arg,),
            {},
        )
        return variables.ConstantVariable.create(None)

    def exit(self, tx: "InstructionTranslator", *args):
        self.state.cleanup_assert()
        arg = self._backends_to_nodes(tx, self.prev_backends)
        tx.output.create_node(
            "call_function",
            torch.nn.attention._sdpa_kernel,
            (arg,),
            {},
        )
        return variables.ConstantVariable.create(None)

    def module_name(self):
        return "torch.nn.attention"

    # use a private version of sdpa_kernel that accepts variadic arguments
    # since dynamo reconstructs the contents of target_values one-by-one
    def fn_name(self):
        return "_sdpa_kernel_variadic"


class StreamVariable(VariableTracker):
    def __init__(self, proxy, value, device, **kwargs) -> None:
        if proxy is not None and "example_value" in proxy.node.meta:
            assert proxy.node.meta["example_value"] == value
        assert value.device.type == device.type, (
            "stream value is not equal to the passed device"
        )
        super().__init__(**kwargs)
        self.proxy = proxy
        self.value = value
        self.device = device

    def python_type(self):
        return torch.Stream

    def call_method(
        self,
        tx,
        name,
        args: "list[VariableTracker]",
        kwargs: "dict[str, VariableTracker]",
    ) -> "VariableTracker":
        assert hasattr(self.value, name), f"no stream method found named {name}"

        from ..utils import cmp_name_to_op_mapping, proxy_args_kwargs
        from .builder import wrap_fx_proxy_cls

        if name in ("wait_stream", "synchronize", "wait_event"):
            tx.output.create_proxy(
                "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
            )
            return variables.ConstantVariable(None)
        elif name == "query":
            return wrap_fx_proxy_cls(
                target_cls=variables.ConstantVariable,
                tx=tx,
                proxy=tx.output.create_proxy(
                    "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
                ),
            )
        elif name == "record_event":
            return wrap_fx_proxy_cls(
                target_cls=EventVariable,
                tx=tx,
                proxy=tx.output.create_proxy(
                    "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
                ),
            )
        elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs:
            # NB : Checking for mutation is necessary because we compare
            # constant values
            other = args[0]
            if not isinstance(other, StreamVariable):
                return variables.ConstantVariable.create(NotImplemented)
            return variables.ConstantVariable.create(
                cmp_name_to_op_mapping[name](self.value, other.value)
            )

        return super().call_method(tx, name, args, kwargs)

    def as_proxy(self):
        return self.proxy

    def reconstruct(self, codegen):
        # If we got here, this stream is fully subsumed by the graph - this means it is
        # not an input or global
        assert not self.source
        # Since we just proved that - for other such structures, like lists and dicts, reconstruction
        # is fine and sound according to dynamo principles of treating collectives. However,
        # streams are special in that we want to preserve the identity of the stream as the same as in the graph
        # Normally, we would do this via codegen for the proxy mapping to an output - we cannot do this yet, as we do not
        # yet have a plan for how we want to handle the case where the stream is used as an input or an output. Pending
        # design, to unblock current work, we lift the stream into a global and then codegen bytecode to load it from there.
        prefix = f"_stream_{self.device}"
        name = codegen.tx.output.install_global_by_id(prefix, self.value)
        codegen.append_output(codegen.create_load_global(name, add=True))


class EventVariable(VariableTracker):
    def __init__(self, proxy, value, **kwargs) -> None:
        if proxy is not None and "example_value" in proxy.node.meta:
            assert proxy.node.meta["example_value"] == value
        super().__init__(**kwargs)
        self.proxy = proxy
        self.value = value

    def call_method(
        self,
        tx,
        name,
        args: "list[VariableTracker]",
        kwargs: "dict[str, VariableTracker]",
    ) -> "VariableTracker":
        from ..utils import proxy_args_kwargs
        from .builder import wrap_fx_proxy_cls

        if name in ("wait", "record", "synchronize"):
            tx.output.create_proxy(
                "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
            )
            return variables.ConstantVariable(None)
        elif name == "query":
            return wrap_fx_proxy_cls(
                target_cls=variables.ConstantVariable,
                tx=tx,
                proxy=tx.output.create_proxy(
                    "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
                ),
            )
        else:
            unimplemented_v2(
                gb_type="Unsupported torch.cuda.Event method",
                context=str(name),
                explanation=(
                    f"Dynamo doesn't support tracing the torch.cuda.Event.{name} method. "
                    f"We currently support wait, record, synchronize, and query.",
                ),
                hints=[
                    *graph_break_hints.SUPPORTABLE,
                ],
            )

    def as_proxy(self):
        return self.proxy

    def reconstruct(self, codegen):
        # If we got here, this event is fully subsumed by the graph - this means it is
        # not an input or global
        assert not self.source
        # Similar to stream handling, we lift the event into a global and then codegen bytecode to load it from there.
        prefix = "_event"
        name = codegen.tx.output.install_global_by_id(prefix, self.value)
        codegen.append_output(codegen.create_load_global(name, add=True))


class WithExitFunctionVariable(VariableTracker):
    _nonvar_fields = {
        "target",
        *VariableTracker._nonvar_fields,
    }

    def __init__(
        self,
        ctx: Union[ContextWrappingVariable, GenericContextWrappingVariable],
        target,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        assert isinstance(
            ctx, (ContextWrappingVariable, GenericContextWrappingVariable)
        )
        self.ctx = ctx
        self.target = target

    def call_function(
        self,
        tx: "InstructionTranslator",
        args: "list[VariableTracker]",
        kwargs: "dict[str, VariableTracker]",
    ) -> "VariableTracker":
        assert not kwargs
        return self.ctx.exit(tx, *args)

    def reconstruct(self, codegen):
        # Note here we reconstruct the context manager rather than the
        # exit function.  The handler generated by BlockStackEntry
        # will re-enter the context in the resume function.
        self.ctx.reconstruct_type(codegen)
        if codegen.tx.output.partial_convert:
            if sys.version_info >= (3, 11):
                codegen.append_output(create_instruction("PUSH_NULL"))
                if sys.version_info < (3, 13):
                    codegen.append_output(create_instruction("SWAP", arg=2))
            codegen.extend_output(
                [codegen.create_load_const(val) for val in self.ctx.target_values]
            )
            codegen.extend_output(
                create_call_function(len(self.ctx.target_values), False)
            )
            codegen.append_output(create_setup_with(self.target))
            codegen.append_output(create_instruction("POP_TOP"))
