"""
Python polyfills for common builtins.
"""

# NOTE: 1. Please do not import any submodule in the directory here to avoid circular imports.
#       2. While adding a new polyfill module, also add it to POLYFILLED_MODULE_NAMES in loader.py.
#          Add it in the TYPE_CHECKING block below as well.

# mypy: allow-untyped-defs

import types
from collections.abc import MutableMapping, Sequence
from itertools import repeat as _repeat
from typing import Any, Callable, TYPE_CHECKING

import torch

from ..utils import dict_keys


if TYPE_CHECKING:
    # Load by torch._dynamo.polyfills.loader
    # See also the POLYFILLED_MODULE_NAMES in torch/_dynamo/polyfills/loader.py
    # Put the submodules here to avoid circular imports
    from . import (
        builtins as builtins,
        functools as functools,
        itertools as itertools,
        operator as operator,
        os as os,
        pytree as pytree,
        sys as sys,
    )

from torch.overrides import BaseTorchFunctionMode


# These classes handle support for TorchFunctionModes across
# graph breaks
# Today the TorchFunctionMode enter (for the classes we support)
# simply pushes the mode onto the stack. Since after this occurs
# the stack is mutated, and we replay these mutations, we don't need
# any cleanup logic to be run once the graph break occurs, we simply replay
# these mutations to ensure at the graph break the torch function mode stack is correct
#  and reconstruct the torch function mode stack normally
# when we compile the resume function on the other side of the break.
# However, to ensure we exit properly
# in the resume function, we need to re-enter the contexts as we do other contexts.
# These contexts do nothing on enter, but provide the correct exit logic to ensure
# the stack state is correct.
class NoEnterTorchFunctionMode(BaseTorchFunctionMode):
    def __enter__(self):
        pass


def index(iterator, item, start=0, end=None):
    from itertools import islice

    for i, elem in islice(enumerate(iterator), start, end):
        if item == elem:
            return i
    # This will not run in dynamo
    raise ValueError(f"{item} is not in {type(iterator)}")


def repeat(item, count):
    for _ in range(count):
        yield item


def radians(x):
    import math

    return math.pi / 180.0 * x


def accumulate_grad(x, new_grad):
    if new_grad is None:
        return
    new_grad = torch.clone(new_grad)
    if x.grad is None:
        x.grad = new_grad
    else:
        x.grad.add_(new_grad)


# This mirrors
# https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/listobject.c#L3352-L3413
def list_cmp(op: Callable[[Any, Any], bool], left: Sequence[Any], right: Sequence[Any]):
    """emulate `(1,2,3) > (1,2)` etc"""
    # Apply `op` to the first pair that differ
    for a, b in zip(left, right):
        if a != b:
            return op(a, b)

    # No more pairs to compare, so compare sizes.
    return op(len(left), len(right))


def set_isdisjoint(set1, set2):
    for x in set1:
        if x in set2:
            return False
    return True


def set_intersection(set1, set2):
    intersection_set = set()
    for x in set1:
        if x in set2:
            intersection_set.add(x)
    return intersection_set


def set_union(set1, set2):
    union_set = set1.copy()
    set_update(union_set, set2)
    return union_set


def set_update(set1, set2):
    for x in set2:
        if x not in set1:
            set1.add(x)
    return set1


def set_difference(set1, set2):
    difference_set = set()
    for x in set1:
        if x not in set2:
            difference_set.add(x)
    return difference_set


def getattr_and_trace(*args, **kwargs):
    wrapper_obj = args[0]
    attr_name = args[1]
    fn = getattr(wrapper_obj, attr_name)
    return fn(*args[2:], **kwargs)


def mapping_get(obj, key, value=None):
    try:
        return obj.__getitem__(key)
    except KeyError:
        return value


def instantiate_user_defined_class_object(cls, /, *args, **kwargs):
    obj = cls.__new__(cls, *args, **kwargs)

    # Only call __init__ if the object is an instance of the class
    # Reference: https://github.com/python/cpython/blob/3.12/Objects/typeobject.c#L1670-L1673
    if isinstance(obj, cls):
        obj.__init__(*args, **kwargs)
    return obj


# Used with something like dict(obj)
def construct_dict(cls, /, *args, **kwargs):
    dst = cls.__new__(cls)

    if args:
        src = args[0]

        # Ensure that the overridden __iter__ method is invoked
        if isinstance(src, (dict, MutableMapping)):
            for key in src:
                # This will inline the __getitem__ of the src object
                dst[key] = src[key]
        else:
            # likely a sequence like tuple of pairs
            for key, value in src:
                dst[key] = value

    if kwargs:
        for key in kwargs:
            dst[key] = kwargs[key]

    return dst


def foreach_map_fn(*args):
    op = args[0]
    new_args: list[Any] = []
    at_least_one_list = False
    for arg in args[1:]:
        if not isinstance(arg, (list, tuple)):
            new_args.append(_repeat(arg))
        else:
            at_least_one_list = True
            new_args.append(arg)

    # Just apply op once to args if there are no lists
    if not at_least_one_list:
        return op(*args[1:])

    out = []
    for unpacked in zip(*new_args):
        out.append(op(*unpacked))

    return out


def foreach_lerp_inplace(self, end, weight):
    # decompose foreach lerp into constituent ops, prevents a graph break due to
    # converting a value to a scalar when arg[2] is a single tensor
    result = torch._foreach_sub(end, self)
    result = torch._foreach_mul(result, weight)
    return torch._foreach_add_(self, result)


def foreach_pow_scalar(scalar, exps):
    return torch._foreach_pow([scalar for _ in exps], exps)


def addcmul_inplace(self, tensor1, tensor2, value):
    return self.add_(tensor1 * tensor2 * value)


def predicate(obj: Any) -> bool:
    # This will cause the rest of dynamo to handle the if statement correctly, so we don't have to rewrite it here.
    # We can't just use bool() here since we can't trace into that in general.
    if obj:
        return True
    return False


def cmp_eq(a, b):
    # Note that the commented `is` check should ideally be removed. This is a
    # CPython optimization that skips the __eq__ checks it the obj id's are
    # same. But, these lines adds many `is` nodes in the Fx graph for
    # SymNodeVariable. For now, we can just skip this check. This is STILL
    # correct because one of the __eq__ checks will pass later, just could be
    # slow in some corner cases.
    # if a is b:
    #     return True
    result = a.__eq__(b)
    if result is NotImplemented:
        result = b.__eq__(a)
    return result is not NotImplemented and result


def cmp_ne(a, b):
    # Check if __ne__ is overridden
    if isinstance(type(a).__ne__, types.FunctionType):
        return a.__ne__(b)
    return not cmp_eq(a, b)


def cmp_lt(a, b):
    result = a.__lt__(b)
    if result is NotImplemented:
        raise TypeError(f"{type(a)} does not support the < operator")
    return result


def cmp_le(a, b):
    # Check if __le__ is overridden
    if isinstance(type(a).__le__, types.FunctionType):
        return a.__le__(b)
    return cmp_eq(a, b) or cmp_lt(a, b)


def cmp_gt(a, b):
    # Check if __gt__ is overridden
    if isinstance(type(a).__gt__, types.FunctionType):
        return a.__gt__(b)
    # a > b is equivalent to b < a
    return cmp_lt(b, a)


def cmp_ge(a, b):
    # Check if __ge__ is overridden
    if isinstance(type(a).__ge__, types.FunctionType):
        return a.__ge__(b)
    return cmp_eq(a, b) or cmp_gt(a, b)
