from __future__ import annotations


"""Exception handling and error reporting for TorchDynamo.

This module provides a comprehensive set of exception classes and utilities for error
handling in TorchDynamo. It includes:

Base Exceptions:
    - TorchDynamoException: Base class for all TorchDynamo-specific exceptions
    - Various specialized subclasses for different error scenarios

User Error Handling:
    - UserError: Exceptions for user-facing errors in TorchDynamo usage
    - UserErrorType: Enumeration of different categories of user errors
    - Formatted error messages with debugging information

Observed Exceptions:
    - Classes for handling exceptions observed during tracing
    - Special handling for StopIteration, LookupError, etc.
    - Exception state management during compilation

Error Formatting:
    - Stack trace filtering and formatting
    - Error message augmentation
    - Debugging utilities for error reporting
"""

import logging
import os
import re
import textwrap
import typing
from enum import auto, Enum
from traceback import extract_stack, format_exc, format_list, StackSummary
from typing import Any, NoReturn, Optional, TYPE_CHECKING

import torch._guards

from . import config
from .utils import counters


if TYPE_CHECKING:
    import types

    from torch._guards import CompileId

    from .symbolic_convert import InstructionTranslatorBase
    from .types import DynamoFrameType


def exportdb_error_message(case_name: str) -> str:
    return (
        "For more information about this error, see: "
        + "https://pytorch.org/docs/main/generated/exportdb/index.html#"
        + case_name.replace("_", "-")
    )


log = logging.getLogger(__name__)
graph_breaks_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")


class TorchDynamoException(RuntimeError):
    pass


class InternalTorchDynamoError(TorchDynamoException):
    pass


class RestartAnalysis(TorchDynamoException):
    restart_reason: Optional[str]

    def __init__(self, *args: Any, restart_reason: Optional[str] = None) -> None:
        self.restart_reason = restart_reason
        super().__init__(*args)


class SpeculationRestartAnalysis(RestartAnalysis):
    pass


class UnspecializeRestartAnalysis(RestartAnalysis):
    pass


class CompileCollectiveRestartAnalysis(RestartAnalysis):
    pass


class TensorifyScalarRestartAnalysis(RestartAnalysis):
    pass


class SkipFrame(TorchDynamoException):
    pass


class TorchRuntimeError(TorchDynamoException):
    pass


class InvalidBackend(TorchDynamoException):
    def __init__(self, name: str) -> None:
        super().__init__(
            f"Invalid backend: {name!r}, see `torch._dynamo.list_backends()` for available backends."
        )


class ResetRequired(TorchDynamoException):
    def __init__(self) -> None:
        super().__init__(
            textwrap.dedent(
                """
                Must call `torch._dynamo.reset()` before changing backends.  Detected two calls to
                `torch.compile()` with a different backend compiler arguments.
                """
            )
        )


class ShortenTraceback(TorchDynamoException):
    def __init__(
        self, *args: Any, first_useful_frame: Optional[types.FrameType], **kwargs: Any
    ) -> None:
        super().__init__(*args, **kwargs)
        self.first_useful_frame = first_useful_frame

    def remove_dynamo_frames(self) -> typing.Self:
        tb = self.__traceback__
        if self.first_useful_frame is None or tb is None or config.verbose:
            return self
        while tb.tb_frame is not self.first_useful_frame:
            tb = tb.tb_next
            assert tb is not None, "internal error, please report a bug"
        return self.with_traceback(tb)


class BackendCompilerFailed(ShortenTraceback):
    def __init__(
        self,
        backend_fn: Any,
        inner_exception: Exception,
        first_useful_frame: Optional[types.FrameType],
    ) -> None:
        self.backend_name = getattr(backend_fn, "__name__", "?")
        self.inner_exception = inner_exception
        msg = f"backend={self.backend_name!r} raised:\n{type(inner_exception).__name__}: {inner_exception}"
        super().__init__(msg, first_useful_frame=first_useful_frame)


class Unsupported(TorchDynamoException):
    def __init__(self, msg: str, *, case_name: Optional[str] = None) -> None:
        super().__init__(msg)
        self.real_stack = torch._guards.TracingContext.extract_stack()
        self.msg = msg
        self.category: Optional[str] = None
        self.add_to_stats()
        self.case_name: Optional[str] = case_name

    def remove_from_stats(self) -> None:
        assert self.category is not None
        counters[self.category][self.msg] -= 1
        if counters[self.category][self.msg] <= 0:
            del counters[self.category][self.msg]

    def add_to_stats(self, category: str = "unimplemented") -> None:
        self.category = category
        counters[category][self.msg] += 1


class UnknownPropertiesDuringBackwardTrace(Unsupported):
    pass


class RecompileError(TorchDynamoException):
    pass


class ArgsMismatchError(Unsupported):
    def __init__(self, msg: str) -> None:
        super().__init__(msg)


class AttributeMutationError(Unsupported):
    def __init__(self, msg: str) -> None:
        super().__init__(msg)


class InfiniteGeneratorError(Unsupported):
    # Raised when the number of yielded values is greater than MAX_ITERATOR_LIMIT
    def __init__(self, msg: str) -> None:
        super().__init__(msg)


class SideEffectsError(Unsupported):
    def __init__(self, msg: str) -> None:
        super().__init__(msg)


class CondOpArgsMismatchError(ArgsMismatchError):
    """
    Internal error from cond() due to arguments mismatch.
    """

    def __init__(self, msg: str) -> None:
        super().__init__(msg)


class UserErrorType(Enum):
    DYNAMIC_CONTROL_FLOW = auto()
    ANTI_PATTERN = auto()
    STANDARD_LIBRARY = auto()
    CONSTRAINT_VIOLATION = auto()
    DYNAMIC_DIM = auto()
    INVALID_INPUT = auto()
    INVALID_OUTPUT = auto()
    UNSUPPORTED_ALIASED_MUTATED_DYNAMIC_INPUTS = auto()


class UserError(Unsupported):
    def __init__(
        self, error_type: UserErrorType, msg: str, case_name: Optional[str] = None
    ) -> None:
        """
        Type of errors that would be valid in Eager, but not supported in TorchDynamo.
        The error message should tell user about next actions.

        error_type: Type of user error
        msg: Actionable error message
        case_name: (Optional) Unique name (snake case) for the usage example in exportdb.
        """
        if case_name is not None:
            assert isinstance(case_name, str)
            if msg.endswith("."):
                msg += " "
            else:
                msg += "\n"
            msg += exportdb_error_message(case_name)
        super().__init__(msg)
        self.error_type = error_type
        self.message = msg


class SkipCodeRecursiveException(TorchDynamoException):
    pass


class RecompileLimitExceeded(Unsupported):
    pass


class UnsafeScriptObjectError(TorchDynamoException):
    pass


class UncapturedHigherOrderOpError(TorchDynamoException):
    pass


class IncorrectUsage(Exception):
    pass


# TODO: I'm a little uncertain about what error classification we should have
# for this.  This is potentially a user error, but regressions in
# specialization in PyTorch proper could also trigger this problem
class FailOnRecompileLimitHit(Exception):
    pass


class ObservedException(TorchDynamoException):
    # An exception observed during the tracing. This exception is used by Dynamo to handle exceptions.
    pass


class ObservedUserStopIteration(ObservedException):
    # An UserStopIteraion exception observed during the Dynamo tracing (e.g Dynamo tracing __next__)
    value: Optional[Any]

    # Reference `StopIteration_init` in CPython
    # https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L568-L584
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__("unhandled `raise StopIteration`")
        if len(args) > 0:
            self.value = args[0]
        else:
            self.value = None


class ObservedLookupError(ObservedException):
    # A LookupError exception to be raised from inside Dynamo tracing. This can happen on __getitem__
    pass


class ObservedIndexError(ObservedLookupError):
    # An IndexError exception to be raised from inside Dynamo tracing. This can happen on list __getitem__
    pass


class ObservedKeyError(ObservedLookupError):
    # A KeyError exception to be raised from inside Dynamo tracing. This can happen on dict __getitem__
    pass


class ObservedGeneratorExit(ObservedException):
    pass


class ObservedAttributeError(ObservedException):
    # An AttributeError exception to be raised from inside Dynamo tracing. This can happen on user defined object __getattr__
    pass


class ObservedRuntimeError(ObservedException):
    # A RuntimeError exception to be raised from inside Dynamo tracing. This can happen on generator.throw(..) method
    pass


class ObservedNotImplementedError(ObservedException):
    pass


class ObservedTypeError(ObservedException):
    # A TypeError exception to be raised from inside Dynamo tracing. This can happen on generator.send(..) method
    pass


observed_exception_map = {
    StopIteration: ObservedUserStopIteration,
    LookupError: ObservedLookupError,
    IndexError: ObservedIndexError,
    GeneratorExit: ObservedGeneratorExit,
    KeyError: ObservedKeyError,
    AttributeError: ObservedAttributeError,
    RuntimeError: ObservedRuntimeError,
    NotImplementedError: ObservedNotImplementedError,
    TypeError: ObservedTypeError,
}


def get_dynamo_observed_exception(exc_type: type[Exception]) -> type[ObservedException]:
    if exc_type not in observed_exception_map:
        name = getattr(exc_type, "__name__", str(exc_type))
        observed_exception_map[exc_type] = type(
            f"Observed{name}Error", (ObservedException,), {}
        )
    return observed_exception_map[exc_type]


def raise_observed_exception(
    exc_type: type[Exception],
    tx: InstructionTranslatorBase,
    *,
    args: Optional[list[Any]] = None,
    kwargs: Optional[dict[str, Any]] = None,
) -> NoReturn:
    from .variables import BuiltinVariable

    # CPython here raises an exception. Since there is no python code, we have to manually setup the exception
    # stack and raise the exception.
    exception_vt = BuiltinVariable(exc_type).call_function(tx, args or [], kwargs or {})  # type: ignore[arg-type]
    tx.exn_vt_stack.set_current_exception(exception_vt)
    raise observed_exception_map[exc_type]


def handle_observed_exception(tx: Any) -> None:
    # This is essentially exception handling code, equivalent of this pseudo code
    #
    # try:
    #     ... somebody raising StopIteration
    # except StopIteration
    #     pass
    #
    # If this was going through the python code, we would have called exception_handler method, but FOR_ITER
    # handles the exception completely in CPython. For example for 3.11, the resulting bytecode is
    #
    #
    #   6          46 LOAD_GLOBAL              2 (StopIteration)
    #              58 RAISE_VARARGS            1
    #         >>   60 PUSH_EXC_INFO

    #   7          62 LOAD_GLOBAL              2 (StopIteration)
    #              74 CHECK_EXC_MATCH
    #              76 POP_JUMP_FORWARD_IF_FALSE     3 (to 84)
    #              78 POP_TOP

    #   8          80 POP_EXCEPT
    #

    # Fortunately this translates to a simple pop from the exn_vt_stack
    tx.exn_vt_stack.clear_current_exception()


# These exceptions are ok to fallback to eager/graph_break.
exceptions_allowed_to_be_fallback = (
    torch._subclasses.fake_tensor.DataDependentOutputException,
    torch._subclasses.fake_tensor.DynamicOutputShapeException,
    torch._subclasses.fake_tensor.UnsupportedOperatorException,
    torch._subclasses.fake_tensor.UnsupportedFakeTensorException,
)


def unimplemented_with_warning(
    e: Exception, code: types.CodeType, msg: str
) -> NoReturn:
    # This function calls unimplemented internally and eventually graph breaks
    # or falls to eager. unimplemented itself does not print any user warnings,
    # i.e., its very silent. This helper function is intended when an error is
    # encountered in the torch.compile stack which is worth showing as warning
    # to the user. For example, if AOT Autograd backend fails with a fake tensor
    # exception, its ok to fallback to eager but not silently. Here, we can use
    # this function to log the message and the stack trace.
    graph_break_msg = format_error_msg_verbose(e, code)
    torch._logging.trace_structured(
        "artifact",
        metadata_fn=lambda: {
            "name": "dynamo_graph_break_reason",
            "encoding": "string",
        },
        payload_fn=lambda: graph_break_msg,
    )
    graph_breaks_log.debug("%s", graph_break_msg)
    log.warning(msg)
    unimplemented(msg, from_exc=e)


_NOTHING = object()


def unimplemented(
    msg: str, *, from_exc: Any = _NOTHING, case_name: Optional[str] = None
) -> NoReturn:
    assert msg != os.environ.get("BREAK", False)
    if from_exc is not _NOTHING:
        raise Unsupported(msg, case_name=case_name) from from_exc
    raise Unsupported(msg, case_name=case_name)


def unimplemented_v2_with_warning(
    e: Exception,
    code: types.CodeType,
    gb_type: str,
    context: str,
    explanation: str,
    hints: list[str],
) -> NoReturn:
    # This function calls unimplemented internally and eventually graph breaks
    # or falls to eager. unimplemented itself does not print any user warnings,
    # i.e., its very silent. This helper function is intended when an error is
    # encountered in the torch.compile stack which is worth showing as warning
    # to the user. For example, if AOT Autograd backend fails with a fake tensor
    # exception, its ok to fallback to eager but not silently. Here, we can use
    # this function to log the message and the stack trace.
    graph_break_msg = format_error_msg_verbose(e, code)
    torch._logging.trace_structured(
        "artifact",
        metadata_fn=lambda: {
            "name": "dynamo_graph_break_reason",
            "encoding": "string",
        },
        payload_fn=lambda: graph_break_msg,
    )
    graph_breaks_log.debug("%s", graph_break_msg)
    unimplemented_v2(gb_type, context, explanation, hints, from_exc=e, log_warning=True)


def format_graph_break_message(
    gb_type: str,
    context: str,
    explanation: str,
    hints: list[str],
) -> str:
    explanation = textwrap.indent(explanation, "    ").lstrip()
    hints_str = "\n".join(
        "  Hint: " + textwrap.indent(hint, "    ").lstrip() for hint in hints
    )
    context = textwrap.indent(context, "    ").lstrip()

    msg = f"""\
{gb_type}
  Explanation: {explanation}
{hints_str}

  Developer debug context: {context}
"""
    return msg


# TODO replace old unimplemented later
def unimplemented_v2(
    gb_type: str,
    context: str,
    explanation: str,
    hints: list[str],
    *,
    from_exc: Any = _NOTHING,
    log_warning: bool = False,
) -> NoReturn:
    """
    Called within dynamo to cause a graph break.
    Args:
        gb_type: Context-free graph break type. It should be a short string without any
                 information specific to the tracing context (i.e. no dynamically-generated strings)
        context: Developer context for the graph break. It can contain tracing context/dynamic strings.
        explanation: User-facing context-dependent explanation for the graph break. Can be dynamic.
        hints: List of user-facing hints for the graph break.
    """

    msg = format_graph_break_message(gb_type, context, explanation, hints)
    if log_warning:
        log.warning(msg)
    if from_exc is not _NOTHING:
        raise Unsupported(msg) from from_exc
    raise Unsupported(msg)


def warning(msg: str) -> None:
    counters["warnings"][msg] += 1
    assert msg != os.environ.get("BREAK", False)


# KeyError has special handling for its args
# see https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L2534 for details
class KeyErrorMsg:
    def __init__(self, value: Any) -> None:
        self.value = value

    def __str__(self) -> str:
        return str(self.value)

    def __repr__(self) -> str:
        return self.__str__()


def augment_exc_message(exc: Exception, msg: str = "\n", export: bool = False) -> None:
    import traceback

    exc.innermost_user_frame_summary = None  # type: ignore[attr-defined]

    real_stack = get_real_stack(exc)
    if real_stack is not None and len(real_stack) > 0:
        exc.innermost_user_frame_summary = real_stack[-1]  # type: ignore[attr-defined]
        msg += f"\nfrom user code:\n {''.join(traceback.format_list(real_stack))}"

    if config.replay_record_enabled and hasattr(exc, "record_filename"):
        msg += (
            f"\nLast frame execution written to {exc.record_filename}. To run only this frame while debugging, run\
 torch._dynamo.replay('{exc.record_filename}').\n"
        )

    if not config.verbose and hasattr(exc, "real_stack"):
        msg += (
            "\nSet TORCHDYNAMO_VERBOSE=1 for the internal stack trace "
            "(please do this especially if you're reporting a bug to PyTorch). "
            'For even more developer context, set TORCH_LOGS="+dynamo"\n'
        )

    if hasattr(exc, "inner_exception") and hasattr(
        exc.inner_exception, "minifier_path"
    ):
        if hasattr(exc.inner_exception, "buck_command"):
            msg += (
                f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run "
                f"this buck command to find the smallest traced graph "
                f"which reproduces this error: {exc.inner_exception.buck_command}\n"
            )
        else:
            msg += (
                f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run "
                "this script to find the smallest traced graph which reproduces this error.\n"
            )

    old_msg = "" if len(exc.args) == 0 else str(exc.args[0])

    if isinstance(exc, KeyError):
        exc.args = (KeyErrorMsg(old_msg + msg),) + exc.args[1:]
    else:
        new_msg = old_msg + msg
        exc.args = (new_msg,) + exc.args[1:]


def get_exc_message(
    e: Exception, compile_id: CompileId
) -> tuple[Optional[str], Optional[int]]:
    filename = None
    lineno = None
    if e.innermost_user_frame_summary is not None:  # type: ignore[attr-defined]
        filename = e.innermost_user_frame_summary.filename  # type: ignore[attr-defined]
        lineno = e.innermost_user_frame_summary.lineno  # type: ignore[attr-defined]
    e.compile_id = compile_id  # type: ignore[attr-defined]
    return filename, lineno


def get_stack_above_dynamo() -> StackSummary:
    return filter_stack(extract_stack())


def get_real_stack(
    exc: Exception, frame: Optional[DynamoFrameType] = None
) -> Optional[StackSummary]:
    real_stack = getattr(exc, "real_stack", None)
    if real_stack is None:
        return None

    # NB: it's possible for real_stack to be []; we still attempt to
    # report a stack anyway because the stack_above_dynamo may still
    # be useful for debugging

    if frame is not None:
        # NB: frame is PyInterpreterFrame on Python 3.11 and later,
        # not a TRUE frame object.  You can't actually feed it
        # to traceback because it doesn't have enough information.
        # To solve this problem, we technically should just materialize
        # the frame, the same way _PyFrame_GetFrameObject would do
        # (but we cannot actually do this, because this populates
        # frame_obj field, which default eval frame doesn't like).
        #
        # Fortunately, in this case, we can hack it: there's no need
        # to actually use the truly top frame, we can just extract
        # from where we are right now and rely on filter_stack to
        # get rid of all the dynamo frames.  For ease of testing
        # we apply this behavior to ALL Python versions
        stack_above_dynamo = get_stack_above_dynamo()
    else:
        stack_above_dynamo = StackSummary()

    return StackSummary.from_list(stack_above_dynamo + real_stack)


# filter out all frames after entering dynamo
def filter_stack(stack: StackSummary) -> StackSummary:
    user_stack = StackSummary()
    for frame in stack:
        if frame.filename is None:
            continue
        if "convert_frame" in frame.filename:
            break
        if "eval_frame" in frame.filename or (
            frame.line and "torch._dynamo.optimize(" in frame.line
        ):
            continue
        user_stack.append(frame)

    return user_stack


def remove_resume_prefix(name: str) -> Optional[str]:
    from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX

    match = re.match(f"{TORCH_DYNAMO_RESUME_IN_PREFIX}_(\\w+)_at_\\d+", name)
    if match:
        return match.group(1)
    return None


def collapse_resume_frames(stack: StackSummary) -> StackSummary:
    """
    When we graph break, we create a resume function and make a regular Python call
    to it, which gets intercepted by Dynamo. This behavior is normally shown in the
    traceback, which can be confusing to a user. So we can filter out resume frames
    for better traceback clarity.

    Example:
    File "..." line 3, in f
        <line 3>
    File "..." line 5, in torch_dynamo_resume_in_f_at_80
        <line 5>
    File "..." line 10, in torch_dynamo_resume_in_f_at_120
        <line 10>

    becomes
    File "..." line 10, in f
        <line 10>
    """

    new_stack = StackSummary()
    for frame in stack:
        if frame.filename is None:
            continue
        name = remove_resume_prefix(frame.name)
        if new_stack and name and new_stack[-1].name == name:
            new_stack[-1] = frame
            frame.name = name
        else:
            new_stack.append(frame)

    return new_stack


def format_error_msg_verbose(
    exc: Exception,
    code: types.CodeType,
    record_filename: Optional[str] = None,
    frame: Optional[DynamoFrameType] = None,
) -> str:
    msg = (
        f"WON'T CONVERT {code.co_name} {code.co_filename} line {code.co_firstlineno}\n"
    )
    msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n"
    msg += format_exc()
    real_stack = get_real_stack(exc, frame)
    if real_stack is not None:
        msg += (
            "\n"
            + "=" * 10
            + " The above exception occurred while processing the following code "
            + "=" * 10
            + "\n\n"
        )
        msg += "".join(format_list(real_stack))
        msg += "\n"
        msg += "=" * 10

    return msg


def format_error_msg(
    exc: Exception,
    code: types.CodeType,
    record_filename: Optional[str] = None,
    frame: Optional[DynamoFrameType] = None,
) -> str:
    if config.verbose:
        return format_error_msg_verbose(exc, code, record_filename, frame)
    return f"WON'T CONVERT {code.co_name} {code.co_filename}\
 line {code.co_firstlineno} \ndue to: \n{format_exc()}"
